Tensorflow2 深度学习十必知
博主根据自身多年的深度学习算法研发经验,整理分享以下十条必知。
含参考资料链接,部分附上相关代码实现。
独乐乐不如众乐乐,希望对各位看客有所帮助。
待回头有时间再展开细节说一说深度学习里的那些道道。
有什么技术需求需要有偿解决的也可以邮件或者QQ联系博主。
邮箱QQ同ID:gaozhihan@vip.qq.com
当然除了这十条,肯定还有其他“必知”,
欢迎评论分享更多,这里只是暂时拟定的十条,别较真哈。
主要学习其中的思路,切记,以下思路在个别场景并不适用 。
1.数据回流
[1907.05550] Faster Neural Network Training with Data Echoing
def data_echoing(factor):
return lambda image, label: tf.data.Dataset.from_tensors((image, label)).repeat(factor)
作用:
数据集加载后,在数据增广前后重复当前批次进模型的次数,减少数据的加载耗时。
等价于让模型看n次当前的数据,或者看n个增广后的数据样本。
2.AMP 自动精度混合
在bert4keras中使用混合精度和XLA加速训练 - 科学空间|Scientific Spaces
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": True})
作用:
降低显存占用,加速训练,将部分网络计算转为等价的低精度计算,以此降低计算量。
3.优化器节省显存
3.1 [1804.04235]Adafactor: Adaptive Learning Rates with Sublinear Memory Cost
mesh/optimize.py at master · tensorflow/mesh · GitHub
3.2 [1901.11150] Memory-Efficient Adaptive Optimization
google-research/sm3 at master · google-research/google-research (github.com)
作用:
节省显存,加速训练,
主要是对二阶动量进行特例化解构,减少显存存储。
4.权重标准化(归一化)
[2102.06171] High-Performance Large-Scale Image Recognition Without Normalization
deepmind-research/nfnets at master · deepmind/deepmind-research · GitHub
class WSConv2D(tf.keras.layers.Conv2D):
def __init__(self, *args, **kwargs):
super(WSConv2D, self).__init__(
kernel_initializer=tf.keras.initializers.VarianceScaling(
scale=1.0, mode='fan_in', distribution='untruncated_normal',
),
use_bias=False,
kernel_regularizer=tf.keras.regularizers.l2(1e-4), *args, **kwargs
)
self.gain = self.add_weight(
name='gain',
shape=(self.filters,),
initializer="ones",
trainable=True,
dtype=self.dtype
) def standardize_weight(self, eps):
mean, var = tf.nn.moments(self.kernel, axes=[0, 1, 2], keepdims=True)
fan_in = np.prod(self.kernel.shape[:-1])
# Manually fused normalization, eq. to (w - mean) * gain / sqrt(N * var)
scale = tf.math.rsqrt(
tf.math.maximum(
var * fan_in,
tf.convert_to_tensor(eps, dtype=self.dtype)
)
) * self.gain
shift = mean * scale
return self.kernel * scale - shift def call(self, inputs):
eps = 1e-4
weight = self.standardize_weight(eps)
return tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
) if self.bias is None else tf.nn.bias_add(
tf.nn.conv2d(
inputs, weight, strides=self.strides,
padding=self.padding.upper(), dilations=self.dilation_rate
), self.bias)
作用:
通过对kernel进行标准化或归一化,相当于对kernel做一个先验约束,以此加速模型训练收敛。
5.自适应梯度裁剪
deepmind-research/agc_optax.py at master · deepmind/deepmind-research · GitHub
def unitwise_norm(x):
if len(tf.squeeze(x).shape) <= 1: # Scalars and vectors
axis = None
keepdims = False
elif len(x.shape) in [2, 3]: # Linear layers of shape IO
axis = 0
keepdims = True
elif len(x.shape) == 4: # Conv kernels of shape HWIO
axis = [0, 1, 2, ]
keepdims = True
else:
raise ValueError(f'Got a parameter with shape not in [1, 2, 3, 4]! {x}')
square_sum = tf.reduce_sum(tf.square(x), axis, keepdims=keepdims)
return tf.sqrt(square_sum) def gradient_clipping(grad, var):
clipping = 0.01
max_norm = tf.maximum(unitwise_norm(var), 1e-3) * clipping
grad_norm = unitwise_norm(grad)
trigger = (grad_norm > max_norm)
clipped_grad = (max_norm / tf.maximum(grad_norm, 1e-6))
return grad * tf.where(trigger, clipped_grad, tf.ones_like(clipped_grad))
作用:
防止梯度爆炸,稳定训练。通过梯度和参数的关系,对梯度进行裁剪,约束学习率。
6.recompute_grad
[1604.06174] Training Deep Nets with Sublinear Memory Cost
google-research/recompute_grad.py at master · google-research/google-research (github.com)
bojone/keras_recompute: saving memory by recomputing for keras (github.com)
作用:
通过梯度重计算,节省显存。
7.归一化
[2003.05569] Extended Batch Normalization (arxiv.org)
from keras.layers.normalization.batch_normalization import BatchNormalizationBase class ExtendedBatchNormalization(BatchNormalizationBase):
def __init__(self,
axis=-1,
momentum=0.99,
epsilon=1e-3,
center=True,
scale=True,
beta_initializer='zeros',
gamma_initializer='ones',
moving_mean_initializer='zeros',
moving_variance_initializer='ones',
beta_regularizer=None,
gamma_regularizer=None,
beta_constraint=None,
gamma_constraint=None,
renorm=False,
renorm_clipping=None,
renorm_momentum=0.99,
trainable=True,
name=None,
**kwargs):
# Currently we only support aggregating over the global batch size.
super(ExtendedBatchNormalization, self).__init__(
axis=axis,
momentum=momentum,
epsilon=epsilon,
center=center,
scale=scale,
beta_initializer=beta_initializer,
gamma_initializer=gamma_initializer,
moving_mean_initializer=moving_mean_initializer,
moving_variance_initializer=moving_variance_initializer,
beta_regularizer=beta_regularizer,
gamma_regularizer=gamma_regularizer,
beta_constraint=beta_constraint,
gamma_constraint=gamma_constraint,
renorm=renorm,
renorm_clipping=renorm_clipping,
renorm_momentum=renorm_momentum,
fused=False,
trainable=trainable,
virtual_batch_size=None,
name=name,
**kwargs) def _calculate_mean_and_var(self, x, axes, keep_dims):
with tf.keras.backend.name_scope('moments'):
y = tf.cast(x, tf.float32) if x.dtype == tf.float16 else x
replica_ctx = tf.distribute.get_replica_context()
if replica_ctx:
local_sum = tf.math.reduce_sum(y, axis=axes, keepdims=True)
local_squared_sum = tf.math.reduce_sum(tf.math.square(y), axis=axes,
keepdims=True)
batch_size = tf.cast(tf.shape(y)[0], tf.float32)
y_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM, local_sum)
y_squared_sum = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
local_squared_sum)
global_batch_size = replica_ctx.all_reduce(tf.distribute.ReduceOp.SUM,
batch_size)
axes_vals = [(tf.shape(y))[i] for i in range(1, len(axes))]
multiplier = tf.cast(tf.reduce_prod(axes_vals), tf.float32)
multiplier = multiplier * global_batch_size
mean = y_sum / multiplier
y_squared_mean = y_squared_sum / multiplier
# var = E(x^2) - E(x)^2
variance = y_squared_mean - tf.math.square(mean)
else:
# Compute true mean while keeping the dims for proper broadcasting.
mean = tf.math.reduce_mean(y, axes, keepdims=True, name='mean')
variance = tf.math.reduce_mean(
tf.math.squared_difference(y, tf.stop_gradient(mean)),
axes,
keepdims=True,
name='variance')
if not keep_dims:
mean = tf.squeeze(mean, axes)
variance = tf.squeeze(variance, axes)
variance = tf.math.reduce_mean(variance)
if x.dtype == tf.float16:
return (tf.cast(mean, tf.float16),
tf.cast(variance, tf.float16))
else:
return mean, variance
作用:
一个简易改进版的Batch Normalization,思路简单有效。
8.学习率策略
[1506.01186] Cyclical Learning Rates for Training Neural Networks (arxiv.org)
作用:
一个推荐的学习率策略方案,特定情况下可以取得更好的泛化。
9.重参数化
https://zhuanlan.zhihu.com/p/361090497
作用:
通过同时训练多份参数,合并权重的思路来提升模型泛化性。
10.长尾学习
[2110.04596] Deep Long-Tailed Learning: A Survey (arxiv.org)
Jorwnpay/A-Long-Tailed-Survey: 本项目是 Deep Long-Tailed Learning: A Survey 文章的中译版 (github.com)
作用:
解决长尾问题,可以加速收敛,提升模型泛化,稳定训练。
Tensorflow2 深度学习十必知的更多相关文章
- 对比深度学习十大框架:TensorFlow 并非最好?
http://www.oschina.net/news/80593/deep-learning-frameworks-a-review-before-finishing-2016 TensorFlow ...
- 推荐系统遇上深度学习(十)--GBDT+LR融合方案实战
推荐系统遇上深度学习(十)--GBDT+LR融合方案实战 0.8012018.05.19 16:17:18字数 2068阅读 22568 推荐系统遇上深度学习系列:推荐系统遇上深度学习(一)--FM模 ...
- 《TensorFlow2深度学习》学习笔记(一)Tensorflow基础
本系列笔记记录了学习TensorFlow2的过程,主要依据 https://github.com/dragen1860/Deep-Learning-with-TensorFlow-book 进行学习 ...
- mysql学习--mysql必知必会1
例如以下为mysql必知必会第九章開始: 正則表達式用于匹配特殊的字符集合.mysql通过where子句对正則表達式提供初步的支持. keywordregexp用来表示后面跟的东西作为正則表達式 ...
- mysql学习--mysql必知必会
上图为数据库操作分类: 下面的操作參考(mysql必知必会) 创建数据库 运行脚本建表: mysql> create database mytest; Query OK, 1 row ...
- 学习axios必知必会(2)~axios基本使用、使用axios前必知细节、axios和实例对象区别、拦截器、取消请求
一.axios的基本使用: ✿ 使用axios前必知细节: 1.axios 函数对象(可以作为axios(config)函数使用去发送请求,也可以作为对象调用方法axios.request(confi ...
- 【Android Api 翻译4】android api 完整翻译之Contacts Provider (学习安卓必知的api,中英文对照)
Contacts Provider 电话簿(注:联系人,联络人.通信录)提供者 ------------------------------- QUICKVIEW 快速概览 * Android's r ...
- 《TensorFlow2深度学习》学习笔记(四)对笔记二中的模型增加正确率展示
全部代码如下:(红色部分为与笔记二不同之处) #1.Import the neccessary libraries needed import numpy as np import tensorflo ...
- 学习MyBatis必知必会(2)~MyBatis基本介绍和MyBatis基本使用
一.MyBatis框架基本介绍: 1.认识 MyBatis: MyBatis 是支持普通 SQL 查询,存储过程和高级映射的持久层框架,严格上说应该是一个 SQL 映射框架. 其前身是 iBatis, ...
随机推荐
- formData一般用法,移动端,pc端都可以用,pc有兼容性问题
其实FormData是一个 对象他是一个比较新的东东(其实我也不知道改叫什么好) 利用FormData对象,你可以使用一系列的键值对来模拟一个完整的表单,然后使用XMLHttpRequest发送这个& ...
- javaWeb代码整理02-jdbcTemplete数据库连接工具
jar包: maven坐标: /**属于spring框架的包*/<dependency> <groupId>org.springframework</groupId> ...
- Java语言学习day30--8月5日
###10String类的其他方法 * A:String类的其他方法 * a: 方法介绍 * int length(): 返回字符串的长度 * String substring(int beginIn ...
- 2021.08.05 P1738 洛谷的文件夹(树形结构)
2021.08.05 P1738 洛谷的文件夹(树形结构) P1738 洛谷的文件夹 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn) 重点: 1.树!! 题意: 给出n个网页路径,求 ...
- 甲骨文严查Java授权,换openJDK要避坑
背景 外媒The Register报道,甲骨文稽查企业用户,近期开始将把过去看管较松散的Java授权加入. 甲骨文针对标准版Java(Java SE)有2种商业授权.2019年4月甲骨文宣布Java ...
- Rancher部署PostgreSQL容器
1.打开工作负载,选择部署服务 2.选择合适的PostgreSQL镜像 镜像地址https://registry.hub.docker.com/_/postgres,也可使用公司内部镜像库 网络模式选 ...
- Gson解析:java.lang.IllegalArgumentException: declares multiple JSON fields named status 问题的解决
在一次写定义系统统一返回值的情况下,碰到了java.lang.IllegalArgumentException: declares multiple JSON fields named status这 ...
- BootstrapBlazor实战 Menu 导航菜单使用(1)
实战BootstrapBlazorMenu 导航菜单的使用, 以及整合Freesql orm快速制作菜单项数据库后台维护页面 demo演示的是Sqlite驱动,FreeSql支持多种数据库,MySql ...
- C/C++游戏项目:中国程序员一定要会的中国象棋教程
中国象棋是中国一种流传十分广泛的游戏. 下棋双方根据自己对棋局形式的理解和对棋艺规律的掌握,调动车马,组织兵力,协调作战在棋盘这块特定的战场上进行着象征性的军事战斗. 象棋,亦作"象碁&qu ...
- MySQL常用数据类型及细节
目录 1 整数类型 1.1 可选属性 1.1.1 M 1.1.2 UNSIGNED 1.1.3 ZEROFILL 2 浮点类型 2.1 精度误差 3 定点数类型 3.1 数据精度说明 3.2 类型介绍 ...