本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson3/module_containers.py

这篇文章来看下 PyTorch 中网络模型的创建步骤。网络模型的内容如下,包括模型创建和权值初始化,这些内容都在nn.Module中有实现。

网络模型的创建步骤

创建模型有 2 个要素:构建子模块拼接子模块。如 LeNet 里包含很多卷积层、池化层、全连接层,当我们构建好所有的子模块之后,按照一定的顺序拼接起来。

这里以上一篇文章中 `lenet.py`的 LeNet 为例,继承`nn.Module`,必须实现`__init__()` 方法和`forward()`方法。其中`__init__()` 方法里创建子模块,在`forward()`方法里拼接子模块。

class LeNet(nn.Module):
# 子模块创建
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
# 子模块拼接
def forward(self, x):
out = F.relu(self.conv1(x))
out = F.max_pool2d(out, 2)
out = F.relu(self.conv2(out))
out = F.max_pool2d(out, 2)
out = out.view(out.size(0), -1)
out = F.relu(self.fc1(out))
out = F.relu(self.fc2(out))
out = self.fc3(out)
return out

当我们调用net = LeNet(classes=2)创建模型时,会调用__init__()方法创建模型的子模块。

当我们在训练时调用outputs = net(inputs)时,会进入module.pycall()函数中:

    def __call__(self, *input, **kwargs):
for hook in self._forward_pre_hooks.values():
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state():
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs)
...
...
...

最终会调用result = self.forward(*input, **kwargs)函数,该函数会进入模型的forward()函数中,进行前向传播。

torch.nn中包含 4 个模块,如下图所示。

其中所有网络模型都是继承于`nn.Module`的,下面重点分析`nn.Module`模块。

nn.Module

nn.Module 有 8 个属性,都是OrderDict(有序字典)。在 LeNet 的__init__()方法中会调用父类nn.Module__init__()方法,创建这 8 个属性。

    def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module") self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
  • _parameters 属性:存储管理 nn.Parameter 类型的参数
  • _modules 属性:存储管理 nn.Module 类型的参数
  • _buffers 属性:存储管理缓冲属性,如 BN 层中的 running_mean
  • 5 个 ***_hooks 属性:存储管理钩子函数

其中比较重要的是parametersmodules属性。

在 LeNet 的__init__()中创建了 5 个子模块,nn.Conv2d()nn.Linear()都是 继承于nn.module,也就是说一个 module 都是包含多个子 module 的。

class LeNet(nn.Module):
# 子模块创建
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
...
...
...

当调用net = LeNet(classes=2)创建模型后,net对象的 modules 属性就包含了这 5 个子网络模块。

下面看下每个子模块是如何添加到 LeNet 的`_modules` 属性中的。以`self.conv1 = nn.Conv2d(3, 6, 5)`为例,当我们运行到这一行时,首先 Step Into 进入 `Conv2d`的构造,然后 Step Out。右键`Evaluate Expression`查看`nn.Conv2d(3, 6, 5)`的属性。

上面说了`Conv2d`也是一个 module,里面的`_modules`属性为空,`_parameters`属性里包含了该卷积层的可学习参数,这些参数的类型是 Parameter,继承自 Tensor。

此时只是完成了`nn.Conv2d(3, 6, 5)` module 的创建。还没有赋值给`self.conv1 `。在`nn.Module`里有一个机制,会拦截所有的类属性赋值操作(`self.conv1`是类属性),进入到`__setattr__()`函数中。我们再次 Step Into 就可以进入`__setattr__()`。

    def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name] params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
...
...
...

在这里判断 value 的类型是Parameter还是Module,存储到对应的有序字典中。

这里nn.Conv2d(3, 6, 5)的类型是Module,因此会执行modules[name] = value,key 是类属性的名字conv1,value 就是nn.Conv2d(3, 6, 5)

总结

  • 一个 module 里可包含多个子 module。比如 LeNet 是一个 Module,里面包括多个卷积层、池化层、全连接层等子 module
  • 一个 module 相当于一个运算,必须实现 forward() 函数
  • 每个 module 都有 8 个字典管理自己的属性

模型容器

除了上述的模块之外,还有一个重要的概念是模型容器 (Containers),常用的容器有 3 个,这些容器都是继承自nn.Module

  • nn.Sequetial:按照顺序包装多个网络层
  • nn.ModuleList:像 python 的 list 一样包装多个网络层,可以迭代
  • nn.ModuleDict:像 python 的 dict 一样包装多个网络层,通过 (key, value) 的方式为每个网络层指定名称。

nn.Sequetial

在传统的机器学习中,有一个步骤是特征工程,我们需要从数据中认为地提取特征,然后把特征输入到分类器中预测。在深度学习的时代,特征工程的概念被弱化了,特征提取和分类器这两步被融合到了一个神经网络中。在卷积神经网络中,前面的卷积层以及池化层可以认为是特征提取部分,而后面的全连接层可以认为是分类器部分。比如 LeNet 就可以分为特征提取分类器两部分,这 2 部分都可以分别使用 nn.Seuqtial 来包装。

