博主根据自身多年的深度学习算法研发经验,整理分享以下十条必知。

含参考资料链接,部分附上相关代码实现。

独乐乐不如众乐乐,希望对各位看客有所帮助。

待回头有时间再展开细节说一说深度学习里的那些道道。

有什么技术需求需要有偿解决的也可以邮件或者QQ联系博主。

邮箱QQ同ID:gaozhihan@vip.qq.com

当然除了这十条,肯定还有其他“必知”,

欢迎评论分享更多,这里只是暂时拟定的十条,别较真哈。

主要学习其中的思路,切记,以下思路在个别场景并不适用 。

1.数据回流

[1907.05550] Faster Neural Network Training with Data Echoing

def data_echoing(factor):
return lambda image, label: tf.data.Dataset.from_tensors((image, label)).repeat(factor)

作用:

数据集加载后,在数据增广前后重复当前批次进模型的次数,减少数据的加载耗时。

等价于让模型看n次当前的数据,或者看n个增广后的数据样本。

2.AMP 自动精度混合

在bert4keras中使用混合精度和XLA加速训练 - 科学空间|Scientific Spaces

    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

作用:

降低显存占用,加速训练,将部分网络计算转为等价的低精度计算,以此降低计算量。

3.优化器节省显存

3.1  [1804.04235]Adafactor: Adaptive Learning Rates with Sublinear Memory Cost

mesh/optimize.py at master · tensorflow/mesh · GitHub

3.2 [1901.11150] Memory-Efficient Adaptive Optimization

google-research/sm3 at master · google-research/google-research (github.com)

作用:

节省显存,加速训练,

主要是对二阶动量进行特例化解构,减少显存存储。

4.权重标准化(归一化)

[2102.06171] High-Performance Large-Scale Image Recognition Without Normalization

deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub

class WSConv2D(tf.keras.layers.Conv2D):
def __init__(self, *args, **kwargs):
super(WSConv2D, self).__init__(
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='untruncated_normal',
),
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(1e-4), *args, **kwargs
)
self.gain = self.add_weight(
name='gain',
shape=(self.filters,),
initializer="ones",
trainable=True,
dtype=self.dtype
) def standardize_weight(self, eps):
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
fan_in = np.prod(self.kernel.shape[:-1])
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = tf.math.rsqrt(
tf.math.maximum(
var * fan_in,
tf.convert_to_tensor(eps, dtype=self.dtype)
)
) * self.gain
shift = mean * scale
return self.kernel * scale - shift def call(self, inputs):
eps = 1e-4
weight = self.standardize_weight(eps)
return tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
) if self.bias is None else tf.nn.bias_add(
tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
), self.bias)

作用:

通过对kernel进行标准化或归一化,相当于对kernel做一个先验约束,以此加速模型训练收敛。

5.自适应梯度裁剪

deepmind-research/agc_optax.py at master · deepmind/deepmind-research · GitHub

def unitwise_norm(x):
if len(tf.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2, ]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
square_sum = tf.reduce_sum(tf.square(x), axis, keepdims=keepdims)
return tf.sqrt(square_sum) def gradient_clipping(grad, var):
clipping = 0.01
max_norm = tf.maximum(unitwise_norm(var), 1e-3) * clipping
grad_norm = unitwise_norm(grad)
trigger = (grad_norm > max_norm)
clipped_grad = (max_norm / tf.maximum(grad_norm, 1e-6))
return grad * tf.where(trigger, clipped_grad, tf.ones_like(clipped_grad))

作用:

防止梯度爆炸,稳定训练。通过梯度和参数的关系,对梯度进行裁剪,约束学习率。

6.recompute_grad

[1604.06174] Training Deep Nets with Sublinear Memory Cost

google-research/recompute_grad.py at master · google-research/google-research (github.com)

bojone/keras_recompute: saving memory by recomputing for keras (github.com)

作用:

通过梯度重计算,节省显存。

7.归一化

[2003.05569] Extended Batch Normalization (arxiv.org)

from keras.layers.normalization.batch_normalization import BatchNormalizationBase

