1. 均匀分布

torch.nn.init.uniform_(tensor, a=0, b=1)

从均匀分布U(a, b)中采样,初始化张量。

参数:

  • tensor - 需要填充的张量
  • a - 均匀分布的下界
  • b - 均匀分布的上界

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.uniform_(w)
tensor([[0.1755, 0.4399, 0.8769, 0.8465, 0.2909],
[0.9962, 0.6918, 0.1033, 0.7028, 0.5835],
[0.1754, 0.8796, 0.1900, 0.9504, 0.4433]])

均匀分布详解:

若 x 服从均匀分布,即 x~U(a,b),其概率密度函数(表征随机变量每个取值有多大的可能性)为,

\[f(x)=\begin{cases}\dfrac{1}{b-a},\quad a\lt x\lt b\\
0,\quad 其他
\end{cases}\\
\]

则有期望和方差,

\[E(x) = \int_{-\infty}^{\infty} xf(x)dx=\frac{1}{2}(a+b)\\
D(x) = E(x^2)-[E(x)]^2 = \frac{(b-a)^2}{12}
\]

2. 正态(高斯)分布

torch.nn.init.normal_(tensor, mean=0.0, std=1.0)

从给定的均值和标准差的正态分布\(N(mean, std^2)\)中生成值,初始化张量。

参数:

  • tensor - 需要填充的张量
  • mean - 正态分布的均值
  • std - 正态分布的标准偏差

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.normal_(w, mean=0, std=1)
tensor([[-0.9444, -0.3295, 0.1785, 0.4165, 0.3658],
[ 0.5130, -1.1455, -0.1335, -1.6953, 0.2862],
[-2.3368, 0.2380, -2.2001, -0.5455, 0.8179]])

正态分布详解:

若随机变量x服从正态分布,即 \(x\sim N(\mu, \sigma^2)\), 其概率密度函数为,

\[f(x) = \frac{1}{\sigma\sqrt{2\pi}}\exp\left(-\frac{(x-\mu^2)}{2\sigma^2}\right)
\]

正态分布概率密度函数中一些特殊的概率值:

  • 68.268949%的面积在平均值左右的一个标准差\(\sigma\)范围内(\(\mu\pm\sigma\))
  • 95.449974%的面积在平均值左右两个标准差\(2\sigma\)的范围内(\(\mu\pm2\sigma\))
  • 99.730020%的面积在平均值左右三个标准差\(3\sigma\)的范围内(\(\mu\pm3\sigma\))
  • 99.993666%的面积在平均值左右四个标准差\(4\sigma\)的范围内(\(\mu\pm4\sigma\))

\(\mu=0, \sigma=1\)时的正态分布是标准正态分布

3. Xavier初始化

3.1 Xavier均匀分布初始化

torch.nn.init.xavier_normal_(tensor, gain=1.0)

又称Glorot初始化,按照Glorot, X. & Bengio, Y.(2010)在论文Understanding the difficulty of training deep feedforward neural networks中描述的方法,从均匀分布U(−a, a)中采样,初始化输入张量tensor,其中a值由下式确定:

\[\rm a=gain\times\sqrt{\dfrac{6}{fan\_in+fan\_out}}
\]

参数:

  • tensor - 需要初始化的张量
  • gain - 可选的放缩因子

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.xavier_uniform_(w, gain=torch.nn.init.calculate_gain('relu'))
tensor([[ 0.2481, -0.8435, 0.0588, 0.1573, 0.2759],
[ 0.2016, -0.5504, -0.5280, -0.3070, 0.0889],
[-0.9897, -0.9890, -0.8091, 0.8624, -0.5661]])

3.2 Xavier正态分布初始化

torch.nn.init.xavier_normal_(tensor, gain=1.0)

又称Glorot初始化,按照Glorot, X. & Bengio, Y.(2010)在论文Understanding the difficulty of training deep feedforward neural networks中描述的方法,从正态分布\(N(0, std^2)\)中采样,初始化输入张量tensor,其中std值由下式确定:

\[\rm std=gain\times\sqrt{\dfrac{2}{fan\_in+fan\_out}}
\]

