今天在写一个分类网络时,要使用nn.Sequential中的一个模块,因为nn.Sequential中模块都没有名字,我一时竟无从下笔。于是决定写这篇博客梳理pytorch的nn.Module类,看完这篇博客,你大概率可以学会:

  • 提取nn.Sequential中任意一个模块
  • 能初始化一个网络的所有权重,不管是随机初始化还是使用权重文件
  • 对nn.Module类有个总体把握

1 __init__方法

  我们先不看代码,自己小脑袋里想一想这个类应该有什么东西。既然这个类是和各种layer相关,里面一定存着各种layer,如卷积或者relu,肯定还会有layer对应的权重。没错,这些东西类里都有:


from ..backends.thnn import backend as thnn_backend
def __init__(self):
        self._backend = thnn_backend
        self._parameters = OrderedDict()
        self._buffers = OrderedDict()
        self._backward_hooks = OrderedDict()
        self._forward_hooks = OrderedDict()
        self._forward_pre_hooks = OrderedDict()
        self._modules = OrderedDict()
        self.training = True

  其中_modules和_parameters就是存储这些的,那么剩下这些奇奇怪怪的属性是什么呢?第一个奇怪的东西就是这个_backend属性,他的值是thnn_backend,好像和torch的底层代码有关,有点麻烦啊!不用怕,这个属性我们只要大概了解它的意思就可以,对之后这个模块的使用没有任何影响。举个例子来帮助我们理解,想象说我们有maxpooling层,pytorch底层会用cudnn实现,也会用cunn实现,pytorch在前向传播时会自动选择一个速度最快的实现,这里thnn_backend就是指定我们前向传播时用thnn这种实现方式。更多过于backend的直觉理解可以参考这个thread

  第二个奇怪的东西就是_buffer属性,如果我说这个buffer存储的也是网络参数,你会感觉到更加迷惑吗?一个提示:batch norm的参数。(花一分钟想一想。)batch norm中除了有参数$ alpha $和$ beta $外,还有running_mean和running_var,这些在整个学习过程中也需要存储起来但是不需要学习的参数我们就把它们存储到_buffer中。

  第三个奇怪的东西就是3个钩子,这三个钩子具体的实现我可能很后面才会讲,不过为了一些好奇宝宝,如果我说forward_hook的功能是在module完成前向传播时做一些事,你能推断出其他两个钩子的功能吗?(答案

2 children方法和modules方法

  children方法和modules方法的作用是很类似的,我们先看一下children方法的代码。

def named_children(self):
    memo = set()
    for name, module in self._modules.items():
        if module is not None and module not in memo:
            memo.add(module)
            yield name, module
def children(self):
    for name, module in self.named_children():
        yield module

  我们看完代码发现children方法的作用就是把_modules遍历一遍,我们来看一下具体例子(你也可以自己在命令行中把这两个命令输入进去,尽量不要复制粘贴):

>>>model=nn.Sequential(nn.Linear(3,1), \
                       nn.Sequential(nn.BatchNorm2d(1), \
                       nn.Linear(1,3)))
>>> for m in model.children():
...     print(m)
Linear(in_features=3, out_features=1, bias=True)
Sequential(
  (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
  (1): Linear(in_features=1, out_features=3, bias=True)
)    

  如果我们还要把更里层的modules提取出来,我们就需要用到modules方法:

def named_modules(self, memo=None, prefix=''):
    if memo is None:
        memo = set()
    if self not in memo:
       memo.add(self)     yield prefix, self    for name, module in self._modules.items():       if module is None:             continue       submodule_prefix = prefix + ('.' if prefix else '') + name       for m in module.named_modules(memo, submodule_prefix):         yield mdef modules(self):
    for name, module in self.named_modules():
        yeild module
>>> for m in model.modules():
...     print(m)
Sequential(
  (0): Linear(in_features=3, out_features=1, bias=True)
  (1): Sequential(
    (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
    (1): Linear(in_features=1, out_features=3, bias=True)
  )
)
Linear(in_features=3, out_features=1, bias=True)
Sequential(
  (0): BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
  (1): Linear(in_features=1, out_features=3, bias=True)
)
BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True)
Linear(in_features=1, out_features=3, bias=True)  

了解了这两个方法你应该可以完成我们的第一个目标:提取nn.Sequential的任一个模块。在下一节中我们会完成我们的第二个目标,初始化权重。在结束之前给大家个小问题:有时候你只需要底层的module,而不需要module的子类,如nn.Sequential,那么怎么去除呢?

机器学习小贴士:支持向量机的意思就是我们最后选择的模型只与支持向量有关。

                                                  最后编辑于2018-10-1019:39:53 有什么错误请不吝赐教

pytroch nn.Module源码解析(1)的更多相关文章

  1. Cognitive Graph for Multi-Hop Reading Comprehension at Scale(ACL2019) 阅读笔记与源码解析

    论文地址为:Cognitive Graph for Multi-Hop Reading Comprehension at Scale github地址:CogQA 背景 假设你手边有一个维基百科的搜索 ...

  2. [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现

    [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 目录 [源码解析] 深度学习流水线并行Gpipe(1)---流水线基本实现 0x00 摘要 0x01 概述 1.1 什么是GPip ...

  3. [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积

    [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 目录 [源码解析] 深度学习流水线并行GPipe (2) ----- 梯度累积 0x00 摘要 0x01 概述 1.1 前文回 ...

  4. [源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段

    [源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段 目录 [源码解析] 深度学习流水线并行之PopeDream(1)--- Profile阶段 0x00 摘要 0x0 ...

  5. [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型

    [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 目录 [源码解析] 深度学习流水线并行 PipeDream(3)--- 转换模型 0x00 摘要 0x01 前言 1.1 改 ...

  6. [源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎

    [源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎 目录 [源码解析] 深度学习流水线并行 PipeDream(4)--- 运行时引擎 0x00 摘要 0x01 前言 1.1 ...

  7. [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练

    [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 目录 [源码解析] 深度学习分布式训练框架 horovod (21) --- 之如何恢复训练 0x00 摘要 0 ...

  8. [源码解析] PyTorch 流水线并行实现 (1)--基础知识

    [源码解析] PyTorch 流水线并行实现 (1)--基础知识 目录 [源码解析] PyTorch 流水线并行实现 (1)--基础知识 0x00 摘要 0x01 历史 1.1 GPipe 1.2 t ...

  9. [源码解析] PyTorch 流水线并行实现 (2)--如何划分模型

    [源码解析] PyTorch 流水线并行实现 (2)--如何划分模型 目录 [源码解析] PyTorch 流水线并行实现 (2)--如何划分模型 0x00 摘要 0x01 问题 0x01 自动平衡 1 ...

随机推荐

  1. 动态解析xml,并生成excel,然后发邮件。

    直接贴代码了! DECLARE @CurrentServer NVARCHAR(100)DECLARE @CurrentDatabase NVARCHAR(100)DECLARE @CurrentLo ...

  2. Studio 5000编程:如何判断AB PLC系统中的硬件设备是否在正常工作

    前言:PLC控制系统,主要由CPU.本机架I/O模块,分布式I/O模块,通信模块,或其他设备(如:伺服驱动器.交换机.第三方设备)等组成,如何判断这些设备是否工作正常?或是一旦出现故障,能在第一时间判 ...

  3. redis集群配置与管理

    Redis在3.0版本以后开始支持集群,经过中间几个版本的不断更新优化,最新的版本集群功能已经非常完善.本文简单介绍一下Redis集群搭建的过程和配置方法,redis版本是5.0.4,操作系统是中标麒 ...

  4. 标准库类型string

    定义和初始化string对象 初始化string对象方式: string s1;//默认初始化,s1是一个字符串 string s2(s1);//s2是s1的副本 string s2 = s1;//等 ...

  5. npx 是什么?

    参考链接:https://www.jianshu.com/p/cee806439865

  6. android shape 圆圈 圆环 圆角

    定义圆圈:比如角标: xml布局文件 <TextView android:id="@+id/item_order_pay_count" android:layout_widt ...

  7. 51nod 2513

    写代码的时候抄错变量,晕! 另外有个while循环条件错的,因为两个指针必须都要有终止条件 代码: #include<iostream> #include<cstdio> #i ...

  8. CNN的反向传播

    在一般的全联接神经网络中,我们通过反向传播算法计算参数的导数.BP 算法本质上可以认为是链式法则在矩阵求导上的运用.但 CNN 中的卷积操作则不再是全联接的形式,因此 CNN 的 BP 算法需要在原始 ...

  9. ABP core学习之二 IIS部署.NET CORE

    本文是关于IIS部署.NET CORE的总结,以后有碰到问题将陆续添加 IIS部署.NET CORE总结 一.服务器环境 首先确定自己项目的core版本,然后下载对应的包在服务器上安装 下载地址: h ...

  10. centos命令安装

    1.解决ifconfig命令失效:需要安装net-tools工具 yum install net-tools 2.免密码登录 (1)通过命令,产生公钥信息 ssh-keygen -t rsa 如果提示 ...