『PyTorch』第十三弹_torch.nn.init参数初始化
初始化参数的方法
nn.Module模块对于参数进行了内置的较为合理的初始化方式,当我们使用nn.Parameter时,初始化就很重要,而且我们也可以指定代替内置初始化的方式对nn.Module模块进行补充。
除了之前的.data进行赋值,或者.data.初始化方式外,我们可以使用torch.nn.init进行初始化参数。
from torch.nn import init linear = nn.Linear(3, 4) t.manual_seed(1) init.xavier_normal(linear.weight)
print(linear.weight.data) import math std = math.sqrt(2)/math.sqrt(7.)
linear.weight.data.normal_(0, std)
不同层类型定制化初始化
除此之外,我们可以使用如下的方式对不同的类型的层(卷积层、全连接层……)进行不同的赋值方式,
for name, params in net.named_parameters():
if name.find('linear') != -1:
params[0] # weights
params[1] # bias
elif name.find('conv') != -1:
pass
elif name.find('norm') != -1:
pass
这里使用了str.find()方法,如下:
'asda'.find('a')
Out[3]:
0
即返回第一个find参数在原str中的位置索引。
『PyTorch』第十三弹_torch.nn.init参数初始化的更多相关文章
- 『PyTorch』第十一弹_torch.optim优化器
一.简化前馈网络LeNet import torch as t class LeNet(t.nn.Module): def __init__(self): super(LeNet, self).__i ...
- 『PyTorch』第十一弹_torch.optim优化器 每层定制参数
一.简化前馈网络LeNet 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 im ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第七弹_nn.Module扩展层
有下面代码可以看出torch层函数(nn.Module)用法,使用超参数实例化层函数类(常位于网络class的__init__中),而网络class实际上就是一个高级的递归的nn.Module的cla ...
- 『PyTorch』第五弹_深入理解autograd_下:函数扩展&高阶导数
一.封装新的PyTorch函数 继承Function类 forward:输入Variable->中间计算Tensor->输出Variable backward:均使用Variable 线性 ...
随机推荐
- 检测u盘是否挂载上方法
打开内核log:echo "8" > /proc/sys/kernel/printk 关闭内核log:echo "1" > /proc/sys/ke ...
- CPU负载过高异常排查实践与总结
昨天下午突然收到运维邮件报警,显示数据平台服务器cpu利用率达到了98.94%,而且最近一段时间一直持续在70%以上,看起来像是硬件资源到瓶颈需要扩容了,但仔细思考就会发现咱们的业务系统并不是一个高并 ...
- Python学习笔记之在Python中实现单例模式
有些时候你的项目中难免需要一些全局唯一的对象,这些对象大多是一些工具性的东西,在Python中实现单例模式并不是什么难事.以下总结几种方法: 使用类装饰器 使用装饰器实现单例类的时候,类本身并不知道自 ...
- Android实践项目汇报(三)
Google天气客户端 本周学习计划 调试代码使之成功运行并实现天气预报功能. 实际完成情况 由于google取消api接口服务,天气源的传输.所以我换了一个使用 haoserver API接口的程序 ...
- 不到 200 行代码,教你如何用 Keras 搭建生成对抗网络(GAN)【转】
本文转载自:https://www.leiphone.com/news/201703/Y5vnDSV9uIJIQzQm.html 生成对抗网络(Generative Adversarial Netwo ...
- rabbitmq direct、fanout、topic 三种Exchange java 代码比较
Producer端 1.channel的创建 无论是才用什么样的Exchange,创建channel代码都是相同的,如下 ConnectionFactory factory = new Connect ...
- 【传输对象】kafka传递实体类消息
工具类 负责对象字节数组的相互转换,传输数据用 package com.yq.utils; import java.io.ByteArrayInputStream; import java.io.By ...
- C# 查出数据表DataTable 清除一列中的重复项保留其他项
http://bbs.csdn.net/topics/391085792 DataTable 老表= 新表.AsEnumerable().GroupBy(p => p["姓名& ...
- asp.net <asp:Repeater>下的radio的单选使用
aspx页面 <asp:Repeater ID="rptData" runat="server"> <ItemTemplate> < ...
- Js页面自动跳转
//声明 t = 1 var t = 10; function openwin() { t -= 1; if(t==0){ location.href='index2.html'; } setTime ...