Pytorch系列:(七)模型初始化
为什么要进行初始化
首先假设有一个两层全连接网络,第一层的第一个节点值为 \(H_{11}= \sum_{i=0}^n X_i*W_{1i}\),
这个时候,方差为 \(D(H_{11}) = \sum_{i=0}^n D(X_i) * D(W_{1i})\), 这个时候,输入\(X_i\)一般会做归一化,那么其方差为1,而权重W如果不进行归一化的话,H的方差就会变得很大,然后多层累计,下一次的输入会越来越大,使得网络不好收敛,如果权重W进行了初始化,使得其方差保持在1/n附近,那么方差H则会收敛在1附近,从而使得网络变得更好优化。 很多初始化都是使用的这个原理,控制每一层的输出,使其保持在一定的范围内。
一些常见初始化方法
Xavier
Xavier初始化也是类似的原理, 假设输入X 以及做了归一化,其方差为1 ,那么Xavier所希望的就是上述公式D(H) 保持在1左右,那么就可以得到公式
\]
其中n1 和 n2 为网络层的输入输出节点数量,一般情况下,输入输出是不一样的,为了均衡考虑,可以做一个平均操作,于是变得到 \(D(W) = \frac{2}{n_1+n_2}\)
这个时候,我们假设 W服从均匀分布 \(U[-a, a]\), 那么在这个条件下,
\]
推出\(a = \frac{\sqrt{6}}{\sqrt{n_1+n_2+1}}\),从而得到:
\]
这样就可以得到Xavier初始化,在pytorch中使用Xavier初始化方式如下,值得注意的是,Xavier对于sigmoid和tanh比较好,对于其他的可能效果就不是那么好了
nn.init.xavier_uniform_(m.weight.data)
Kaiming
Kaiming 初始化比较适合ReLU激活函数,其原理也跟上述差不多,也是希望将权重的方差保持在一定的范围内,使得正反向传播的值得到有效的控制,在kaiming初始化中,主要将权重的方差设置为 \(D(w) = \frac{2}{ni}\),由于考虑到ReLU激活函数,将方差调整为\(D(w)= \frac{2}{(1+a^2)*n_i}\), 这里的a是ReLU的斜率。
在pytorch中使用Kaiming初始化
nn.init.kaiming_normal_(m.weight.data)


