DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)
PS:要转载请注明出处,本人版权所有。
PS: 这个只是基于《我自己》的理解,
如果和你的原则及想法相冲突,请谅解,勿喷。
前置说明
本文作为本人csdn blog的主站的备份。(BlogID=109)
环境说明
- Windows 10
- VSCode
- Python 3.8.10
- Pytorch 1.8.1
- Cuda 10.2
前言
如果有计算机背景的相关童鞋,都应该知道数值计算中的上溢和下溢的问题。关于计算机中的数值表示,在我的《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》 (https://blog.csdn.net/u011728480/article/details/100277582) 一文中有比较好的介绍。计算机中的数值表示,相对于实数数轴来说是离散且有限的,意思就是计算机中的能表示的数有最大值和最小值以及最小单位,特别是浮点数表示,有兴趣的可以看看上文。
其实很好理解,深度学习里面具有大量的乘法加法,一不小心你就会遇见上溢和下溢的问题,因此我们一不小心就会遇见NAN和INF的问题(NAN和INF详见上文提到的文章)。此外,由于一些特殊的情况,可能会导致我们的参数的偏导数接近于0,让我们的模型收敛的非常的慢。因此我们可能需要从模型的初始化以及相关的模型构造方面来好好的讨论一下我们在训练过程中可能出现的问题。
一般来说,我们训练的时候都非常的关注我们的损失函数,如果损失函数值异常,会导致相关的偏导数出现接近于0或者接近于无限大,那么就会直接导致模型训练及其困难。此外,我们的权重参数也会参与网络计算,按照上述的描述,权重参数的初始值也可能导致损失函数的值异常。因此大佬们也引入了另外一种常见的初始化方式Xavier,比较具有普适性。下面我们简单的验证一下我们训练过程中出现梯度接近于0和接近于无限大的情况,这里也就是说的梯度消失和梯度爆炸问题。同时也简单说明参数初始化相关的问题。
梯度消失(gradient vanishing)
在深度学习中有一个激活层叫做Sigmoid层,其定义如下是:\(Sigmoid(x)=1/(1+\exp(-x))\),如果我们的模型里面接入了这种激活函数,很容易构造出梯度消失的情况,下面我们看一下其导数和函数值相对于X的相关关系。
代码如下:
import torch
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')
def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):
ax.set_xlabel('x')
ax.set_ylabel('sigmoid(x)')
ax.set_title('sigmoid/gradient-sigmoid')
ax.set_xlim(xlim_min, xlim_max)
ax.set_ylim(ylim_min, ylim_max)
ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))
line0.set_data(xdata[0], ydata[0])
line1.set_data(xdata[1], ydata[1])
plt.show()
def sigmoid_test():
x = np.arange(-10.0, 10.0, 0.1)
x = torch.tensor(x, dtype=torch.float, requires_grad=True)
sig_fun = torch.nn.Sigmoid()
y = sig_fun(x)
y.backward(torch.ones_like(y))
xdata[0] = x.detach().numpy()
xdata[1] = x.detach().numpy()
ydata[0] = y.detach().numpy()
ydata[1] = x.grad.detach().numpy()
init_and_show(-10.0, 10.0, 0, 1)
def multi_mat_dot():
M = np.random.normal(size=(4, 4))
print('⼀个矩阵\n', M)
for i in range(10000):
M = np.dot(M, np.random.normal(size=(4, 4)))
print('乘以100个矩阵后\n', M)
if __name__ == '__main__':
sigmoid_test()
结果图如下
我们可以从图中看到,当x小于-5和大于+5的时候,其导数的值接近于0,导致bp的时候,参数更新小,模型收敛的特别的慢。
梯度爆炸(gradient exploding)
现在我们假设我们有一个模型,其有N个线性层构成,定义输入为X,标签为Y,模型为 \(M(X) = X*W_1 .... W_{n-2}*W_{n-1}*W_n\),损失函数为\(L(X) = M(X) - Y = X*W_1 .... W_{n-2}*W_{n-1}*W_n - Y\),求W1关于损失函数的偏导数\(\frac{dL(X)}{dW_1} = X*W_2 .... W_{n-2}*W_{n-1}*W_n\)。从这里我们可以看到W2到Wn与输入的X的乘积构成了W1的偏导数。
下面我们简单的构造一个矩阵,然后让他计算100次乘法。代码如下:
import torch
import numpy as np
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
xdata, ydata = [[], []], [[], []]
line0, = ax.plot([], [], 'r-', label='sigmoid')
line1, = ax.plot([], [], 'b-', label='gradient-sigmoid')
def init_and_show(xlim_min, xlim_max, ylim_min, ylim_max):
ax.set_xlabel('x')
ax.set_ylabel('sigmoid(x)')
ax.set_title('sigmoid/gradient-sigmoid')
ax.set_xlim(xlim_min, xlim_max)
ax.set_ylim(ylim_min, ylim_max)
ax.legend([line0, line1], ('sigmoid', 'gradient-sigmoid'))
line0.set_data(xdata[0], ydata[0])
line1.set_data(xdata[1], ydata[1])
plt.show()
def sigmoid_test():
x = np.arange(-10.0, 10.0, 0.1)
x = torch.tensor(x, dtype=torch.float, requires_grad=True)
sig_fun = torch.nn.Sigmoid()
y = sig_fun(x)
y.backward(torch.ones_like(y))
xdata[0] = x.detach().numpy()
xdata[1] = x.detach().numpy()
ydata[0] = y.detach().numpy()
ydata[1] = x.grad.detach().numpy()
init_and_show(-10.0, 10.0, 0, 1)
def multi_mat_dot():
M = np.random.normal(size=(4, 4))
print('⼀个矩阵\n', M)
for i in range(100):
M = np.dot(M, np.random.normal(size=(4, 4)))
print('乘以100个矩阵后\n', M)
if __name__ == '__main__':
multi_mat_dot()
他计算100次乘法后结果如下:
我们可以看到,经过100次乘法后,其值已经非常大(小)了指数都是到了25了。这个时候算出来的损失非常大的,这个时候梯度也非常大,很容易导致训练异常。
参数初始化之Xavier
文首我们提到,我们之前的参数初始化都是基于期望为0,方差为一个指定值初始化的,这里面的指定值是随个人定义的,这个可能会给我们的训练过程带来困扰。
但是我们可以从以下的角度来看待这个事情,我们的权重参数W是一个期望为0,方差为\(\delta^2\)的特定分布。我们的输入特征X是一个期望为0,方差为\(\lambda^2\)的特定分布(注意这里不仅仅是正态分布)。我们假设我们的模型是线性模型,那么其输出为:\(O_i = \sum\limits_{j=1}^{n}W_{ij}X_{j}\),\(O_i\)是代表第i层的输出。这个时候,我们求出\(O_i\)的期望是:\(E(O_i) = \sum\limits_{j=1}^{n}E(W_{ij}X_{j}) = \sum\limits_{j=1}^{n}E(W_{ij})E(X_{j}) = 0\),其方差为:\(Variance(O_i) = E(O_i^2) - (E(O_i))^2 = \sum\limits_{j=1}^{n}E(W_{ij}^2X_{j}^2) - 0 = \sum\limits_{j=1}^{n}E(W_{ij}^2)E(X_{j}^2) = n*\delta^2*\lambda^2\)。我们现在假设如果要\(O_i\)的方差等于X的方差,那么\(n*\delta^2 = 1\)才能够满足要求。现在我们考虑BP的时候,也需要\(n_{out}*\delta^2 = 1\)才能够保证方差不会变,至少从数值稳定性来说,我们应该保证方差尽量稳定,不应该放大。我们同时考虑n和\(n_{out}\),那么我们可以认为当\(1/2*(n+n_{out})*\delta^2 = 1\)时,我们保证了输出O的方差在约定范围内,尽量保证了其数值的稳定性,这就是Xavier方法的核心内容。
初始化方法有很多,但是Xavier方法有较大的普适性。对于某些模型,特定的初始化方法有奇效。
后记
到本文结束,其实我们可以训练一些简单的模型了,但是本文所介绍的3个概念会一直伴随着我们以后的学习过程,如果训练出现了INF,NAN这些特殊的值,基本我们就需要往这方面去想和解决问题。
参考文献
- https://github.com/d2l-ai/d2l-zh/releases (V1.0.0)
- https://github.com/d2l-ai/d2l-zh/releases (V2.0.0 alpha1)
- https://blog.csdn.net/u011728480/article/details/100277582 《数与计算机 (编码、原码、反码、补码、移码、IEEE 754、定点数、浮点数)》
打赏、订阅、收藏、丢香蕉、硬币,请关注公众号(攻城狮的搬砖之路)
PS: 请尊重原创,不喜勿喷。
PS: 要转载请注明出处,本人版权所有。
PS: 有问题请留言,看到后我会第一时间回复。
DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)的更多相关文章
- DL基础补全计划(二)---Softmax回归及示例(Pytorch,交叉熵损失)
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(三)---模型选择、欠拟合、过拟合
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(六)---卷积和池化
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- DL基础补全计划(一)---线性回归及示例(Pytorch,平方损失)
PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明 本文作为本人csdn blog的主站的备份.(Bl ...
- OSPF补全计划-0 preface
哇靠,一看日历吓了我一跳,我这一个月都没写任何东西,好吧,事情的确多了点儿,同事离职,我需要处理很多untechnical的东西,弄得我很烦,中间学的一点小东西(关于Linux的)也没往这里记,但是我 ...
- 【hjmmm网络流24题补全计划】
本文食用方式 按ABC--分层叙述思路 可以看完一步有思路后自行思考 飞行员配对问题 题目链接 这可能是24题里最水的一道吧... 很显然分成两个集合 左外籍飞行员 右皇家飞行员 跑二分图最大匹配 输 ...
- 2018.我的NOIP补全计划
code: efzoi.tk @ shleodai noip2011 D1 选择客栈 这道题是一道大水题,冷静分析一会就会发现我们需要维护最后一个不合法点和前缀和. 维护最后一个不合法点只要边扫描边维 ...
- OSPF补全计划-2
想起来几个面试题: 1. OSPF在什么情况下会stuck in Exstart /Exchange状态? 我知道的一个答案是两个端口的mtu不一致.当然整个也不是绝对,因为可以用ip ospf mt ...
- OSPF补全计划-1
OSPF全称是啥我就不絮叨了,什么迪杰斯特拉,什么开放最短路径优先算法都是人尽皆知的事儿,尤其是一提算法还会被学数据结构的童鞋鄙视,干脆就不提了,直接开整怎么用吧.(不过好像真有人不知道OSPF里的F ...
随机推荐
- Kubernetes ConfigMap详解,多种方式创建、多种方式使用
我最新最全的文章都在南瓜慢说 www.pkslow.com,欢迎大家来喝茶! 1 简介 配置是程序绕不开的话题,在Kubernetes中使用ConfigMap来配置,它本质其实就是键值对.本文讲解如何 ...
- 微信sdk上传图片大小1k,损坏的问题以及微信上传图片需要的配置
微信公众号的appid和appsecret有问题,会导致上传图片大小为1k这个问题 微信上传图片需要设置公众号的'JS接口安全域名'
- QTreeView 使用 QStandardItemModel
QTreeView 使用 QStandardItemModel @ 目录 QTreeView 使用 QStandardItemModel 前言 一.直接上图 二.添加同级结点项 1.思路 2.实现 二 ...
- js jq计算器
<html><head><meta http-equiv="Content-type" content="text/html; charse ...
- 全新安装Windows版 Atlassian Confluence 7.3.1 + MySQL 8.0,迁移数据,并设置服务自启
Confluence是一个专业的企业知识管理与协同软件,也可以用于构建企业wiki.使用简单,但它强大的编辑和站点管理特征能够帮助团队成员之间共享信息.文档协作.集体讨论,信息推送. 安装Conflu ...
- Qt之先用了再说系列-信号与槽
QT之信号与槽 简介:信号与槽可是Qt最大成功点,也是整个Qt基本核心机制,如果不会信号与槽,将无法领略Qt之美: 1.信号与槽函数原型: QObject::connect(const QObject ...
- 10、修改windows编码集
10.1.查看Windows的字符集编码: 1.方法一: (1) 同时按住"windows"徽标键和"r"键,在弹出的"运行"框中输入&qu ...
- scRNAseq benchmark 学习笔记
背景 把早年没填完的坑(单细胞测序的细胞类型鉴别)给重新拾起来 其Github描述的基本情况: 作者并不对单个分类器进行说明,统一包装在benchmark工程里,还建立了docker容器 但说明了在s ...
- promise的基本使用
// 什么情况下适用promise? // 一般情况下是有异步请求操作时,使用promise对这个异步操作进行封装 // new ->构造函数(1.保存了一些状态信息 2.执行传入的函数) // ...
- AOP面向切面的实现
AOP(Aspect Orient Programming),我们一般称为面向方面(切面)编程,作为面向对象的一种补充,用于处理系统中分布于各个模块的横切关注点,比如事务管理.日志.缓存等等. AOP ...