pytorch在torch.nn.init中提供了常用的初始化方法函数,这里简单介绍,方便查询使用。

介绍分两部分:

1. Xavier,kaiming系列;

2. 其他方法分布

Xavier初始化方法,论文在《Understanding the difficulty of training deep feedforward neural networks》

公式推导是从“方差一致性”出发,初始化的分布有均匀分布和正态分布两种。

1. Xavier均匀分布

torch.nn.init.xavier_uniform_(tensor, gain=1)

xavier初始化方法中服从均匀分布U(−a,a) ,分布的参数a = gain * sqrt(6/fan_in+fan_out),

这里有一个gain,增益的大小是依据激活函数类型来设定

eg:nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

PS:上述初始化方法,也称为Glorot initialization

2. Xavier正态分布

torch.nn.init.xavier_normal_(tensorgain=1)

xavier初始化方法中服从正态分布,

mean=0,std = gain * sqrt(2/fan_in + fan_out)

kaiming初始化方法,论文在《 Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》,公式推导同样从“方差一致性”出法,kaiming是针对xavier初始化方法在relu这一类激活函数表现不佳而提出的改进,详细可以参看论文。

3. kaiming均匀分布

torch.nn.init.kaiming_uniform_(tensora=0mode='fan_in'nonlinearity='leaky_relu')

此为均匀分布,U~(-bound, bound), bound = sqrt(6/(1+a^2)*fan_in)

其中,a为激活函数的负半轴的斜率,relu是0

mode- 可选为fan_in 或 fan_out, fan_in使正向传播时,方差一致; fan_out使反向传播时,方差一致

nonlinearity- 可选 relu 和 leaky_relu ,默认值为 。 leaky_relu

nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

4. kaiming正态分布

torch.nn.init.kaiming_normal_(tensora=0mode='fan_in'nonlinearity='leaky_relu')

此为0均值的正态分布,N~ (0,std),其中std = sqrt(2/(1+a^2)*fan_in)

其中,a为激活函数的负半轴的斜率,relu是0

mode- 可选为fan_in 或 fan_out, fan_in使正向传播时,方差一致;fan_out使反向传播时,方差一致

nonlinearity- 可选 relu 和 leaky_relu ,默认值为 。 leaky_relu

nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')

2.其他

5. 均匀分布初始化

torch.nn.init.uniform_(tensora=0b=1)

使值服从均匀分布U(a,b)

6. 正态分布初始化

torch.nn.init.normal_(tensormean=0std=1)

使值服从正态分布N(mean, std),默认值为0,1

7. 常数初始化

torch.nn.init.constant_(tensorval)

使值为常数val nn.init.constant_(w, 0.3)

8. 单位矩阵初始化

torch.nn.init.eye_(tensor)

将二维tensor初始化为单位矩阵(the identity matrix)

9. 正交初始化

torch.nn.init.orthogonal_(tensorgain=1)

使得tensor是正交的,论文:Exact solutions to the nonlinear dynamics of learning in deep linear neural networks” - Saxe, A. et al. (2013)

10. 稀疏初始化

torch.nn.init.sparse_(tensorsparsitystd=0.01)

从正态分布N~(0. std)中进行稀疏化,使每一个column有一部分为0

sparsity- 每一个column稀疏的比例,即为0的比例

nn.init.sparse_(w, sparsity=0.1)

11. 计算增益

torch.nn.init.calculate_gain(nonlinearityparam=None)