LSTM初始化
LSTM中,公式和参数值的设定如下所示
在LSTM中,由于很多门控的权重尺寸是一样的,所以可以使用如下方法进行初始化
def _init_lstm(self, weight):
for w in weight.chunk(4, 0):
init.xavier_uniform(w)
self._init_lstm(self.lstm.weight_ih_l0)
self._init_lstm(self.lstm.weight_hh_l0)
self.lstm.bias_ih_l0.data.zero_()
self.lstm.bias_hh_l0.data.zero_()
Embedding进行初始化
self.embedding = nn.Embedding(embedding_tokens, embedding_features, padding_idx=0)
init.xavier_uniform(self.embedding.weight)
其他通用初始化方法
遍历初始化
for name, param in net.named_parameters():
if 'weight' in name:
init.normal_(param, mean=0, std=0.01)
print(name, param.data)
for name, param in net.named_parameters():
if 'bias' in name:
init.constant_(param, val=0)
print(name, param.data)
## 通过instance 初始化
for m in self.children():
if isinstance(m, nn.Linear):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, -100)
# 也可以判断是否为conv2d,使用相应的初始化方式
elif isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight.item(), 1)
nn.init.constant_(m.bias.item(), 0)
直接使用pytorch内置初始化
from torch.nn import init
init.normal_(net[0].weight, mean=0, std=0.01)
init.constant_(net[0].bias, val=0)
自带初始化方法中,会自动消除梯度反向传播,但是手动情况下必须自己设定
def no_grad_uniform(tensor, a, b):
with torch.no_grad():
return tensor.uniform_(a, b)
使用apply进行初始化
批量初始化方法,注意net里面的apply函数,可以作用网络的所有module
def weights_init(m): # 1
classname = m.__class__.__name__ # 2
if classname.find('Conv') != -1: # 3
nn.init.kaiming_normal_(m.weight.data) # 4
elif classname.find('BatchNorm') != -1: # 5
nn.init.normal_(m.weight.data, 1.0, 0.02) # 6
nn.init.constant_(m.bias.data, 0) # 7
net.apply(weights_init)
Pytorch系列:(七)模型初始化的更多相关文章
- 计算广告CTR预估系列(七)--Facebook经典模型LR+GBDT理论与实践
计算广告CTR预估系列(七)--Facebook经典模型LR+GBDT理论与实践 2018年06月13日 16:38:11 轻春 阅读数 6004更多 分类专栏: 机器学习 机器学习荐货情报局 版 ...
- Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager)
Alamofire源码解读系列(七)之网络监控(NetworkReachabilityManager) 本篇主要讲解iOS开发中的网络监控 前言 在开发中,有时候我们需要获取这些信息: 手机是否联网 ...
- [Asp.net MVC]Asp.net MVC5系列——在模型中添加验证规则
目录 概述 在模型中添加验证规则 自定义验证规则 伙伴类的使用 总结 系列文章 [Asp.net MVC]Asp.net MVC5系列——第一个项目 [Asp.net MVC]Asp.net MVC5 ...
- WCF编程系列(七)信道及信道工厂
WCF编程系列(七)信道及信道工厂 信道及信道栈 前面已经提及过,WCF中客户端与服务端的交互都是通过消息来进行的.消息从客户端传送到服务端会经过多个处理动作,在WCF编程模型中,这些动作是按层 ...
- Asp.net MVC]Asp.net MVC5系列——在模型中添加
目录 概述 在模型中添加验证规则 自定义验证规则 伙伴类的使用 总结 系列文章 [Asp.net MVC]Asp.net MVC5系列——第一个项目 [Asp.net MVC]Asp.net MVC5 ...
- iOS流布局UICollectionView系列七——三维中的球型布局
摘要: 类似标签云的球状布局,也类似与魔方的3D布局 iOS流布局UICollectionView系列七——三维中的球型布局 一.引言 通过6篇的博客,从平面上最简单的规则摆放的布局,到不规则的瀑 ...
- [源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store
[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store 目录 [源码解析] PyTorch分布式(6) ---Distribu ...
- Keil MDK STM32系列(七) STM32F4基于HAL的PWM和定时器
Keil MDK STM32系列 Keil MDK STM32系列(一) 基于标准外设库SPL的STM32F103开发 Keil MDK STM32系列(二) 基于标准外设库SPL的STM32F401 ...
- SQL Server 2008空间数据应用系列七:基于Bing Maps(Silverlight) 的空间数据展现
原文:SQL Server 2008空间数据应用系列七:基于Bing Maps(Silverlight) 的空间数据展现 友情提示,您阅读本篇博文的先决条件如下: 1.本文示例基于Microsoft ...
随机推荐
- 基于ARM Cortex-M的SoC存储体系结构和实战
基于ARM Cortex-M的SoC存储体系结构和实战 System on Chip Architecture Tutorial Memory Architecture for ARM Cortex- ...
- ContOS8 使用yum安装MariaDB
首先全部删除MySQL/MariaDB(若是首次安装可根据需要跳过此步) 若不清楚MySQL和MariaDB的关系请移步至 Mariadb百科 1.查看系统版本(以下任一命令即可). # cat /p ...
- .h5图像文件(数据集)的读取并存储 工具贴(二)
概述 H5文件是层次数据格式第5代的版本(Hierarchical Data Format,HDF5),它是用于存储科学数据的一种文件格式和库文件.由美国超级计算中心与应用中心研发的文件格式,用以存储 ...
- 详解 DNS 解析
背景 前面讲了域名.IP,那么还缺少一个主角,就是 DNS 这些都是网络中最最最基础的,也是最最最重要的概念,很有必要深入学习下 所有素材均来自:https://www.bilibili.com/vi ...
- Redmine部署
Redmine部署文章: 第一篇:Redmine部署 第二篇:Redmine部署中遇到的问题 部门内部需要项目开发维护的网站,这种网站有付费的,也有开源项目.这类项目管理与协作的工具主要的MS Sha ...
- 什么是DDoS黑洞路由?
1. 什么是DDoS黑洞路由? DDoS黑洞路由/过滤(有时称为黑孔)是缓解DDoS攻击的一种对策,网络流量将被路由到"黑洞"中并且丢失.如果在没有特定限制条件下实施黑洞过滤,合法 ...
- HTTP 5XX代码理解
500: 内部服务器错误,服务程序出现错误 501: 请求为完成,服务器不支持所请求功能,很少遇到 502: Bad Gateway,服务器从上游服务器收到一个无效响应, 一般就是nginx后端服 ...
- 07 修改JumpServer网页信息
1.7.修改JumpServer网页信息 注意:在修改相关配置文件之前要先进行备份,防止文件修改错误无法恢复. 1.Luna图标: /opt/luna/static/imgs/logo.png 2.j ...
- JVM到底是什么呢
在我们运行和调试Java程序的时候,经常会提到一个JVM的概念.那JVM到底是什么呢? JVM是Java程序的运行环境,它同时也是一个操作系统的一个应用程序.一个进程,因此他也有他自己的运行生命周期, ...
- CentOS-Docker安装RabbitMQ(单点)
这里注意获取镜像的时候要获取management版本的,不要获取last版本的,management版本的才带有管理界面. 获取镜像 $ docker pull rabbitmq:management ...