初始化参数的方法

nn.Module模块对于参数进行了内置的较为合理的初始化方式,当我们使用nn.Parameter时,初始化就很重要,而且我们也可以指定代替内置初始化的方式对nn.Module模块进行补充。

除了之前的.data进行赋值,或者.data.初始化方式外,我们可以使用torch.nn.init进行初始化参数。

from torch.nn import init

linear = nn.Linear(3, 4)

t.manual_seed(1)

init.xavier_normal(linear.weight)
print(linear.weight.data) import math std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0, std)

不同层类型定制化初始化

除此之外,我们可以使用如下的方式对不同的类型的层(卷积层、全连接层……)进行不同的赋值方式,

for name, params in net.named_parameters():
if name.find('linear') != -1:
params[0] # weights
params[1] # bias
elif name.find('conv') != -1:
pass
elif name.find('norm') != -1:
pass

这里使用了str.find()方法,如下:

'asda'.find('a')
Out[3]:
0

即返回第一个find参数在原str中的位置索引。

『PyTorch』第十三弹_torch.nn.init参数初始化的更多相关文章

  1. 『PyTorch』第十一弹_torch.optim优化器

    一.简化前馈网络LeNet import torch as t class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__i ...

  2. 『PyTorch』第十一弹_torch.optim优化器 每层定制参数

    一.简化前馈网络LeNet 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 im ...

  3. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下

    『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...

  4. 『PyTorch』第三弹重置_Variable对象

    『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...

  5. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  6. 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上

    总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...

  7. 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法

    在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...

  8. 『PyTorch』第七弹_nn.Module扩展层

    有下面代码可以看出torch层函数(nn.Module)用法,使用超参数实例化层函数类(常位于网络class的__init__中),而网络class实际上就是一个高级的递归的nn.Module的cla ...

  9. 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数

    一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...

随机推荐

  1. Struts2 Spring Hibernate 框架整合 Annotation MavenProject

    项目结构目录 pom.xml       添加和管理jar包 <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns ...

  2. 深入浅出JavaScript运行机制

    一.引子 本文介绍JavaScript运行机制,这一部分比较抽象,我们先从一道面试题入手: console.log(1); setTimeout(function(){ console.log(3); ...

  3. mysql 触发器 trigger用法 two (稍微复杂的)

    触发器(trigger):监视某种情况,并触发某种操作. 触发器创建语法四要素:1.监视地点(table) 2.监视事件(insert/update/delete) 3.触发时间(after/befo ...

  4. PHP中private和public还有protected的区别

    原文链接:http://www.thinkphp.cn/code/1898.html <? //父类 class father{ public function a(){ echo " ...

  5. 20145302张薇 《网络对抗技术》 web安全基础实践

    20145302张薇 <网络对抗技术> web安全基础实践 实验问题回答 1.SQL注入攻击原理,如何防御 原理:攻击者把SQL命令插入到网页的各种查询字符串处,达到欺骗服务器执行恶意的S ...

  6. Cocos 开发笔记

    经发现: cocos creator 提供的hello world 模版中.只有HelloWorkd.js中 properties 属性 text的值不是'hello world!' Label 组件 ...

  7. HDU1251 统计难题 (字典树模板)题解

    思路:模板题,贴个模板 代码: #include<cstdio> #include<cstring> #include<cstdlib> #include<q ...

  8. 【Git安装】centos安装git

    1 yum install git 安装后的默认存放地点/usr/bin/git

  9. c#之有参和无参构造函数,扩展方法

    例如在程序中创建 Parent类和Test类,在Test有三个构造函数,parent类继承Test类,那么我们可以在Test类自身中添加 扩展 方法吗? 答案:是不可以的.因为扩展方法必须是静态的,且 ...

  10. [AtCoder ARC101D/ABC107D] Median of Medians

    题目链接 题意:给n个数,求出所有子区间的中位数,组成另外一个序列,求出它的中位数 这里的中位数的定义是:将当前区间排序后,设区间长度为m,则中位数为第m/2+1个数 做法:二分+前缀和+树状数组维护 ...