[PyTorch]PyTorch中模型的参数初始化的几种方法(转)
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
本文目录
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
转载请注明出处:
http://www.cnblogs.com/darkknightzh/p/8297793.html
参考网址:
http://pytorch.org/docs/0.3.0/nn.html?highlight=kaiming#torch.nn.init.kaiming_normal
https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py
https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
https://github.com/bamos/densenet.pytorch/blob/master/densenet.py
https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua
说明:暂时就这么多吧,错误之处请见谅。前两个初始化的方法见pytorch官方文档
1. xavier初始化
torch.nn.init.xavier_uniform(tensor, gain=1)
对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从均匀分布U(−a,a)" role="presentation" style="position: relative;">U(−a,a)U(−a,a),其中a=gain×2/(fan_in+fan_out)×3" role="presentation" style="position: relative;">a=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√×3–√a=gain×2/(fan_in+fan_out)×3,该初始化方法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_uniform(w, gain=nn.init.calculate_gain('relu'))
torch.nn.init.xavier_normal(tensor, gain=1)
对于输入的tensor或者变量,通过论文Understanding the difficulty of training deep feedforward neural networks” - Glorot, X. & Bengio, Y. (2010)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=gain×2/(fan_in+fan_out)" role="presentation" style="position: relative;">std=gain×2/(fan_in+fan_out)−−−−−−−−−−−−−−−−−−√std=gain×2/(fan_in+fan_out),该初始化方法也称Glorot initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:可选择的缩放参数
例如:
w = torch.Tensor(3, 5)
nn.init.xavier_normal(w)
2. kaiming初始化
torch.nn.init.kaiming_uniform(tensor, a=0, mode='fan_in')
对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从均匀分布U(−bound,bound)" role="presentation" style="position: relative;">U(−bound,bound)U(−bound,bound),其中bound=2/((1+a2)×fan_in)×3" role="presentation" style="position: relative;">bound=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√×3–√bound=2/((1+a2)×fan_in)×3,该初始化方法也称He initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)
mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。
例如:
w = torch.Tensor(3, 5)
nn.init.kaiming_uniform(w, mode='fan_in')
torch.nn.init.kaiming_normal(tensor, a=0, mode='fan_in')
对于输入的tensor或者变量,通过论文“Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification” - He, K. et al. (2015)的方法初始化数据。初始化服从高斯分布N(0,std)" role="presentation" style="position: relative;">N(0,std)N(0,std),其中std=2/((1+a2)×fan_in)" role="presentation" style="position: relative;">std=2/((1+a2)×fan_in)−−−−−−−−−−−−−−−−−−√std=2/((1+a2)×fan_in),该初始化方法也称He initialisation。
参数:
tensor:n维的 torch.Tensor 或者 autograd.Variable类型的数据
a:该层后面一层的激活函数中负的斜率(默认为ReLU,此时a=0)
mode:‘fan_in’ (default) 或者 ‘fan_out’. 使用fan_in保持weights的方差在前向传播中不变;使用fan_out保持weights的方差在反向传播中不变。
例如:
w = torch.Tensor(3, 5)
nn.init.kaiming_normal(w, mode='fan_out')
使用的例子(具体参见原始网址):
https://github.com/prlz77/ResNeXt.pytorch/blob/master/models/model.py

from torch.nn import init
self.classifier = nn.Linear(self.stages[3], nlabels)
init.kaiming_normal(self.classifier.weight)
for key in self.state_dict():
if key.split('.')[-1] == 'weight':
if 'conv' in key:
init.kaiming_normal(self.state_dict()[key], mode='fan_out')
if 'bn' in key:
self.state_dict()[key][...] = 1
elif key.split('.')[-1] == 'bias':
self.state_dict()[key][...] = 0

