博主根据自身多年的深度学习算法研发经验,整理分享以下十条必知。

含参考资料链接,部分附上相关代码实现。

独乐乐不如众乐乐,希望对各位看客有所帮助。

待回头有时间再展开细节说一说深度学习里的那些道道。

有什么技术需求需要有偿解决的也可以邮件或者QQ联系博主。

邮箱QQ同ID:gaozhihan@vip.qq.com

当然除了这十条,肯定还有其他“必知”,

欢迎评论分享更多,这里只是暂时拟定的十条,别较真哈。

主要学习其中的思路,切记,以下思路在个别场景并不适用 。

1.数据回流

[1907.05550] Faster Neural Network Training with Data Echoing

def data_echoing(factor):
return lambda image, label: tf.data.Dataset.from_tensors((image, label)).repeat(factor)

作用:

数据集加载后,在数据增广前后重复当前批次进模型的次数,减少数据的加载耗时。

等价于让模型看n次当前的数据,或者看n个增广后的数据样本。

2.AMP 自动精度混合

在bert4keras中使用混合精度和XLA加速训练 - 科学空间|Scientific Spaces

    tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})

作用:

降低显存占用,加速训练,将部分网络计算转为等价的低精度计算,以此降低计算量。

3.优化器节省显存

3.1  [1804.04235]Adafactor: Adaptive Learning Rates with Sublinear Memory Cost

mesh/optimize.py at master · tensorflow/mesh · GitHub

3.2 [1901.11150] Memory-Efficient Adaptive Optimization

google-research/sm3 at master · google-research/google-research (github.com)

作用:

节省显存,加速训练,

主要是对二阶动量进行特例化解构,减少显存存储。

4.权重标准化(归一化)

[2102.06171] High-Performance Large-Scale Image Recognition Without Normalization

deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub

class WSConv2D(tf.keras.layers.Conv2D):
def __init__(self, *args, **kwargs):
super(WSConv2D, self).__init__(
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='untruncated_normal',
),
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(1e-4), *args, **kwargs
)
self.gain = self.add_weight(
name='gain',
shape=(self.filters,),
initializer="ones",
trainable=True,
dtype=self.dtype
) def standardize_weight(self, eps):
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
fan_in = np.prod(self.kernel.shape[:-1])
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = tf.math.rsqrt(
tf.math.maximum(
var * fan_in,
tf.convert_to_tensor(eps, dtype=self.dtype)
)
) * self.gain
shift = mean * scale
return self.kernel * scale - shift def call(self, inputs):
eps = 1e-4
weight = self.standardize_weight(eps)
return tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
) if self.bias is None else tf.nn.bias_add(
tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
), self.bias)

作用:

通过对kernel进行标准化或归一化,相当于对kernel做一个先验约束,以此加速模型训练收敛。

5.自适应梯度裁剪

deepmind-research/agc_optax.py at master · deepmind/deepmind-research · GitHub

def unitwise_norm(x):
if len(tf.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2, ]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
square_sum = tf.reduce_sum(tf.square(x), axis, keepdims=keepdims)
return tf.sqrt(square_sum) def gradient_clipping(grad, var):
clipping = 0.01
max_norm = tf.maximum(unitwise_norm(var), 1e-3) * clipping
grad_norm = unitwise_norm(grad)
trigger = (grad_norm > max_norm)
clipped_grad = (max_norm / tf.maximum(grad_norm, 1e-6))
return grad * tf.where(trigger, clipped_grad, tf.ones_like(clipped_grad))

作用:

防止梯度爆炸,稳定训练。通过梯度和参数的关系,对梯度进行裁剪,约束学习率。

6.recompute_grad

[1604.06174] Training Deep Nets with Sublinear Memory Cost

google-research/recompute_grad.py at master · google-research/google-research (github.com)

bojone/keras_recompute: saving memory by recomputing for keras (github.com)

作用:

通过梯度重计算,节省显存。

7.归一化

[2003.05569] Extended Batch Normalization (arxiv.org)

from keras.layers.normalization.batch_normalization import BatchNormalizationBase

