model.apply(fn)或net.apply(fn)
详情可参考:https://pytorch.org/docs/1.11/generated/torch.nn.Module.html?highlight=torch%20nn%20module%20apply#torch.nn.Module.apply
首先,我们知道pytorch的任何网络net,都是torch.nn.Module的子类,都算是module,也就是模块。
pytorch中的model.apply(fn)会递归地将函数fn应用到父模块的每个子模块submodule,也包括model这个父模块自身。
比如下面的网络例子中。net这个模块有两个子模块,分别为Linear(2,4)和Linear(4,8)。函数首先对Linear(2,4)和Linear(4,8)两个子模块调用init_weights函数,即print(m)打印Linear(2,4)和Linear(4,8)两个子模块。然后再对net模块进行同样的操作。如此完成递归地调用。从而完成model.apply(fn)或者net.apply(fn)。
个人水平有限,不足处望指正。
参考链接:https://blog.csdn.net/qq_37025073/article/details/106739513
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))
net.apply(init_weights)
#输出:
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Linear(in_features=2, out_features=2, bias=True)
Parameter containing:
tensor([[ 1., 1.],
[ 1., 1.]])
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
Sequential(
(0): Linear(in_features=2, out_features=2, bias=True)
(1): Linear(in_features=2, out_features=2, bias=True)
)
https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=nn%20module%20apply#torch.nn.Module.apply
如果我们想对某些特定的子模块submodule做一些针对性的处理,该怎么做呢。我们可以加入type(m) == nn.Linear:这类判断语句,从而对子模块m进行处理。如下,读者可以细细体会一下。
import torch.nn as nn
@torch.no_grad()
def init_weights(m):
print(m)
if type(m) == nn.Linear:
m.weight.fill_(1.0)
print(m.weight)
net = nn.Sequential(nn.Linear(2,4), nn.Linear(4, 8))
print(net)
print('isinstance torch.nn.Module',isinstance(net,torch.nn.Module))
print(' ')
net.apply(init_weights)
可以先打印网络整体看看。调用apply函数后,先逐一打印子模块m,然后对子模块进行判断,打印Linear这类子模块m的权重。
#输出:
Sequential(
(0): Linear(in_features=2, out_features=4, bias=True)
(1): Linear(in_features=4, out_features=8, bias=True)
)
isinstance torch.nn.Module True Linear(in_features=2, out_features=4, bias=True)
Parameter containing:
tensor([[1., 1.],
[1., 1.],
[1., 1.],
[1., 1.]], requires_grad=True)
Linear(in_features=4, out_features=8, bias=True)
Parameter containing:
tensor([[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.],
[1., 1., 1., 1.]], requires_grad=True)
Sequential(
(0): Linear(in_features=2, out_features=4, bias=True)
(1): Linear(in_features=4, out_features=8, bias=True)
)
网友:所以说apply函数是有顺序的,先在子模块上操作,最后在父模块上操作。
model.apply(fn)或net.apply(fn)的更多相关文章
- Function.apply.bind()与Function.apply.bind()
1.Function.apply.bind(…) 我在学习promise部分的时候遇到了这样的代码: Promise.resolve([10,20]).then(Function.apply.bind ...
- jQuery源码-jQuery.fn.attr与jQuery.fn.prop
jQuery.fn.attr.jQuery.fn.prop的区别 假设页面有下面这么个标签,$('#ddd').attr('nick').$('#ddd').prop('nick')分别会取得什么值? ...
- jQuery.fn.attr与jQuery.fn.prop
jQuery.fn.attr与jQuery.fn.prop jQuery.fn.attr.jQuery.fn.prop的区别 假设页面有下面这么个标签,$('#ddd').attr('nick').$ ...
- 探索 Reflect.apply 与 Function.prototype.apply 的区别
探索 Reflect.apply 与 Function.prototype.apply 的区别 众所周知, ES6 新增了一个全局.内建.不可构造的 Reflect 对象,并提供了其下一系列可被拦截的 ...
- django(新增model)No migrations to apply.
django 1.8版本,在models下新建一个class,无法在数据库创建新表的问题: - models.py class HostPwd(models.Model): hostname = mo ...
- $.extend() 或 jQuery.extend() 与 $.fn.Xxx 或 jQuery.fn.extend(object) 之jQuery插件开发
jQuery为开发插件提拱了两个方法 语法现象1:$.extend() 或 jQuery.extend() 或 jQuery.extend(object)//可以理解为为jQuery类添加类方法或静态 ...
- 【JavaScript】JQuery中$.fn、$.extend、$.fn.extend
Web开发肯定要使用第三方插件,对于一个炫丽的效果都忍不住想看看对方是如何实现的,刚下载了一个仿京东商品鼠标经过时局部放大的插件.看了两眼JQuery源码,看看就感觉一头雾水.JQuery本来自己学的 ...
- jQuery属性--html([val|fn])、text([val|fn])和val([val|fn|arr])
html([val|fn]) 概述 取得第一个匹配元素的html内容,这个函数不能用于XML文档.但可以用于XHTML文档. 在一个 HTML 文档中, 我们可以使用 .html() 方法来获取任意一 ...
- angularJS $watch $apply $digest
看O'Reilly的书看到$watch这部分,不过没看懂,网上很多资料也含糊不清,不过还是找到了几个好的,简单记录一下. 一句话说明,$watch是用来监视变量的,好了直接上代码 <html&g ...
随机推荐
- Ubuntu安装docker(摘自官网,自用)
在 Ubuntu 上安装 Docker 引擎(按照标红顺序执行命令) 预计阅读时间:11分钟 适用于 Linux 的 Docker 桌面 Docker Desktop 可帮助您在 Mac 和 Wind ...
- Struts2-使用forEach标签+el标签获取值栈数据
import cn.web.body.User; import com.opensymphony.xwork2.ActionSupport; import java.util.ArrayList; i ...
- drf的JWT认证
JWT认证(5星) token发展史 在用户注册或登录后,我们想记录用户的登录状态,或者为用户创建身份认证的凭证.我们不再使用Session认证机制,而使用Json Web Token(本质就是tok ...
- mysql查询 if判断、case语句的使用等
一. 查询的数字转换为中文返回前端 1. 如果是0或1表状态等,可用: IF(字段 = 0, '否', '是') AS xxx 2. 如果是多个值,比如1,2,3可用: ELT(字段, '计划治理', ...
- Python 国家地震台网中心地震数据集完整分析、pyecharts、plotly,分析强震次数、震级分布、震级震源关系、发生位置、发生时段、最大震级、平均震级
注意,本篇内容根据我老师布置的数据分析作业展开.请勿抄袭,后果自负! 前情提要 编写这篇文章是为了记录自己是如何分析地震数据集,使用模块,克服一系列 \(bug\) 的过程.如果你是 \(python ...
- 前端vue之属性指令、style和class、条件渲染、列表渲染、事件处理、数据双向绑定、表单控制、v-model进阶
今日内容概要 属性指令 style和class 条件渲染 列表渲染 事件处理 数据的双向绑定 v-model进阶 购物车案例 内容详细 1.属性指令 <!DOCTYPE html> < ...
- 『现学现忘』Git基础 — 5、Git的协作模式
目录 1.分布式工作流程 2.集中式工作流 3.分支工作流 4.GitFlow 工作流(最流行) 5.Forking 工作流(偶尔使用) 6.总结 1.分布式工作流程 与传统的集中式版本控制系统(CV ...
- Spring的3级缓存和循环引用的理解
此处是我自己的一个理解,防止以后忘记,如若那个地方理解不对,欢迎指出. 一.背景 在我们写代码的过程中一般会使用 @Autowired 来注入另外的一个对象,但有些时候发生了 循环依赖,但是我们的代码 ...
- 攻防世界-MISC:Aesop_secret
这是攻防世界高手进阶区的的第九题,题目如下: 点击下载附件一,得到一个压缩包,解压后得到一张GIF动图,找个网站给他分解一下,得到如下图片 不知道是什么意思,所以就跑去看WP了,用010editor打 ...
- c# 一些警告的处理方法
在使用.Net 6开发程序时,发现多了很多新的警告类型.这里总结一下处理方法. CS8618 在退出构造函数时,不可为 null 的 属性"Name"必须包含非 null 值 经常 ...