class ExtendedBatchNormalization(BatchNormalizationBase):
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(ExtendedBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs) def _calculate_mean_and_var(self, x, axes, keep_dims):
with tf.keras.backend.name_scope('moments'):
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
replica_ctx = tf.distribute.get_replica_context()
if replica_ctx:
local_sum = tf.math.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.math.reduce_sum(tf.math.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[0], tf.float32)
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - tf.math.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = tf.math.reduce_mean(y, axes, keepdims=True, name='mean')
variance = tf.math.reduce_mean(
tf.math.squared_difference(y, tf.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
variance = tf.math.reduce_mean(variance)
if x.dtype == tf.float16:
return (tf.cast(mean, tf.float16),
tf.cast(variance, tf.float16))
else:
return mean, variance

  

作用:

一个简易改进版的Batch Normalization,思路简单有效。

8.学习率策略

[1506.01186] Cyclical Learning Rates for Training Neural Networks (arxiv.org)

作用:

一个推荐的学习率策略方案,特定情况下可以取得更好的泛化。

9.重参数化

[1908.03930] ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks

https://zhuanlan.zhihu.com/p/361090497

作用:

通过同时训练多份参数,合并权重的思路来提升模型泛化性。

10.长尾学习

[2110.04596] Deep Long-Tailed Learning: A Survey (arxiv.org)

Jorwnpay/A-Long-Tailed-Survey: 本项目是 Deep Long-Tailed Learning: A Survey 文章的中译版 (github.com)

作用:

解决长尾问题,可以加速收敛,提升模型泛化,稳定训练。

Tensorflow2 深度学习十必知的更多相关文章

  1. 对比深度学习十大框架:TensorFlow 并非最好?

    http://www.oschina.net/news/80593/deep-learning-frameworks-a-review-before-finishing-2016 TensorFlow ...

  2. 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

    推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...

  3. 《TensorFlow2深度学习》学习笔记(一)Tensorflow基础

    本系列笔记记录了学习TensorFlow2的过程,主要依据 https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book 进行学习 ...

  4. mysql学习--mysql必知必会1

     例如以下为mysql必知必会第九章開始: 正則表達式用于匹配特殊的字符集合.mysql通过where子句对正則表達式提供初步的支持. keywordregexp用来表示后面跟的东西作为正則表達式 ...

  5. mysql学习--mysql必知必会

      上图为数据库操作分类:     下面的操作參考(mysql必知必会) 创建数据库 运行脚本建表: mysql> create database mytest; Query OK, 1 row ...

  6. 学习axios必知必会(2)~axios基本使用、使用axios前必知细节、axios和实例对象区别、拦截器、取消请求

    一.axios的基本使用: ✿ 使用axios前必知细节: 1.axios 函数对象(可以作为axios(config)函数使用去发送请求,也可以作为对象调用方法axios.request(confi ...

  7. 【Android Api 翻译4】android api 完整翻译之Contacts Provider (学习安卓必知的api,中英文对照)

    Contacts Provider 电话簿(注:联系人,联络人.通信录)提供者 ------------------------------- QUICKVIEW 快速概览 * Android's r ...

  8. 《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示

    全部代码如下:(红色部分为与笔记二不同之处) #1.Import the neccessary libraries needed import numpy as np import tensorflo ...

  9. 学习MyBatis必知必会(2)~MyBatis基本介绍和MyBatis基本使用

    一.MyBatis框架基本介绍: 1.认识 MyBatis: MyBatis 是支持普通 SQL 查询,存储过程和高级映射的持久层框架,严格上说应该是一个 SQL 映射框架. 其前身是 iBatis, ...

随机推荐

  1. 【大话云原生】煮饺子与docker、kubernetes之间的关系

    云原生的概念最近非常火爆,企业落地云原生的愿望也越发强烈.看过很多关于云原生的文章,要么云山雾罩,要么曲高和寡. 所以笔者就有了写<大话云原生>系列文章的想法,期望用最通俗.简单的语言说明 ...

  2. 运行npm install命令的时候会发生什么?

    摘要:我们日常在下载第三方依赖的时候,都会用到一个命令npm install,那么你知道,在运行这个命令的时候都会发生什么吗? 本文分享自华为云社区<运行npm install命令的时候会发生什 ...

  3. Nacos在企业生产中如何使用集群环境?

    点赞再看,养成习惯,微信搜索[牧小农]关注我获取更多资讯,风里雨里,小农等你,很高兴能够成为你的朋友. 项目源码地址:公众号回复 nacos,即可免费获取源码 前言 由于在公司,注册中心和配置中心都是 ...

  4. Git (常用命令)

    某程序猿退休后决定练习书法,于是花重金买下文房四宝.某日,饭后突生雅兴,一番磨墨拟纸 并点上上好檀香.定神片刻,泼墨挥毫,郑重地写下一行:Hello World 斯~ 有被冷到吗哈哈哈 Git常用命令 ...

  5. 10个 Linux 命令,让你的操作更有效率

    点击上方"开源Linux",选择"设为星标" 回复"学习"获取独家整理的学习资料! 根据老九大师兄口头阐述,Linux是最适合开发的操作系统 ...

  6. ReentrantLock可重入、可打断、Condition原理剖析

    本文紧接上文的AQS源码,如果对于ReentrantLock没有基础可以先阅读我的上一篇文章学习ReentrantLock的源码 ReentrantLock锁重入原理 重入加锁其实就是将AQS的sta ...

  7. 对象、Map、Set、WeakMap、WeakSet

    对象.Map.Set.WeakMap.WeakSet 本文写于 2020 年 11 月 24 日 总的来说,Set 和 Map 主要的应用场景分别在于数据重组和数据储存.Set 是一种叫做「集合」的数 ...

  8. git clone指定分支

    技术背景 Git是代码版本最常用的管理工具,此前也写过一篇介绍Git的基本使用的博客,而本文介绍一个可能在特定场景下能够用到的功能--直接拉取指定分支的内容. Git Clone 首先看一下如果我们按 ...

  9. spring boot 统一接口异常返回值

    创建业务 Exception 一般在实际项目中,推荐创建自己的 Exception 类型,这样在后期会更容易处理,也比较方便统一,否则,可能每个人都抛出自己喜欢的异常类型,而造成代码混乱 Servic ...

  10. yolov2学习笔记

    Yolov2学习笔记 yolov2在yolov1的基础上进行一系列改进: 1.比如Batch Normalization,High Resolution Classifier,使用Anchor Box ...