class ExtendedBatchNormalization(BatchNormalizationBase):
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(ExtendedBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs) def _calculate_mean_and_var(self, x, axes, keep_dims):
with tf.keras.backend.name_scope('moments'):
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
replica_ctx = tf.distribute.get_replica_context()
if replica_ctx:
local_sum = tf.math.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.math.reduce_sum(tf.math.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[0], tf.float32)
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - tf.math.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = tf.math.reduce_mean(y, axes, keepdims=True, name='mean')
variance = tf.math.reduce_mean(
tf.math.squared_difference(y, tf.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
variance = tf.math.reduce_mean(variance)
if x.dtype == tf.float16:
return (tf.cast(mean, tf.float16),
tf.cast(variance, tf.float16))
else:
return mean, variance

  

作用:

一个简易改进版的Batch Normalization,思路简单有效。

8.学习率策略

[1506.01186] Cyclical Learning Rates for Training Neural Networks (arxiv.org)

作用:

一个推荐的学习率策略方案,特定情况下可以取得更好的泛化。

9.重参数化

[1908.03930] ACNet: Strengthening the Kernel Skeletons for Powerful CNN via Asymmetric Convolution Blocks

https://zhuanlan.zhihu.com/p/361090497

作用:

通过同时训练多份参数,合并权重的思路来提升模型泛化性。

10.长尾学习

[2110.04596] Deep Long-Tailed Learning: A Survey (arxiv.org)

Jorwnpay/A-Long-Tailed-Survey: 本项目是 Deep Long-Tailed Learning: A Survey 文章的中译版 (github.com)

作用:

解决长尾问题,可以加速收敛,提升模型泛化,稳定训练。

Tensorflow2 深度学习十必知的更多相关文章

  1. 对比深度学习十大框架:TensorFlow 并非最好?

    http://www.oschina.net/news/80593/deep-learning-frameworks-a-review-before-finishing-2016 TensorFlow ...

  2. 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战

    推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...

  3. 《TensorFlow2深度学习》学习笔记(一)Tensorflow基础

    本系列笔记记录了学习TensorFlow2的过程,主要依据 https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book 进行学习 ...

  4. mysql学习--mysql必知必会1

     例如以下为mysql必知必会第九章開始: 正則表達式用于匹配特殊的字符集合.mysql通过where子句对正則表達式提供初步的支持. keywordregexp用来表示后面跟的东西作为正則表達式 ...

  5. mysql学习--mysql必知必会

      上图为数据库操作分类:     下面的操作參考(mysql必知必会) 创建数据库 运行脚本建表: mysql> create database mytest; Query OK, 1 row ...

  6. 学习axios必知必会(2)~axios基本使用、使用axios前必知细节、axios和实例对象区别、拦截器、取消请求

    一.axios的基本使用: ✿ 使用axios前必知细节: 1.axios 函数对象(可以作为axios(config)函数使用去发送请求,也可以作为对象调用方法axios.request(confi ...

  7. 【Android Api 翻译4】android api 完整翻译之Contacts Provider (学习安卓必知的api,中英文对照)

    Contacts Provider 电话簿(注:联系人,联络人.通信录)提供者 ------------------------------- QUICKVIEW 快速概览 * Android's r ...

  8. 《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示

    全部代码如下:(红色部分为与笔记二不同之处) #1.Import the neccessary libraries needed import numpy as np import tensorflo ...

  9. 学习MyBatis必知必会(2)~MyBatis基本介绍和MyBatis基本使用

    一.MyBatis框架基本介绍: 1.认识 MyBatis: MyBatis 是支持普通 SQL 查询,存储过程和高级映射的持久层框架,严格上说应该是一个 SQL 映射框架. 其前身是 iBatis, ...

随机推荐

  1. partOne测试收获总结

    测试收获总结   执行类中构造多个方法,将各个功能分解出来,将大的,复杂的问题转化成小的,简单的问题,来进行处理,正所谓复杂问题简单化,简单问题流程化.大道至简编程精益.现总结编程中的一些问题,①在J ...

  2. python基础练习题(九九乘法表)

    又把python捡起来了,动手能力偏弱,决定每日一练,把基础打好! ------------------------------------------------------------------ ...

  3. linux下redis开机自启动

    将/usr/local/app/redis-4.0.8/redis.conf文件中daemonize no改为daemonize yes 在/etc目录下新建redis目录:mkdir /etc/re ...

  4. HttpServletResponse & HttpServletRequest

    web服务器接收到客户端的http请求,针对这个请求,分别创建一个代表请求的HttpServletRequest对象,代表响应的一个HttpServletResponse: 如果要获取客户端请求过来的 ...

  5. 【导包】使用Sklearn构建Logistic回归分类器

    官方英文文档地址:http://scikit-learn.org/dev/modules/generated/sklearn.linear_model.LogisticRegression.html# ...

  6. javaScript中Math内置对象基本方法入门

    概念 Math 是javaScript的内置对象,包含了部分数学常数属性和数学函数方法. Math 不是一个函数对象,用户Number类型进行使用,不支持BigInt. Math 的所有属性与方法都是 ...

  7. 【已解决】vscode窗口控制台闪现(不用更改原代码)

    打开launch.json 将"type": "cppdbg"改为"type": "cppvsdbg" 会出现密钥ext ...

  8. [没接触过kubevirt?]15分钟快速入门kubevirt

    @ 目录 本文介绍 前言 环境准备 详细版 搭建步骤 安装KubeVirt 安装virtctl客户端工具 创建VirtualMachine 启动VirtualMachineInstance 启动和停止 ...

  9. .NET混合开发解决方案6 检测是否已安装合适的WebView2运行时

    系列目录     [已更新最新开发文章,点击查看详细] 长青版WebView2运行时将作为Windows 11操作系统的一部分包含在内.但是在Windows 11之前(Win10.Win8.1.Win ...

  10. IIS方式部署项目发布上线

    VS2019如何把项目部署和发布 这里演示:通过IIS文件publish的方式部署到Windows本地服务器上 第一步(安装IIS) 1.在自己电脑上搜索Windows功能里的[启用或关闭Window ...