一、使用Numpy初始化:【直接对Tensor操作】

  • 对Sequential模型的参数进行修改:

 import numpy as np
import torch
from torch import nn # 定义一个 Sequential 模型
net1 = nn.Sequential(
nn.Linear(30, 40),
nn.ReLU(),
nn.Linear(40, 50),
nn.ReLU(),
nn.Linear(50, 10)
) # 访问第一层的参数
w1 = net1[0].weight
b1 = net1[0].bias
print(w1) #对第一层Linear的参数进行修改:
# 定义第一层的参数 Tensor 直接对其进行替换
net1[0].weight.data = torch.from_numpy(np.random.uniform(3, 5, size=(40, 30)))
print(net1[0].weight)
23
24 #若模型中相同类型的层都需要初始化成相同的方式,一种更高效的方式:使用循环去访问:
25 for layer in net1:
26     if isinstance(layer, nn.Linear): # 判断是否是线性层
27         param_shape = layer.weight.shape
28         layer.weight.data = torch.from_numpy(np.random.normal(0, 0.5, size=param_shape))
29         # 定义为均值为 0,方差为 0.5 的正态分布
  • 对Module模型 的参数初始化:

对于 Module 的参数初始化,其实也非常简单,如果想对其中的某层进行初始化,可以直接像 Sequential 一样对其 Tensor 进行重新定义,其唯一不同的地方在于,如果要用循环的方式访问,需要介绍两个属性,children 和 modules,下面我们举例来说明:

1、创建Module模型类:

 class sim_net(nn.Module):
def __init__(self):
super(sim_net, self).__init__()
self.l1 = nn.Sequential(
nn.Linear(30, 40),
nn.ReLU()
) self.l1[0].weight.data = torch.randn(40, 30) # 直接对某一层初始化 self.l2 = nn.Sequential(
nn.Linear(40, 50),
nn.ReLU()
) self.l3 = nn.Sequential(
nn.Linear(50, 10),
nn.ReLU()
) def forward(self, x):
x = self.l1(x)
x =self.l2(x)
x = self.l3(x)
return x

2、创建模型对象:

net2 = sim_net()

3、访问children:

# 访问 children
for i in net2.children():
print(i)
     #打印的结果:
Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)

4、访问modules:

# 访问 modules
for i in net2.modules():
print(i) #打印的结果
sim_net(
(l1): Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
(l2): Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
(l3): Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)
)
Sequential(
(0): Linear(in_features=30, out_features=40)
(1): ReLU()
)
Linear(in_features=30, out_features=40)
ReLU()
Sequential(
(0): Linear(in_features=40, out_features=50)
(1): ReLU()
)
Linear(in_features=40, out_features=50)
ReLU()
Sequential(
(0): Linear(in_features=50, out_features=10)
(1): ReLU()
)
Linear(in_features=50, out_features=10)
ReLU()

通过上面的例子,可以看到:

children 只会访问到模型定义中的第一层,因为上面的模型中定义了三个 Sequential,所以只会访问到三个 Sequential,而 modules 会访问到最后的结构,比如上面的例子,modules 不仅访问到了 Sequential,也访问到了 Sequential 里面,这就对我们做初始化非常方便。

5、采用循环初始化:

for layer in net2.modules():
if isinstance(layer, nn.Linear):
param_shape = layer.weight.shape
layer.weight.data = torch.from_numpy(np.random.normal(0, 0.5, size=param_shape))

二、torch.nn.init初始化

PyTorch 还提供了初始化的函数帮助我们快速初始化,就是 torch.nn.init,其操作层面仍然在 Tensor 上。先介绍一种初始化方法:

Xavier 初始化方法:

其中 表示该层的输入和输出数目。

这种非常流行的初始化方式叫 Xavier,方法来源于 2010 年的一篇论文 Understanding the difficulty of training deep feedforward neural networks,其通过数学的推到,证明了这种初始化方式可以使得每一层的输出方差是尽可能相等的。

torch.nn.init:

from torch.nn import init

init.xavier_uniform(net1[0].weight) # 这就是上面我们讲过的 Xavier 初始化方法,PyTorch 直接内置了其实现

#这就直接修改了net1[0].weight的值

