1. 关于parameters()方法

Pytorch中继承了torch.nn.Module的模型类具有named_parameters()/parameters()方法,这两个方法都会返回一个用于迭代模型参数的迭代器(named_parameters还包括参数名字):

import torch

net = torch.nn.LSTM(input_size=512, hidden_size=64)
print(net.parameters())
print(net.named_parameters())
# <generator object Module.parameters at 0x12a4e9890>
# <generator object Module.named_parameters at 0x12a4e9890>

我们可以将net.parameters()迭代器和将net.named_parameters()转化为列表类型,前者列表元素是模型参数,后者是包含参数名和模型参数的元组。

当然,我们更多的是对迭代器直接进行迭代:

for param in net.parameters():
print(param.shape)
# torch.Size([256, 512])
# torch.Size([256, 64])
# torch.Size([256])
# torch.Size([256])
for name, param in net.named_parameters():
print(name, param.shape)
# weight_ih_l0 torch.Size([256, 512])
# weight_hh_l0 torch.Size([256, 64])
# bias_ih_l0 torch.Size([256])
# bias_hh_l0 torch.Size([256])

我们知道,Pytorch在进行优化时需要给优化器传入这个参数迭代器,如:

from torch.optim import RMSprop
optimizer = RMSprop(net.parameters(), lr=0.01)

2. 关于参数修改

那么底层具体是怎么对参数进行修改的呢?

我们在博客《Python对象模型与序列迭代陷阱》中介绍过,Python序列中本身存放的就是对象的引用,而迭代器返回的是序列中的对象的二次引用,如果序列的引用指向基础数据类型,则是不可以通过遍历序列进行修改的,如:

my_list = [1, 2, 3, 4]
for x in my_list:
x += 1
print(my_list) #[1, 2, 3, 4]

而序列中的引用指向复合数据类型,则可以通过遍历序列来完成修改操作,如:

my_list = [[1, 2],[3, 4]]
for sub_list in my_list:
sub_list[0] += 1
print(my_list)
# [1, 2, 3, 4]
# [[2, 2], [4, 4]]

具体原理可参照该篇博客,此处我就不在赘述。这里想提到的是,用net.parameters()/net.named_parameters()来迭代并修改参数,本质上就是上述第二种对复合数据类型序列的修改。我们可以如下写:

for param in net.parameters():
with torch.no_grad():
param += 1

with torch.no_grad():表示将将所要修改的张量关闭梯度计算。所增加的1会广播到param张量的中的每一个元素上。上述操作本质上为:

for param in net.parameters():
with torch.no_grad():
param += torch.ones(param.shape)

但是需要注意,如果我们想让参数全部置为0,切不可像下列这样写:

for param in net.parameters():
with torch.no_grad():
param = torch.zeros(param.shape)

param是二次引用,param=0操作再语义上会被解释为让param这个二次引用去指向新的全0张量对象,但是对参数张量本身并不会产生任何变动。该操作实际上类似下列这种操作:

list_1 = [1, 2]
list_2 = list_1
list_2 = [0, 0]
print(list_1) # [1, 2]

修改二次引用list_2自然不会影响到list_1引用的对象。

下面让我们纠正这种错误,采用下列方法直接来将参数张量中的所有数值置0:

for param in net.parameters():
with torch.no_grad():
param[:] = 0 #张量类型自带广播操作,等效于param[:] = torch.zeros(param.shape)

这时语义上就类似

list_1 = [1, 2]
list_2 = list_1
list_2[:] = [0, 0]
print(list_1) # [0, 0]

自然就能完成修改的操作了。

参考

