Tensorflow2 深度学习十必知
博主根据自身多年的深度学习算法研发经验,整理分享以下十条必知。
含参考资料链接,部分附上相关代码实现。
独乐乐不如众乐乐,希望对各位看客有所帮助。
待回头有时间再展开细节说一说深度学习里的那些道道。
有什么技术需求需要有偿解决的也可以邮件或者QQ联系博主。
邮箱QQ同ID:gaozhihan@vip.qq.com
当然除了这十条,肯定还有其他“必知”,
欢迎评论分享更多,这里只是暂时拟定的十条,别较真哈。
主要学习其中的思路,切记,以下思路在个别场景并不适用 。
1.数据回流
[1907.05550] Faster Neural Network Training with Data Echoing
def data_echoing(factor):
return lambda image, label: tf.data.Dataset.from_tensors((image, label)).repeat(factor)
作用:
数据集加载后,在数据增广前后重复当前批次进模型的次数,减少数据的加载耗时。
等价于让模型看n次当前的数据,或者看n个增广后的数据样本。
2.AMP 自动精度混合
在bert4keras中使用混合精度和XLA加速训练 - 科学空间|Scientific Spaces
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
作用:
降低显存占用,加速训练,将部分网络计算转为等价的低精度计算,以此降低计算量。
3.优化器节省显存
3.1 [1804.04235]Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
mesh/optimize.py at master · tensorflow/mesh · GitHub
3.2 [1901.11150] Memory-Efficient Adaptive Optimization
google-research/sm3 at master · google-research/google-research (github.com)
作用:
节省显存,加速训练,
主要是对二阶动量进行特例化解构,减少显存存储。
4.权重标准化(归一化)
[2102.06171] High-Performance Large-Scale Image Recognition Without Normalization
deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub
class WSConv2D(tf.keras.layers.Conv2D):
def __init__(self, *args, **kwargs):
super(WSConv2D, self).__init__(
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='untruncated_normal',
),
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(1e-4), *args, **kwargs
)
self.gain = self.add_weight(
name='gain',
shape=(self.filters,),
initializer="ones",
trainable=True,
dtype=self.dtype
) def standardize_weight(self, eps):
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
fan_in = np.prod(self.kernel.shape[:-1])
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = tf.math.rsqrt(
tf.math.maximum(
var * fan_in,
tf.convert_to_tensor(eps, dtype=self.dtype)
)
) * self.gain
shift = mean * scale
return self.kernel * scale - shift def call(self, inputs):
eps = 1e-4
weight = self.standardize_weight(eps)
return tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
) if self.bias is None else tf.nn.bias_add(
tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
), self.bias)
作用:
通过对kernel进行标准化或归一化,相当于对kernel做一个先验约束,以此加速模型训练收敛。
5.自适应梯度裁剪
deepmind-research/agc_optax.py at master · deepmind/deepmind-research · GitHub
def unitwise_norm(x):
if len(tf.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2, ]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
square_sum = tf.reduce_sum(tf.square(x), axis, keepdims=keepdims)
return tf.sqrt(square_sum) def gradient_clipping(grad, var):
clipping = 0.01
max_norm = tf.maximum(unitwise_norm(var), 1e-3) * clipping
grad_norm = unitwise_norm(grad)
trigger = (grad_norm > max_norm)
clipped_grad = (max_norm / tf.maximum(grad_norm, 1e-6))
return grad * tf.where(trigger, clipped_grad, tf.ones_like(clipped_grad))
作用:
防止梯度爆炸,稳定训练。通过梯度和参数的关系,对梯度进行裁剪,约束学习率。
6.recompute_grad
[1604.06174] Training Deep Nets with Sublinear Memory Cost
google-research/recompute_grad.py at master · google-research/google-research (github.com)
bojone/keras_recompute: saving memory by recomputing for keras (github.com)
作用:
通过梯度重计算,节省显存。
7.归一化
[2003.05569] Extended Batch Normalization (arxiv.org)
from keras.layers.normalization.batch_normalization import BatchNormalizationBase class ExtendedBatchNormalization(BatchNormalizationBase):
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(ExtendedBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs) def _calculate_mean_and_var(self, x, axes, keep_dims):
with tf.keras.backend.name_scope('moments'):
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
replica_ctx = tf.distribute.get_replica_context()
if replica_ctx:
local_sum = tf.math.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.math.reduce_sum(tf.math.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[0], tf.float32)
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - tf.math.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = tf.math.reduce_mean(y, axes, keepdims=True, name='mean')
variance = tf.math.reduce_mean(
tf.math.squared_difference(y, tf.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
variance = tf.math.reduce_mean(variance)
if x.dtype == tf.float16:
return (tf.cast(mean, tf.float16),
tf.cast(variance, tf.float16))
else:
return mean, variance
作用:
一个简易改进版的Batch Normalization,思路简单有效。
8.学习率策略
[1506.01186] Cyclical Learning Rates for Training Neural Networks (arxiv.org)
作用:
一个推荐的学习率策略方案,特定情况下可以取得更好的泛化。
9.重参数化
https://zhuanlan.zhihu.com/p/361090497
作用:
通过同时训练多份参数,合并权重的思路来提升模型泛化性。
10.长尾学习
[2110.04596] Deep Long-Tailed Learning: A Survey (arxiv.org)
Jorwnpay/A-Long-Tailed-Survey: 本项目是 Deep Long-Tailed Learning: A Survey 文章的中译版 (github.com)
作用:
解决长尾问题,可以加速收敛,提升模型泛化,稳定训练。
Tensorflow2 深度学习十必知的更多相关文章
- 对比深度学习十大框架:TensorFlow 并非最好?
http://www.oschina.net/news/80593/deep-learning-frameworks-a-review-before-finishing-2016 TensorFlow ...
- 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...
- 《TensorFlow2深度学习》学习笔记(一)Tensorflow基础
本系列笔记记录了学习TensorFlow2的过程,主要依据 https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book 进行学习 ...
- mysql学习--mysql必知必会1
例如以下为mysql必知必会第九章開始: 正則表達式用于匹配特殊的字符集合.mysql通过where子句对正則表達式提供初步的支持. keywordregexp用来表示后面跟的东西作为正則表達式 ...
- mysql学习--mysql必知必会
上图为数据库操作分类: 下面的操作參考(mysql必知必会) 创建数据库 运行脚本建表: mysql> create database mytest; Query OK, 1 row ...
- 学习axios必知必会(2)~axios基本使用、使用axios前必知细节、axios和实例对象区别、拦截器、取消请求
一.axios的基本使用: ✿ 使用axios前必知细节: 1.axios 函数对象(可以作为axios(config)函数使用去发送请求,也可以作为对象调用方法axios.request(confi ...
- 【Android Api 翻译4】android api 完整翻译之Contacts Provider (学习安卓必知的api,中英文对照)
Contacts Provider 电话簿(注:联系人,联络人.通信录)提供者 ------------------------------- QUICKVIEW 快速概览 * Android's r ...
- 《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示
全部代码如下:(红色部分为与笔记二不同之处) #1.Import the neccessary libraries needed import numpy as np import tensorflo ...
- 学习MyBatis必知必会(2)~MyBatis基本介绍和MyBatis基本使用
一.MyBatis框架基本介绍: 1.认识 MyBatis: MyBatis 是支持普通 SQL 查询,存储过程和高级映射的持久层框架,严格上说应该是一个 SQL 映射框架. 其前身是 iBatis, ...
随机推荐
- 【大话云原生】煮饺子与docker、kubernetes之间的关系
云原生的概念最近非常火爆,企业落地云原生的愿望也越发强烈.看过很多关于云原生的文章,要么云山雾罩,要么曲高和寡. 所以笔者就有了写<大话云原生>系列文章的想法,期望用最通俗.简单的语言说明 ...
- 运行npm install命令的时候会发生什么?
摘要:我们日常在下载第三方依赖的时候,都会用到一个命令npm install,那么你知道,在运行这个命令的时候都会发生什么吗? 本文分享自华为云社区<运行npm install命令的时候会发生什 ...
- Nacos在企业生产中如何使用集群环境?
点赞再看,养成习惯,微信搜索[牧小农]关注我获取更多资讯,风里雨里,小农等你,很高兴能够成为你的朋友. 项目源码地址:公众号回复 nacos,即可免费获取源码 前言 由于在公司,注册中心和配置中心都是 ...
- Git (常用命令)
某程序猿退休后决定练习书法,于是花重金买下文房四宝.某日,饭后突生雅兴,一番磨墨拟纸 并点上上好檀香.定神片刻,泼墨挥毫,郑重地写下一行:Hello World 斯~ 有被冷到吗哈哈哈 Git常用命令 ...
- 10个 Linux 命令,让你的操作更有效率
点击上方"开源Linux",选择"设为星标" 回复"学习"获取独家整理的学习资料! 根据老九大师兄口头阐述,Linux是最适合开发的操作系统 ...
- ReentrantLock可重入、可打断、Condition原理剖析
本文紧接上文的AQS源码,如果对于ReentrantLock没有基础可以先阅读我的上一篇文章学习ReentrantLock的源码 ReentrantLock锁重入原理 重入加锁其实就是将AQS的sta ...
- 对象、Map、Set、WeakMap、WeakSet
对象.Map.Set.WeakMap.WeakSet 本文写于 2020 年 11 月 24 日 总的来说,Set 和 Map 主要的应用场景分别在于数据重组和数据储存.Set 是一种叫做「集合」的数 ...
- git clone指定分支
技术背景 Git是代码版本最常用的管理工具,此前也写过一篇介绍Git的基本使用的博客,而本文介绍一个可能在特定场景下能够用到的功能--直接拉取指定分支的内容. Git Clone 首先看一下如果我们按 ...
- spring boot 统一接口异常返回值
创建业务 Exception 一般在实际项目中,推荐创建自己的 Exception 类型,这样在后期会更容易处理,也比较方便统一,否则,可能每个人都抛出自己喜欢的异常类型,而造成代码混乱 Servic ...
- yolov2学习笔记
Yolov2学习笔记 yolov2在yolov1的基础上进行一系列改进: 1.比如Batch Normalization,High Resolution Classifier,使用Anchor Box ...