PyTorch 学习笔记(四):权值初始化的十种方法的更多相关文章

  1. pytorch(14)权值初始化

    权值的方差过大导致梯度爆炸的原因 方差一致性原则分析Xavier方法与Kaiming初始化方法 饱和激活函数tanh,非饱和激活函数relu pytorch提供的十种初始化方法 梯度消失与爆炸 \[H ...

  2. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  3. [PyTorch 学习笔记] 4.1 权值初始化

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/grad_vanish_explod.py 在搭建好网络 ...

  4. ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试   刚开始学习tf时,我们从 ...

  5. 机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记

    机器学习实战(Machine Learning in Action)学习笔记————06.k-均值聚类算法(kMeans)学习笔记 关键字:k-均值.kMeans.聚类.非监督学习作者:米仓山下时间: ...

  6. PyTorch学习系列(九)——参数_初始化

    from:http://blog.csdn.net/VictoriaW/article/details/72872036 之前我学习了神经网络中权值初始化的方法 那么如何在pytorch里实现呢. P ...

  7. python3.4学习笔记(四) 3.x和2.x的区别,持续更新

    python3.4学习笔记(四) 3.x和2.x的区别 在2.x中:print html,3.x中必须改成:print(html) import urllib2ImportError: No modu ...

  8. [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py 这篇文章来看下 ...

  9. [PyTorch 学习笔记] 6.2 Normalization

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson6/bn_and_initialize.py https: ...

随机推荐

  1. Redis → Windows下搭建redis集群

    一,redis集群介绍 Redis cluster(redis集群)是在版本3.0后才支持的架构,和其他集群一样,都是为了解决单台服务器不够用的情况,也防止了主服务器宕机无备用服务器,多个节点网络互联 ...

  2. 猜数字游戏_Python

    预先设置数字变量 age_of_test = 25 #这里设置为25,也可随意 guess_age = int (input("guess age:")) if guess_age ...

  3. hive语句on和where一点小问题

    hive join 后面必须=(0.13版本后支持,不支持like,<>),on后面如需加条件语句必须放到where中不然会产生错误结果 (可以一对多,一对一,不可以多对多‘会出现数据翻倍 ...

  4. node中__dirname、__filename表示的路径

    __dirname 表示当前文件所在的目录的绝对路径__filename 表示当前文件的绝对路径module.filename ==== __filename 等价process.cwd() 返回运行 ...

  5. 洛谷P1681 最大正方形II

    P1681 最大正方形II 题目背景 忙完了学校的事,v神终于可以做他的“正事”:陪女朋友散步.一天,他和女朋友走着走着,不知不觉就来到 了一个千里无烟的地方.v神正要往回走,如发现了一块牌子,牌子上 ...

  6. php rmdir使用递归函数删除非空目录的方法

    php rmdir()函数 rmdir ― 删除空目录 语法: bool rmdir ( string $dirname [, resource $context ] )尝试删除 dirname 所指 ...

  7. 【机器学习PAI实战】—— 玩转人工智能之利用GAN自动生成二次元头像

    前言 深度学习作为人工智能的重要手段,迎来了爆发,在NLP.CV.物联网.无人机等多个领域都发挥了非常重要的作用.最近几年,各种深度学习算法层出不穷, Generative Adverarial Ne ...

  8. Ajax--Ajax基于原生javascript:创建Ajax对象、链接服务器、发送请求、接受响应结果

    Ajax概述 异步:指某段程序执行时不会阻塞其它程序执行,其表现形式为程序的执行顺序不依赖程序本身的书写顺序,相反则为同步. 同步请求: 请求是由浏览器发送 页面会刷新 异步请求: 请求是由浏览器的一 ...

  9. PLUTO平台是由美林数据技术股份有限公司下属西安交大美林数据挖掘研究中心自主研发的一款基于云计算技术架构的数据挖掘产品,产品设计严格遵循国际数据挖掘标准CRISP-DM(跨行业数据挖掘过程标准),具备完备的数据准备、模型构建、模型评估、模型管理、海量数据处理和高纬数据可视化分析能力。

    http://www.meritdata.com.cn/article/90 PLUTO平台是由美林数据技术股份有限公司下属西安交大美林数据挖掘研究中心自主研发的一款基于云计算技术架构的数据挖掘产品, ...

  10. Mathcad 是一种工程计算软件,主要运算功能:代数运算、线性代数、微积分、符号计算、2D和3D图表、动画、函数、程序编写、逻辑运算、变量与单位的定义和计算等。

    Mathcad软件包Mathcad是由MathSoft公司(2006 年4 月被美国PTC收购)推出的一种交互式数值计算系统. Mathcad 是一种工程计算软件,作为工程计算的全球标准,与专有的计算 ...