Pytorch:利用torch.nn.Modules.parameters修改模型参数的更多相关文章

  1. pytorch中torch.nn构建神经网络的不同层的含义

    主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...

  2. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  3. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  4. 到底什么是TORCH.NN?

    该教程是在notebook上运行的,而不是脚本,下载notebook文件. PyTorch提供了设计优雅的模块和类:torch.nn, torch.optim, Dataset, DataLoader ...

  5. 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...

  6. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  7. pytorch保存模型等相关参数,利用torch.save(),以及读取保存之后的文件

    本文分为两部分,第一部分讲如何保存模型参数,优化器参数等等,第二部分则讲如何读取. 假设网络为model = Net(), optimizer = optim.Adam(model.parameter ...

  8. 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())

    学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...

  9. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  10. [深度学习] Pytorch学习(二)—— torch.nn 实践:训练分类器(含多GPU训练CPU加载预测的使用方法)

    Learn From: Pytroch 官方Tutorials Pytorch 官方文档 环境:python3.6 CUDA10 pytorch1.3 vscode+jupyter扩展 #%% #%% ...

随机推荐

  1. openGauss升级脚本撰写

    概述 重要提示: 升级过程通过执行升级 sql 脚本实现系统表变更,这些脚本必须由开发人员在修改系统表的同时一并提供升级 sql 脚本,请将这些脚本代码提交至 openGauss-server/src ...

  2. 本周三晚19:00 Hello HarmonyOS进阶课程第6课—短视频应用开发

    短视频应用软件的开发一直保持着快速发展,在用户流量增长和规模扩大的同时,短视频行业的受欢迎程度也在持续上升.在生活节奏不断加快的今天,人们过着越来越充实的生活,碎片化已经渐渐成为人们习以为常的节奏,比 ...

  3. Maven 必备技能:MAC 系统下 JDK和Maven 安装及环境变量配置详细讲解

    开发中难免因系统问题或者版本变更反复折腾JDK和Maven环境变量,干脆写个笔记备忘个,也方便小伙伴们节省时间. =================JDK安装与环境变量配置====== 1.官网下载j ...

  4. Spring Cloud 核心组件之Spring Cloud Hystrix:服务容错保护

    Spring Cloud Hystrix:服务容错保护 SpringCloud学习教程 SpringCloud Spring Cloud Hystrix 是Spring Cloud Netflix 子 ...

  5. 【6】Spring JavaConfig和常见Annotation

    Java 5 的推出,加上当年基于纯 Java Annotation 的依赖注入框架 Guice 的出现,使得 Spring 框架及其社区也"顺应民意",推出并持续完善了基于 Ja ...

  6. mysql 重新整理——索引简介[七]

    前言 百度百科索引: 在关系数据库中,索引是一种单独的.物理的对数据库表中一列或多列的值进行排序的一种存储结构,它是某个表中一列或若干列值的集合和相应的指向表中物理标识这些值的数据页的逻辑指针清单. ...

  7. ERP财务管理有哪些功能?如何选择合适的ERP软件开发商定制开发适合自己的ERP财务管理?

    企业日常运营中,分工明确.结构清晰的财务管理非常重要,因此在完整的ERP解决方案中,财务管理是不可或缺的部分,甚至财务管理是整个ERP解决方案的核心,其它功能模块都围绕着财务管理构建价值链创造流程,最 ...

  8. JavaScript中数值小知识

    1. 数值10.0 这种类似的会被去掉数值后的0 之所以这样是因为,整数的存储空间占用比浮点数小,当一个数值不是真浮点数(即10.0这种格式),会被转换为整数10,如果业务中有一些需求需要进行数值位数 ...

  9. Résumé Review 二分方法题解

    一道非常好的数学题,不愧是CF的题,跟某些网站上的水题.恶心题没法比~ 题意 这里就要夸一下某谷了,翻译的很好,不像我,在CF上用deepl翻译,不够清晰(←全是废话) 分析 先不考虑 bi ,考虑转 ...

  10. 力扣121(java&python)-买卖股票的最佳时机(简单)

    题目: 给定一个数组 prices ,它的第 i 个元素 prices[i] 表示一支给定股票第 i 天的价格. 你只能选择 某一天 买入这只股票,并选择在 未来的某一个不同的日子 卖出该股票.设计一 ...