1. 权值的方差过大导致梯度爆炸的原因
  2. 方差一致性原则分析Xavier方法与Kaiming初始化方法

    饱和激活函数tanh,非饱和激活函数relu
  3. pytorch提供的十种初始化方法

梯度消失与爆炸

\[H_2 = H_1 * W_2\\
\Delta W_2 = \frac{\partial Loss}{\partial W_2}
=\frac{\partial Loss}{\partial out}
*\frac{\partial out}{\partial H_2}
*\frac{\partial H_2}{\partial W_2}
=\frac{\partial Loss}{\partial out}
*\frac{\partial out}{\partial H_2}*H_1
\]
\[{梯度消失:}H_1 \rightarrow 0 \Rightarrow \Delta W_2 \rightarrow 0\\
{梯度爆炸:}H_1 \rightarrow \infty \Rightarrow \Delta W_2 \rightarrow \infty
\]
\[1. E(X*Y)=E(X)*E(Y)\\
2. D(X)=E(X^2)-[E(X)]^2\\
3. D(X+Y)=D(X)+D(Y)\\
1.2.3. \Rightarrow D(X*Y)=D(X)D(Y)+D(X)*[E(Y)]^2+D(Y)*[E(X)]^2\\
若E(X)=0,E(Y)=0 \Rightarrow D(X*Y)=D(X)*D(Y)
\]
\[H_{11} = \sum ^{n}_{i=0} X_i * W_{1i}\\
D(X*Y) = D(X)*D(Y)\\
D(H_{11})=\sum ^{n}_{i=0} D(X_i)*D(W_1i)=n*(1*1)=n\\
std(H_{11})=\sqrt D(H_11) = \sqrt n\\
D(H_1) = n*D(X)*D(W)=1\\
D(W)=\frac{1}{n}\Rightarrow std(W)=\sqrt \frac {1}{n}
\]

Xavier方法与Kaiming方法

Xavier初始化

方差一致性,保持数据尺度维持在恰当范围,通常方差为1

激活函数:饱和函数,如Sigmoid,Tanh

\[n_i * D(W)=1\\
n_{i+1} *D(W)=1\\
\Rightarrow D(W)=\frac{2}{n_i+n_i+1}
\]
\[W \sim U[-a,a]\\
D(W) = \frac {(-a-a)^2}{12} = \frac {(2a)^2}{12}=\frac {a^2}{3}\\
\frac{2}{n_i+n_{i+1}}=\frac{a^2}{3}\Rightarrow a = \frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}\\
\Rightarrow W \sim U[-\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}},\frac{\sqrt 6}{\sqrt {n_i+n_{i+1}}}]
\]

Kaiming初始化

方差一致性:保持数据尺度维持在恰当范围,通常方差为1

激活函数:ReLU及其变种

\[D(W) = \frac{2}{n_i}\\
D(W) = \frac{2}{(1+a^2)*n_i}\\
std(W) = \sqrt{\frac{2}{(1+a^2)*n_i}}
\]

参考文献:

《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》

常用初始化方法

  1. Xavier均匀分布
  2. Xavier正态分布
  3. Kaiming均匀分布
  4. Kaiming正态分布
  5. 均匀分布
  6. 正态分布
  7. 常数分布
  8. 正交矩阵初始化
  9. 单位矩阵初始化
  10. 稀疏矩阵初始化
nn.init.calculate_gain(nonlinearity, param=None)

功能:计算激活函数的方差变化尺度

输入数据的方差和输出数据方差的比例。

参数:

  • nonlinearity:激活函数名称
  • param:激活函数参数,Leaky ReLU的negative_slop
# -*- coding: utf-8 -*-
"""
# @file name : grad_vanish_explod.py
# @author : TingsongYu https://github.com/TingsongYu
# @date : 2019-09-30 10:08:00
# @brief : 梯度消失与爆炸实验
"""
import os
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
import torch
import random
import numpy as np
import torch.nn as nn
path_tools = os.path.abspath(os.path.join(BASE_DIR, "..", "..", "tools", "common_tools.py"))
assert os.path.exists(path_tools), "{}不存在,请将common_tools.py文件放到 {}".format(path_tools, os.path.dirname(path_tools)) import sys
hello_pytorch_DIR = os.path.abspath(os.path.dirname(__file__)+os.path.sep+".."+os.path.sep+"..")
sys.path.append(hello_pytorch_DIR) from tools.common_tools import set_seed set_seed(1) # 设置随机种子 class MLP(nn.Module):
def __init__(self, neural_num, layers):
super(MLP, self).__init__()
self.linears = nn.ModuleList([nn.Linear(neural_num, neural_num, bias=False) for i in range(layers)])
self.neural_num = neural_num def forward(self, x):
for (i, linear) in enumerate(self.linears):
x = linear(x)
x = torch.relu(x) print("layer:{}, std:{}".format(i, x.std()))
if torch.isnan(x.std()):
print("output is nan in {} layers".format(i))
break return x def initialize(self):
for m in self.modules():
if isinstance(m, nn.Linear):
# nn.init.normal_(m.weight.data, std=np.sqrt(1/self.neural_num)) # normal: mean=0, std=1 # a = np.sqrt(6 / (self.neural_num + self.neural_num))
#
# tanh_gain = nn.init.calculate_gain('tanh')
# a *= tanh_gain
#
# nn.init.uniform_(m.weight.data, -a, a) # nn.init.xavier_uniform_(m.weight.data, gain=tanh_gain) # nn.init.normal_(m.weight.data, std=np.sqrt(2 / self.neural_num))
nn.init.kaiming_normal_(m.weight.data) flag = 0
# flag = 1 if flag:
layer_nums = 100
neural_nums = 256
batch_size = 16 net = MLP(neural_nums, layer_nums)
net.initialize() inputs = torch.randn((batch_size, neural_nums)) # normal: mean=0, std=1 output = net(inputs)
print(output) # ======================================= calculate gain ======================================= # flag = 0
flag = 1 if flag: x = torch.randn(10000)
out = torch.tanh(x) gain = x.std() / out.std()
print('gain:{}'.format(gain)) tanh_gain = nn.init.calculate_gain('tanh')
print('tanh_gain in PyTorch:', tanh_gain)