代码如下:

class LeNetSequetial(nn.Module):
def __init__(self, classes):
super(LeNet2, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(2, 2)
)
self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes)
) def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x

在初始化时,nn.Sequetial会调用__init__()方法,将每一个子 module 添加到 自身的_modules属性中。这里可以看到,我们传入的参数可以是一个 list,或者一个 OrderDict。如果是一个 OrderDict,那么则使用 OrderDict 里的 key,否则使用数字作为 key (OrderDict 的情况会在下面提及)。

    def __init__(self, *args):
super(Sequential, self).__init__()
if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)

网络初始化完成后有两个子 modulefeaturesclassifier

而`features`中的子 module 如下,每个网络层以序号作为 key:

在进行前向传播时,会进入 LeNet 的`forward()`函数,首先调用第一个`Sequetial`容器:`self.features`,由于`self.features`也是一个 module,因此会调用`__call__()`函数,里面调用

result = self.forward(*input, **kwargs),进入nn.Seuqetialforward()函数,在这里依次调用所有的 module。

    def forward(self, input):
for module in self:
input = module(input)
return input

在上面可以看到在nn.Sequetial中,里面的每个子网络层 module 是使用序号来索引的,即使用数字来作为 key。一旦网络层增多,难以查找特定的网络层,这种情况可以使用 OrderDict (有序字典)。代码中使用

class LeNetSequentialOrderDict(nn.Module):
def __init__(self, classes):
super(LeNetSequentialOrderDict, self).__init__() self.features = nn.Sequential(OrderedDict({
'conv1': nn.Conv2d(3, 6, 5),
'relu1': nn.ReLU(inplace=True),
'pool1': nn.MaxPool2d(kernel_size=2, stride=2), 'conv2': nn.Conv2d(6, 16, 5),
'relu2': nn.ReLU(inplace=True),
'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
})) self.classifier = nn.Sequential(OrderedDict({
'fc1': nn.Linear(16*5*5, 120),
'relu3': nn.ReLU(), 'fc2': nn.Linear(120, 84),
'relu4': nn.ReLU(inplace=True), 'fc3': nn.Linear(84, classes),
}))
...
...
...

总结

nn.Sequetialnn.Module的容器,用于按顺序包装一组网络层,有以下两个特性。

  • 顺序性:各网络层之间严格按照顺序构建,我们在构建网络时,一定要注意前后网络层之间输入和输出数据之间的形状是否匹配
  • 自带forward()函数:在nn.Sequetialforward()函数里通过 for 循环依次读取每个网络层,执行前向传播运算。这使得我们我们构建的模型更加简洁

nn.ModuleList

nn.ModuleListnn.Module的容器,用于包装一组网络层,以迭代的方式调用网络层,主要有以下 3 个方法:

  • append():在 ModuleList 后面添加网络层
  • extend():拼接两个 ModuleList
  • insert():在 ModuleList 的指定位置中插入网络层

下面的代码通过列表生成式来循环迭代创建 20 个全连接层,非常方便,只是在 forward()函数中需要手动调用每个网络层。

class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x net = ModuleList() print(net) fake_data = torch.ones((10, 10)) output = net(fake_data) print(output)

nn.ModuleDict

nn.ModuleDictnn.Module的容器,用于包装一组网络层,以索引的方式调用网络层,主要有以下 5 个方法:

  • clear():清空 ModuleDict
  • items():返回可迭代的键值对 (key, value)
  • keys():返回字典的所有 key
  • values():返回字典的所有 value
  • pop():返回一对键值,并从字典中删除

下面的模型创建了两个ModuleDictself.choicesself.activations,在前向传播时通过传入对应的 key 来执行对应的网络层。

class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
}) self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
}) def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x net = ModuleDict() fake_img = torch.randn((4, 10, 32, 32)) output = net(fake_img, 'conv', 'relu')
# output = net(fake_img, 'conv', 'prelu')
print(output)

容器总结

  • nn.Sequetial:顺序性,各网络层之间严格按照顺序执行,常用于 block 构建,在前向传播时的代码调用变得简洁
  • nn.ModuleList:迭代行,常用于大量重复网络构建,通过 for 循环实现重复构建
  • nn.ModuleDict:索引性,常用于可选择的网络层

PyTorch 中的 AlexNet

AlexNet 是 Hinton 和他的学生等人在 2012 年提出的卷积神经网络,以高出第二名 10 多个百分点的准确率获得 ImageNet 分类任务冠军,从此卷积神经网络开始在世界上流行,是划时代的贡献。

AlexNet 特点如下:

  • 采用 ReLU 替换饱和激活 函数,减轻梯度消失
  • 采用 LRN (Local Response Normalization) 对数据进行局部归一化,减轻梯度消失
  • 采用 Dropout 提高网络的鲁棒性,增加泛化能力
  • 使用 Data Augmentation,包括 TenCrop 和一些色彩修改