参数:

  • tensor - 需要初始化的张量
  • gain - 可选的放缩因子

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.xavier_normal_(w)
tensor([[ 0.6707, -0.3928, -0.0894, 0.4096, 0.4631],
[ 0.1267, -0.5806, 0.3407, -0.1110, -0.2400],
[-0.4212, 0.2857, -0.1210, -0.2891, 0.7141]])

3.3 Xavier初始化方法的由来:

神经网络的参数可以初始化全为0吗?当然不能。如果初始化为全0,输入经过每个神经元后的输出都是一样的,且后向传播时梯度根本无法向后传播(因为求梯度公式里有个乘积因子是参数w,全0的参数使得梯度全没了,感兴趣的可以去查看神经网络的BP推导过程),这样的模型训练一万年也没有意义。

如果不能初始化为全0,那么我们应该如何初始化呢?

在深层神经网络中,每一层的输出都是下一层的输入,为了使网络中的信息更好的流动,应保证每层方差尽可能相等(可以把前向传播过程看作是输入和一系列参数的连乘,若数值过大容易进入饱和区,反向传播时数值过大可能造成梯度爆炸,反之可能梯度消失)。因此,参数初始化就可以看作是从某个概率分布的区间中进行采样的过程,则初始化问题转化为求解特定概率分布的参数问题。

那么又如何保证每一层的方差尽可能相等呢?即\(Var(z^{l-1})=Var(z^l)\)

先考虑单层网络,n为神经元的数量,输出z的表达式为,

\[z=\sum\limits_{i=1}^n w_i x_i
\]

根据概率统计中的方差公式,有,

\[Var(w_i x_i)=Var(w_i)Var(x_i)+E[w_i]^2Var(x_i)+E[x_i]^2Var(w_i)
\]

当输入\(x_i\)和权重\(w_i\)的均值都是0时,即\(E[x_i]=E[w_i]=0\)(可使用BatchNormalization将输入的均值置0),上式简化为,

\[Var(w_i x_i)=Var(w_i)Var(x_i) \\
Var(z) = \sum\limits_{i=1}^nVar(w_ix_i) = \sum\limits_{i=1}^nVar(w_i)Var(x_i)
\]

进一步,假设随机变量\(w_i\)和\(x_i\)为独立同分布,则

\[Var(z)=nVar(w)Var(x)
\]

若是输入输出的方差一致,即\(Var(z)=Var(x)\),应有,

\[Var(w)=\dfrac{1}{n}
\]

其中n为输入层的神经元数量,即论文中的fan_in,而输出层的神经元数量fan_out往往和fan_in不相等,考虑到反向传播时是从后往前计算,所以论文中取了二者的均值,即令

\[Var(w)=\dfrac{2}{\rm fan\_in+fan\_out}
\]

由概率论基础知识知,若随机变量x服从区间[a,b]上的均匀分布,则x的方差为,

\[Var(x)=\dfrac{(b-a)^2}{12}
\]

代入上边\(Var(w)\)的方差公式,可以解得\(b-a=\dfrac{2\sqrt6}{\sqrt{\rm fan\_in+fan\_out}}\),即得参数w均匀采样情况下的采样区间,

\[w\sim U(-\dfrac{\sqrt6}{\sqrt{\rm fan\_in+fan\_out}},\dfrac{\sqrt6}{\sqrt{\rm fan\_in+fan\_out}})
\]

以上就是采用Xavier初始化方法,对均匀分布的参数进行求解的过程。同理,可以推出正态分布采样下的参数分布满足,

\[w\sim N(0,\dfrac{2}{\rm fan\_in+fan\_out})
\]

由于Xavier初始化方法是基于“均值为0”这个假设推导出的,对于ReLU等激活函数,其输出均大于等于0,\(E(x_i)=0\)的假设不再成立,所以Xavier初始化方法对ReLU通常效果不好。

4. kaiming初始化

4.1 kaiming均匀分布初始化

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

又称He初始化,按照He, K. et al. (2015)在论文Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification中描述的方法,从均匀分布U(−bound, bound)中采样,初始化输入张量tensor,其中bound值由下式确定:

\[\rm bound=gain\times\sqrt{\dfrac{3}{fan\_mode}}
\]