pytorch(14)权值初始化的更多相关文章

  1. [PyTorch 学习笔记] 4.1 权值初始化

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/grad_vanish_explod.py 在搭建好网络 ...

  2. caffe中权值初始化方法

    首先说明:在caffe/include/caffe中的 filer.hpp文件中有它的源文件,如果想看,可以看看哦,反正我是不想看,代码细节吧,现在不想知道太多,有个宏观的idea就可以啦,如果想看代 ...

  3. 神经网络权值初始化方法-Xavier

    https://blog.csdn.net/u011534057/article/details/51673458 https://blog.csdn.net/qq_34784753/article/ ...

  4. 权值初始化 - Xavier和MSRA方法

    设计好神经网络结构以及loss function 后,训练神经网络的步骤如下: 初始化权值参数 选择一个合适的梯度下降算法(例如:Adam,RMSprop等) 重复下面的迭代过程: 输入的正向传播 计 ...

  5. PyTorch 学习笔记(四):权值初始化的十种方法

    pytorch在torch.nn.init中提供了常用的初始化方法函数,这里简单介绍,方便查询使用. 介绍分两部分: 1. Xavier,kaiming系列: 2. 其他方法分布 Xavier初始化方 ...

  6. 【5】激活函数的选择与权值w的初始化

    激活函数的选择: 西格玛只在二元分类的输出层还可以用,但在二元分类中,其效果不如tanh,效果不好的原因是当Z大时,斜率变化很小,会导致学习效率很差,从而很影响运算的速度.绝大多数情况下用的激活函数是 ...

  7. 2019.01.14 bzoj5343: [Ctsc2018]混合果汁(整体二分+权值线段树)

    传送门 整体二分好题. 题意简述:nnn种果汁,每种有三个属性:美味度,单位体积价格,购买体积上限. 现在有mmm个询问,每次问能否混合出总体积大于某个值,总价格小于某个值的果汁,如果能,求所有方案中 ...

  8. 【机器学习的Tricks】随机权值平均优化器swa与pseudo-label伪标签

    文章来自公众号[机器学习炼丹术] 1 stochastic weight averaging(swa) 随机权值平均 这是一种全新的优化器,目前常见的有SGB,ADAM, [概述]:这是一种通过梯度下 ...

  9. 51nod1459(带权值的dijkstra)

    题目链接:https://www.51nod.com/onlineJudge/questionCode.html#!problemId=1459 题意:中文题诶- 思路:带权值的最短路,这道题数据也没 ...

随机推荐

  1. python-零基础入门-自学笔记

    目录 第一章:计算机基础 1.1 硬件组成 1.2 操作系统分类 1.3 解释型和编译型介绍 第二章:Python入门 2.1 介绍 2.2 python涉及领域 2.2.1 哪些公司有使用Pytho ...

  2. Kibana 地标图可视化

    ElasticSearch 可以使用 ingest-geoip 插件可以在 Kibana 上对 IP 进行地理位置分析, 这个插件需要 Maxmind 的 GeoLite2 City,GeoLite2 ...

  3. TCP协议与UDP协议的区别以及与TCP/IP协议的联系

    先介绍下什么是TCP,什么是UDP. 1. 什么是TCP? TCP(Transmission Control Protocol,传输控制协议)是面向连接的.可靠的字节流服务,也就是说,在收发数据前,必 ...

  4. C# 数据类型(3)

    动态类型 dynamic types 动态类型是后来引进的,他其实是一个static type,但是不像其他的静态类型,编译器不会检查你到底是啥类型(也不会检查你能不能去call某个'method') ...

  5. MS16-032 windows本地提权

    试用系统:Tested on x32 Win7, x64 Win8, x64 2k12R2 提权powershell脚本: https://github.com/FuzzySecurity/Power ...

  6. Python对excel的基本操作

    Python对excel的基本操作 目录 1. 前言 2. 实验环境 3. 基本操作 3.1 安装openpyxl第三方库 3.2 新建工作簿 3.2.1 新创建工作簿 3.2.2 缺省工作表 3.2 ...

  7. JVM系列之一 JVM的基础概念与内存区域

    前言 作为一名 Java 语言的使用者,学习 JVM 有助于解决程序运行过程中出现的问题.写出性能更高的代码. 可以说:学好 JVM 是成为中高级 Java 工程师的必经之路. 有感于从未整理归纳 J ...

  8. TypeScript 1.7 & TypeScript 1.8

    TypeScript 1.7 & TypeScript 1.8 1 1 https://zh.wikipedia.org/wiki/TypeScript TypeScript是一种由微软开发的 ...

  9. Array.fill & array padding

    Array.fill & array padding arr.fill(value[, start[, end]]) https://developer.mozilla.org/en-US/d ...

  10. BPMN 2.0

    BPMN 2.0 Business Process Model and Notation 业务流程模型和符号 https://www.omg.org/spec/BPMN/2.0.2/ bpmn-js ...