AlexNet 的网络结构可以分为两部分:features 和 classifier。

在`PyTorch`的计算机视觉库`torchvision.models`中的 AlexNet 的代码中,使用了`nn.Sequential`来封装网络层。

class AlexNet(nn.Module):

    def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(64, 192, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(192, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
) def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x

参考资料

如果你觉得这篇文章对你有帮助,不妨点个赞,让我有更多动力写出好文章。

[PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module的更多相关文章

  1. pytorch(11)模型创建步骤与nn.Module

    模型创建与nn.Module 网络模型创建步骤 nn.Module graph LR 模型 --> 模型创建 模型创建 --> 构建网络层 构建网络层 --> id[卷积层,池化层, ...

  2. V-rep学习笔记:机器人模型创建3—搭建动力学模型

    接着之前写的V-rep学习笔记:机器人模型创建2—添加关节继续机器人创建流程.如果已经添加好关节,那么就可以进入流程的最后一步:搭建层次结构模型和模型定义(build the model hierar ...

  3. V-rep学习笔记:机器人模型创建2—添加关节

    下面接着之前经过简化并调整好视觉效果的模型继续工作流,为了使模型能受控制运动起来必须在合适的位置上添加相应的运动副/关节.一般情况下我们可以查阅手册或根据设计图纸获得这些关节的准确位置和姿态,知道这些 ...

  4. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  5. pytorch学习笔记(十二):详解 Module 类

    Module 是 pytorch 提供的一个基类,每次我们要 搭建 自己的神经网络的时候都要继承这个类,继承这个类会使得我们 搭建网络的过程变得异常简单. 本文主要关注 Module 类的内部是怎么样 ...

  6. V-rep学习笔记:机器人模型创建4—定义模型

    完成之前的操作后终于来到最后一步——定义模型,即将之前创建的几何体.关节等元素按层级关系组织成为一个整体. 将最后一个连杆robot_link_dyn6拖放到相应的关节(robot_joint6)下, ...

  7. V-rep学习笔记:机器人模型创建1—模型简化

    要进行机器人仿真首先需要得到机器人的几何模型.我们可以直接通过VREP中提供的基本几何体来搭建一个简易的机器人[Menu bar --> Add --> Primitive shape - ...

  8. PyTorch学习笔记之CBOW模型实践

    import torch from torch import nn, optim from torch.autograd import Variable import torch.nn.functio ...

  9. PyTorch学习笔记之n-gram模型实现

    import torch import torch.nn as nn from torch.autograd import Variable import torch.nn.functional as ...

随机推荐

  1. The JOIN operation -- SQLZOO

    The JOIN operation 注意:where语句中对表示条件的需要用单引号, 下面的译文使用的是有道翻译如有不正确,请直接投诉有道 01.Modify it to show the matc ...

  2. samba服务及配置

    samba 目录 samba 1. samba简介 2. samba访问 配置示例 3.搭建用户认证共享服务器 4.搭建匿名用户共享服务器 1. samba简介 Samba是在Linux和UNIX系统 ...

  3. day3. 六大标准数据类型的类型转换

    一.强制类型转换Number 1.int  强制转换成整型 var1 = 13 var2 = 13.789 var3 = True var4 = 5-7j var5 = "" va ...

  4. 认识SpringData JPA

    简介 JPA全称Java Persistence API,中文名是Java持久层API.用来描述对象-关系表的映射关系,并将运行期的实体对象持久化到数据库中. 名词解释 RDS:关系型数据库服务 Re ...

  5. "点"醒自己

    回顾: 过去的经历 从18年开始在某机构进行了测试培训,9月正式加入测试小白大军,中间也经历了半年的空窗期,入职2个月应能力不够被公司辞退,后通过培训班的朋友内推到一家软件外包公司,工作到现在. 思考 ...

  6. Ajax中关于xmlhttp.readyState的值及解释

    xmlhttp.readyState的值及解释:0:请求未初始化(还没有调用 open()).1:请求已经建立,但是还没有发送(还没有调用 send()).2:请求已发送,正在处理中(通常现在可以从响 ...

  7. CI4框架应用六 - 控制器应用

    这节我们来分析一下控制器的应用,我们看到系统提供的控制器都是继承自一个BaseController,我们来分析一下这个BaseController的作用 use CodeIgniter\Control ...

  8. 再也不怕别人动电脑了!用Python实时监控

    作者:美图博客 https://www.meitubk.com/zatan/386.html 前言 最近突然有个奇妙的想法,就是当我对着电脑屏幕的时候,电脑会先识别屏幕上的人脸是否是本人,如果识别是本 ...

  9. Python最好IDE:Pycharm使用小技巧总结,让你写代码更为舒适

  10. DotNet Core

    安装 dotnet add package Pomelo.EntityFrameworkCore.MySql 使用 MySQL 作为后端     在继承 DbContext 类中重写 OnConfig ...