参数:

  • tensor - 需要初始化的张量
  • a - 这层之后使用的rectifier的斜率系数,用来计算\(\rm gain=\sqrt{\dfrac{2}{1+a^2}}\) (此参数仅在参数nonlinearity为'leaky_relu'时生效)
  • mode - 可以为“fan_in”(默认)或“fan_out”。“fan_in”维持前向传播时权值方差,“fan_out”维持反向传播时的方差
  • nonlinearity - 非线性函数(nn.functional中的函数名),pytorch建议仅与“relu”或“leaky_relu”(默认)一起使用。

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
tensor([[-0.3387, 0.8507, 0.5339, -0.2552, 0.4829],
[ 0.6565, -0.7444, -0.2138, -0.9352, -0.1449],
[-0.7871, 0.4095, 0.3562, -0.2796, -0.8638]])

4.1 kaiming正态分布初始化

torch.nn.init.kaiming_normal_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

又称He初始化,按照He, K. et al. (2015)在论文Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification中描述的方法,从正态分布\(N(0, std^2)\)中采样,初始化输入张量tensor,其中std值由下式确定:

\[\rm std=\dfrac{gain}{\sqrt{fan\_mode}}
\]

参数:

  • tensor - 需要初始化的张量
  • a - 这层之后使用的rectifier的斜率系数,用来计算\(\rm gain=\sqrt{\dfrac{2}{1+a^2}}\) (此参数仅在参数nonlinearity为'leaky_relu'时生效)
  • mode - 可以为“fan_in”(默认)或“fan_out”。“fan_in”维持前向传播时权值方差,“fan_out”维持反向传播时的方差
  • nonlinearity - 非线性函数(nn.functional中的函数名),pytorch建议仅与“relu”或“leaky_relu”(默认)一起使用。

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
tensor([[ 0.0251, 0.5042, 1.7288, 0.8096, -0.2114],
[ 0.0527, 0.2605, 0.8833, 0.4466, 1.8076],
[-1.1390, -0.8388, -1.0632, 0.0480, -0.2835]])

6. 正交矩阵初始化

torch.nn.init.orthogonal_(tensor, gain=1)

用一个(半)正交矩阵初始化输入张量,参考Saxe, A. et al. (2013) - Exact solutions to the nonlinear dynamics of learning in deep linear neural networks。输入张量必须至少有2维,对于大于2维的张量,超出的维度将被flatten化。

正交初始化可以使得卷积核更加紧凑,可以去除相关性,使模型更容易学到有效的参数。

参数:

  • tensor - 需要初始化的张量
  • gain - 可选的放缩因子

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.orthogonal_(w)
tensor([[ 0.1725, 0.7215, -0.3494, -0.3499, 0.4530],
[ 0.7070, 0.0384, -0.0893, -0.3016, -0.6322],
[-0.0815, 0.6231, 0.7038, 0.2127, -0.2542]])

7. 稀疏矩阵初始化

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

将2维的输入张量作为稀疏矩阵填充,其中非零元素由正态分布\(N(0,0.01^2)\)生成。 参考Martens, J.(2010)的 Deep learning via Hessian-free optimization

参数:

  • tensor - 需要填充的张量
  • sparsity - 每列中需要被设置成零的元素比例
  • std - 用于生成非零元素的正态分布的标准偏差

代码示例:

>>> w = torch.Tensor(3, 5)
>>> torch.nn.init.sparse_(w, sparsity=0.1)
tensor([[ 0.0030, 0.0000, 0.0049, -0.0161, 0.0000],
[-0.0081, -0.0022, 0.0000, 0.0112, 0.0060],
[ 0.0000, -0.0211, 0.0161, 0.0000, 0.0147]])