Pytorch基础(6)----参数初始化的更多相关文章

  1. pytorch对模型参数初始化

    1.使用apply() 举例说明: Encoder :设计的编码其模型 weights_init(): 用来初始化模型 model.apply():实现初始化 # coding:utf- from t ...

  2. PyTorch常用参数初始化方法详解

    1. 均匀分布 torch.nn.init.uniform_(tensor, a=0, b=1) 从均匀分布U(a, b)中采样,初始化张量. 参数: tensor - 需要填充的张量 a - 均匀分 ...

  3. PyTorch模型读写、参数初始化、Finetune

    使用了一段时间PyTorch,感觉爱不释手(0-0),听说现在已经有C++接口.在应用过程中不可避免需要使用Finetune/参数初始化/模型加载等. 模型保存/加载 1.所有模型参数 训练过程中,有 ...

  4. 【转载】 pytorch自定义网络结构不进行参数初始化会怎样?

    原文地址: https://blog.csdn.net/u011668104/article/details/81670544 ------------------------------------ ...

  5. pytorch和tensorflow的爱恨情仇之参数初始化

    pytorch和tensorflow的爱恨情仇之基本数据类型 pytorch和tensorflow的爱恨情仇之张量 pytorch和tensorflow的爱恨情仇之定义可训练的参数 pytorch版本 ...

  6. DL基础补全计划(五)---数值稳定性及参数初始化(梯度消失、梯度爆炸)

    PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明   本文作为本人csdn blog的主站的备份.(Bl ...

  7. [源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

    [源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store 目录 [源码解析] PyTorch分布式(6) ---Distribu ...

  8. pytorch基础教程1

    0.迅速入门:根据上一个博客先安装好,然后终端python进入,import torch ******************************************************* ...

  9. [人工智能]Pytorch基础

    PyTorch基础 摘抄自<深度学习之Pytorch>. Tensor(张量) PyTorch里面处理的最基本的操作对象就是Tensor,表示的是一个多维矩阵,比如零维矩阵就是一个点,一维 ...

随机推荐

  1. HTML5:控件自动获得焦点

    在HTML5中,页面打开后,需要指定的控件自动获得焦点很简单,只需要一个属性就可以实现 -  autofocus. 示例如下: <input type="text" auto ...

  2. AngularJS:实现页面滚动到底自动加载数据的功能

    要实现这个功能,可以通过https://github.com/sroze/ngInfiniteScroll这个第三方控件来实现.步骤如下: 1. 下载ng-infinite-scroll.js程序ht ...

  3. LINQ体验(2)——C# 3.0新语言特性和改进(上篇)

    整体来说.Visual Studio 2008和.NET 3.5是建立在.NET2.0核心的基础之上,.NET2.0核心本身将不再变化(假设不了解.NET2.0的朋友,请參看MSDN或者一些经典的书籍 ...

  4. 【独立开发人员er Cocos2d-x实战 011】Cocos2dx 3.x命令行生成APK具体解释

    Cocos2d-x 3.6项目打包生成apk安卓应用文件,搭建安卓环境的步骤有点繁琐.但搭建一次之后,以后就会很快捷! 过程例如以下: 一.下载安卓环境:搭建Android环境须要用到Android ...

  5. codeforces 437D The Child and Zoo

    time limit per test 2 seconds memory limit per test 256 megabytes input standard input output standa ...

  6. linux更改gitlab存储位置

    更改仓库存储位置默认时GitLab的仓库存储位置在“/var/opt/gitlab/git-data/repositories”,在实际生产环境中显然我们不会存储在这个位置,一般都会划分一个独立的分区 ...

  7. UVa 263 - Number Chains

    题目:给你一个数字n0.将它的每一个位的数字按递增排序生成数a,按递减排序生成数b, 新的数字为n1 = a-b,下次依照相同方法计算n1,知道出现循环,问计算了多少次. 分析:数论.模拟.直接模拟计 ...

  8. 【NOIP2018】为什么这么无力啊

    菜鸡又要爆零了 辛辛苦苦背板子结果考时候脑子一片空白 第一题线段树调了半小时 看完三道题两道写暴搜一道写暴力(说是暴搜,觉得更像写了个背包) 别提暴搜还忘记剪枝. . . . . . 我觉得考场上最菜 ...

  9. Linux:命令gedit

    首先,gedit是一个GNOME桌面环境下兼容UTF-8的文本编辑器.它使用GTK+编写而成,因此它十分的简单易用,有良好的语法高亮,对中文支持很好,支持包括gb2312.gbk在内的多种字符编码. ...

  10. 讲一讲WiFi快连、SmartConfig、SmartConnect

    最近要给公司同事们培训WiFi快连技术,整理了相关资料,也分享在博客这,献给有缘人. 前言 现在的智能硬件产品,以WiFi品类居多,这些WiFi硬件没有人机交互界面,但设备要上网肯定要配置SSID等相 ...