3. 实际使用中看到的初始化
3.1 ResNeXt,densenet中初始化
https://github.com/facebookresearch/ResNeXt/blob/master/models/resnext.lua
https://github.com/bamos/densenet.pytorch/blob/master/densenet.py
conv
n = kW* kH*nOutputPlane
weight:normal(,math.sqrt(/n))
bias:zero()
batchnorm
weight:fill()
bias:zero()
linear
bias:zero()
3.2 wide-residual-networks中初始化(MSRinit)
https://github.com/szagoruyko/wide-residual-networks/blob/master/models/utils.lua
conv
n = kW* kH*nInputPlane
weight:normal(,math.sqrt(/n))
bias:zero()
linear
bias:zero()
[PyTorch]PyTorch中模型的参数初始化的几种方法(转)的更多相关文章
- java中Map和List初始化的两种方法
第一种方法(常用方法): //初始化List List<string> list = new ArrayList</string><string>(); list. ...
- Pytorch基础(6)----参数初始化
一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...
- 服务器文档下载zip格式 SQL Server SQL分页查询 C#过滤html标签 EF 延时加载与死锁 在JS方法中返回多个值的三种方法(转载) IEnumerable,ICollection,IList接口问题 不吹不擂,你想要的Python面试都在这里了【315+道题】 基于mvc三层架构和ajax技术实现最简单的文件上传 事件管理
服务器文档下载zip格式 刚好这次项目中遇到了这个东西,就来弄一下,挺简单的,但是前台调用的时候弄错了,浪费了大半天的时间,本人也是菜鸟一枚.开始吧.(MVC的) @using Rattan.Co ...
- Spring3 MVC请求参数获取的几种方法
Spring3 MVC请求参数获取的几种方法 一. 通过@PathVariabl获取路径中的参数 @RequestMapping(value="user/{id}/{name}&q ...
- 获取网页URL地址及参数等的两种方法(js和C#)
转:获取网页URL地址及参数等的两种方法(js和C#) 一 js 先看一个示例 用javascript获取url网址信息 <script type="text/javascript&q ...
- 在Java Web程序中使用监听器可以通过以下两种方法
之前学习了很多涉及servlet的内容,本小结我们说一下监听器,说起监听器,编过桌面程序和手机App的都不陌生,常见的套路都是拖一个控件,然后给它绑定一个监听器,即可以对该对象的事件进行监听以便发生响 ...
- Spring3 MVC请求参数获取的几种方法[转]
Spring3 MVC请求参数获取的几种方法 Spring3 MVC请求参数获取的几种方法 一. 通过@PathVariabl获取路径中的参数 @RequestMapping(value=& ...
- PHP中获取文件扩展名的N种方法
PHP中获取文件扩展名的N种方法 从网上收罗的,基本上就以下这几种方式: 第1种方法:function get_extension($file){substr(strrchr($file, '.'), ...
- 在MySQL中设置事务隔离级别有2种方法:
在MySQL中设置事务隔离级别有2种方法: 1 在my.cnf中设置,在mysqld选项中如下设置 [mysqld] transaction-isolation = READ-COMMITTED 2 ...
随机推荐
- 设计模式之——Template模板模式
Template模式又叫模板模式,是在父类中定义处理流程的框架,在子类中实现具体处理逻辑的模式.当父类的模板方法被调用时程序行为也会不同,但是,不论子类的具体实现如何,处理的流程都会按照父类中所定义的 ...
- Python 之RabbitMQ使用
1. IO 多路复用 # select 模拟socket server # server 端 import select import socket import sys import queue s ...
- python widows安裝scipy
https://blog.csdn.net/github_39611196/article/details/76718707 Python3.x直接运行pip install scipy即可.Pyth ...
- 完全用nosql轻松打造千万级数据量的微博系统
其实微博是一个结构相对简单,但数据量却是很庞大的一种产品.标题所说的是千万级数据量也并不是一千万条微博信息而已,而是千万级订阅关系之间发布.在看 我这篇文章之前,大多数人都看过sina的杨卫华大牛的微 ...
- Linux(5)- MariaDB、mysql主从复制、初识redis
一.MYSQL(mariadb) MariaDB数据库管理系统是MySQL的一个分支,主要由开源社区在维护,采用GPL授权许可. 开发这个分支的原因之一是:甲骨文公司收购了MySQL后,有将MySQL ...
- 【Servlet】把文件写到Respond输出流里面供用户下载
本文区分于<[Jsp]把Java写到Respond输出流里面供用户下载>(点击打开链接)把原本该打印到控制台的内容,直接打印到一个文本文件txt中给用户下载. 实际上是<[Strut ...
- sql server动态行列转换
原文链接:https://www.cnblogs.com/gaizai/p/3753296.html sql server动态行列转换 一.本文所涉及的内容(Contents) 本文所涉及的内容(Co ...
- JavaScript Big-Int
这个库是为JavaScript中的大整数操作,如加,减,乘,除,mod,比较等. 这个库的原理是模拟笔和纸的操作,你可以操作整数,大到你的RAM允许. 例 var bigInt = require(' ...
- 在github上新建一个仓库并上传本地工程
扫盲:在github上新建一个仓库并上传本地工程 http://1ke.co/course/194 我自己新建了个项目,一步一步流程如下. zhoudd@desay:~/桌面/mini_embed_d ...
- 玩转DOM遍历——用NodeIterator实现getElementById,getElementsByTagName方法
先声明一下DOM2中NodeIterator和TreeWalker这两类型真的只是用来玩玩的,因为性能不行遍历起来超级慢,在JS中基本用不到它们,除了<高程>上有两三页对它的讲解外,谷歌的 ...