PyTorch常用参数初始化方法详解的更多相关文章

  1. Linux线程体传递参数的方法详解

    传递参数的两种方法 线程函数只有一个参数的情况:直接定义一个变量通过应用传给线程函数. 例子 #include #include using namespace std; pthread_t thre ...

  2. python内置常用内置方法详解

    # print(locals()) # print(globals()) def func(): x = 1 y = 1 print(locals()) # 函数内部的变量 print(globals ...

  3. $.ajax()方法详解 ajax之async属性 【原创】详细案例解剖——浅谈Redis缓存的常用5种方式(String,Hash,List,set,SetSorted )

    $.ajax()方法详解   jquery中的ajax方法参数总是记不住,这里记录一下. 1.url: 要求为String类型的参数,(默认为当前页地址)发送请求的地址. 2.type: 要求为Str ...

  4. Delphi中TStringList类常用属性方法详解

    TStrings是一个抽象类,在实际开发中,是除了基本类型外,应用得最多的. 常规的用法大家都知道,现在来讨论它的一些高级的用法. 先把要讨论的几个属性列出来: 1.CommaText 2.Delim ...

  5. HTTP请求方法详解

    HTTP请求方法详解 请求方法:指定了客户端想对指定的资源/服务器作何种操作 下面我们介绍HTTP/1.1中可用的请求方法: [GET:获取资源]     GET方法用来请求已被URI识别的资源.指定 ...

  6. Clone使用方法详解【转载】

    博客引用地址:Clone使用方法详解 Clone使用方法详解   java“指针”       Java语言的一个优点就是取消了指针的概念,但也导致了许多程序员在编程中常常忽略了对象与引用的区别,本文 ...

  7. VC++常用数据类型及其操作详解

    原文地址:http://blog.csdn.net/ithomer/article/details/5019367 VC++常用数据类型及其操作详解 一.VC常用数据类型列表 二.常用数据类型转化 2 ...

  8. PHP 中 16 个魔术方法详解

    PHP 中 16 个魔术方法详解   前言 PHP中把以两个下划线__开头的方法称为魔术方法(Magic methods),这些方法在PHP中充当了举足轻重的作用. 魔术方法包括: __constru ...

  9. C++调用JAVA方法详解

    C++调用JAVA方法详解          博客分类: 本文主要参考http://tech.ccidnet.com/art/1081/20050413/237901_1.html 上的文章. C++ ...

随机推荐

  1. iscroll5 滚动条根据内容高度自动显示隐藏及强制横屏时方向错位

    横竖屏方向错位: move: function (e) { if ( !this.enabled || utils.eventType[e.type] !== this.initiated ) { r ...

  2. 被产品经理怼了,线上出Bug为啥你不知道

    前言 前几天跟读者聊天,他说被产品经理给怼了.原因是线上出 Bug 了,最后是客户反馈才知道的. 我就问他:你们是不是没做监控? 读者:我们是刚成立的创业团队,目前最重要的就是堆功能,很多基础设施都没 ...

  3. dubbo学习(九)dubbo监控中心

    安装与配置 下载地址:https://github.com/apache/dubbo-admin/tree/master(包含管理控制台和监控中心) PS:  下载前要选择master分支以后再进行下 ...

  4. 空间向量变换,以及OpenGL的glm库简单应用

    测试项目请查看GitHub库 GLBIproject2/GLBIProject2_2

  5. WAF的那些事

    介绍WAF 本节主要介绍WAF (Web Application Firewall, Web应用防火墙)及与其相关的知识,这里利用国际上公认的一种说法: Web应用防火墙是通过执行系列针对HTTP/H ...

  6. JDK动态代理详解

    JDK动态代理是代理模式的一种,且只能代理接口.spring也有动态代理,称为CGLib,现在主要来看一下JDK动态代理是如何实现的? 一.介绍 JDK动态代理是有JDK提供的工具类Proxy实现的, ...

  7. Maven依赖管理之BOM

    目录 什么是BOM 一个BOM的格式 怎么使用BOM 通过parent引用 通过dependencyManagement引用 怎么查看依赖的某个BOM的具体清单 版本冲突时的一些规则 何为依赖调节 参 ...

  8. 智能卡加密芯片SMEC90ST

    深圳市中巨伟业信息科技有限公司 最新推出一款单价低,安全性高的智能卡安全芯片,产品型号为:SMEC90ST,采用32-bit ARM SC100 SecureCore Processor 安全内核处理 ...

  9. 003 01 Android 零基础入门 01 Java基础语法 01 Java初识 03 Java程序的执行流程

    003 01 Android 零基础入门 01 Java基础语法 01 Java初识 03 Java程序的执行流程 Java程序长啥样? 首先编写一个Java程序 记事本编写程序 打开记事本 1.wi ...

  10. uni-app引入iconfont字体图标

    1 首先进入你的iconfont项目 很好, 看见圈圈的吗 , 我说蓝色的,记住了,选到这个 ,然后点击下载本地项目, 解压完就是这个了 ,然后把 圈起来的放到你的项目文件里面 ,记得引入的时候路径别 ...