containers

graph LR
A["Containers"] --> B["nn.Sequetial"]
B["nn.Sequetial"] --> C["wrap multiple network layers in sequence"]
A["Containers"] --> D["nn.ModuleList"]
D["nn.ModuleList"] --> E["wrap multiple network layers in the way like the list of python"]
A["Containers"] --> F["nn.ModuleDict"]
F["nn.ModuleDict"] --> G["wrap multiple network layers in the way like the dict of python"]

the sequential of the container

The nn.Sequential is the container of the nn.module,wrapping the network layers in sequence.

在传统的机器学习中,有一步叫做特征工程。需要人为的设计特征,将特征送去分类器中分类。在深度学习的阶段,就已经弱化了特征工程的概念,尤其是卷积神经网络,对图像的特征我们完全不需要去设计,都可以交给卷积神经网络自动去学习特征,最后会加上几个全连接层,用于输出分类结果。而在早期的神经网络中,用于分类的分类器正是全连接构成的,所以在神经网络阶段,也有习惯以全连接层为界限,将网络模型分为特征提取模块和分类模块,对一个大的模型进行划分,对模型进行管理。

e.g

LeNet:

Conv1 --> pool1 --> Conv2 --> pool2 --> fc1 --> fc2 --> fc3

from the [Conv1 to pool2] is features(特征提取器)

from the [fc1 to fc3] is classifier(分类器)

**nn.Sequentail is the container of the nn.module,wrapping the network layers in sequence.

  • sequential:the construction of the layers is strictly in order
  • 自带forward():自带forward,通过for循环一次执行前向传播运算。**
    def __init__(self, classes):
super(LeNetSequential, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 6, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(6, 16, 5),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),) self.classifier = nn.Sequential(
nn.Linear(16*5*5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, classes),)

以上就完成了模型构建的构建子模块,我们可以看到子模块就两个概念,feature和classifier。紧接着定义它的前向传播forward。

def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x

只需要把输入x输入到features,features输出再经过一个形状的变换,再输入到classifier中得到一个分类的结果。

在代码中的

net = LeNetSequential(classes=2)

中设置断点进行debug,进行观察如何构建一个nn.sequential

nn.sequential给它输入一系列的网络层就能构成一个容器。我们在它的结构体内进行debug,在最后一个maxpool2d层进行step into。此时进入到container.py的sequential类,class Sequential(Module)表明了sequential也是继承的module类。既然是继承module类的,那么也会有8个参数来管理属性。

if len(args) == 1 and isinstance(args[0], OrderedDict):
for key, module in args[0].items():
self.add_module(key, module)
else:
for idx, module in enumerate(args):
self.add_module(str(idx), module)

先判断传进来的参数判断,数据类型是不是一个有序质点。在这里应该不是一个有序质点,因此进入else,对循环传入进来参数进行取出每一个网络层,然后采用Module类方法,将网络层添加到sequential中。不断的添加进来。完成了一个sequential构成。然后回到self.features = nn.Sequential。发现此时的modules还是空的,是因为我们没有进行赋值“=”的运算,只是进行了等号右边的sequential的构建。再进行。

目前进入到含有sequential的modules不为空,由于本身sequential也是一个modules,因此自身也可以进行探查。在sequential的Module里就有网络层了,是按顺序构成的。

接下来我们进入

output = net(fake_img)

进行step into,观察前向传播是如何实现的。进入内部后,我们通过_call_impl的函数体内部进行观察,能够看出在

result = self.forward(*input, **kwargs)

设置step into。此时来到了原始的代码中class LeNetSequential(nn.Module):的def forward(self, x)。我们发现,只需要把x赋值给self.features(x),那么我们就能够自动执行这六层的网络传播。比之前的方法简单很多,我们进入self.features看看为什么那么简洁,我们进入feature后,因为self.feature是一个sequential,sequential又继承与module,因此我们又会进入module.py的call函数中。我们直接进入result = self.forward(*input,**kwargs),这下我们来到了container.py中的class Sequential(Module):函数中的def foward,代码如下

def forward(self, input):
for module in self:
input = module(input)
return input

此时我们能看出这段代码非常简洁,其主要思想就是利用for循环对网络进行循环的forward。我们进行调试,从modules中取出module进行forwad,将我们的input迭代输入到卷积、relu、maxpool等

这里发现网络层在上一节中是有name的,比如conv1,conv2,但是在这里是没有的,通过序号来索引的。因此在复杂的网络中,很难通过序号去索引,因此采用dict去命名索引。

LeNetSequentialOrderDict

核心代码如下:


class LeNetSequentialOrderDict(nn.Module):
def __init__(self, classes):
super(LeNetSequentialOrderDict, self).__init__() self.features = nn.Sequential(OrderedDict({
'conv1': nn.Conv2d(3, 6, 5),
'relu1': nn.ReLU(inplace=True),
'pool1': nn.MaxPool2d(kernel_size=2, stride=2), 'conv2': nn.Conv2d(6, 16, 5),
'relu2': nn.ReLU(inplace=True),
'pool2': nn.MaxPool2d(kernel_size=2, stride=2),
})) self.classifier = nn.Sequential(OrderedDict({
'fc1': nn.Linear(16*5*5, 120),
'relu3': nn.ReLU(), 'fc2': nn.Linear(120, 84),
'relu4': nn.ReLU(inplace=True), 'fc3': nn.Linear(84, classes),
})) def forward(self, x):
x = self.features(x)
x = x.view(x.size()[0], -1)
x = self.classifier(x)
return x

通过单步调试去理解sequential是如何通过dict给网络命名的。

容器之ModuleList

nn.ModuleList是 nn.module的容器,用于包装一组网络层,以迭代方式调用网络层

主要方法:

  • append():在ModuleList后面添加网络层
  • extend():拼接两个ModuleList
  • insert():指定在ModuleList中位置插入网络层

    modulelist就可以采用列表生成式,通过for循环去生成网络层。例子生成一个20个全连接层,每个全连接层有10个神经元的网络

    其主要的核心算法分别为:
class ModuleList(nn.Module):
def __init__(self):
super(ModuleList, self).__init__()
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(20)]) def forward(self, x):
for i, linear in enumerate(self.linears):
x = linear(x)
return x

以及核心的modulelist代码:

def __init__(self, modules: Optional[Iterable[Module]] = None) -> None:
super(ModuleList, self).__init__()
if modules is not None:
self += modules

此时modules是一个List,因此通过不断的+的拼接合成一个modulelist.

容器之ModuleLDict

nn.ModuleDict是 nn.module的容器,用于包装一组网络层,以索引方式调用网络层。字典一样包装,给每一个网络层加上一个名称,可以通过名称去索引网络层。

主要方法:

  • clear():清空ModuleDict
  • items():返回可迭代的键值对(key-value pairs)
  • keys():返回字典的键(key)
  • values():返回字典的值(value)
  • pop():返回一对键值,并从字典中删除

    核心代码

class ModuleDict(nn.Module):
def __init__(self):
super(ModuleDict, self).__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
}) self.activations = nn.ModuleDict({
'relu': nn.ReLU(),
'prelu': nn.PReLU()
}) def forward(self, x, choice, act):
x = self.choices[choice](x)
x = self.activations[act](x)
return x

和之前的sequentaildict的不同,sequentaildict是有序的,moduledict是无序的,能够通过dict对Key来指定网络层。这里的forward会增加两个参数,choice和act来选择我们的网络层。因此在后面的output就需要进行修改,代码如下:

output = net(fake_img,'conv','prelu')

修改后的Output用来指定对应的网络和激活函数

之前的output无需修改,为

output = net(fake_img)

容器总结

  • nn.Sequential:顺序性,各网络层之间严格按顺序执行,常用于block构建
  • nn.ModuleList:迭代性,常用于大量重复网构建,通过for循环实现重复构建
  • nn.ModuleDict:索引性,常用于可选择的网络层

代码

采用sequential容器,改写Alexnet,给features中每一个网络层增加名字,并通过下面这行代码打印出来

print(alexnet._modules['features']._modules.keys())

改写的Alexnet如下:

import torch
import torch.nn as nn
from .utils import load_state_dict_from_url
from collections import OrderedDict __all__ = ['AlexNet', 'alexnet'] model_urls = {
'alexnet': 'https://download.pytorch.org/models/alexnet-owt-4df8aa71.pth',
} class AlexNet(nn.Module): def __init__(self, num_classes=1000):
super(AlexNet, self).__init__()
self.features = nn.Sequential(OrderedDict({
'conv1' : nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
'relu1': nn.ReLU(inplace=True),
'maxpool2d1': nn.MaxPool2d(kernel_size=3, stride=2),
'conv2':nn.Conv2d(64, 192, kernel_size=5, padding=2),
'relu2':nn.ReLU(inplace=True),
'MaxPool2d2': nn.MaxPool2d(kernel_size=3, stride=2),
'Conv2d3': nn.Conv2d(192, 384, kernel_size=3, padding=1),
'ReLU3': nn.ReLU(inplace=True),
'Conv2d4':nn.Conv2d(384, 256, kernel_size=3, padding=1),
'ReLU4': nn.ReLU(inplace=True),
'Conv2d5': nn.Conv2d(256, 256, kernel_size=3, padding=1),
'ReLU5': nn.ReLU(inplace=True),
'MaxPool2d3':nn.MaxPool2d(kernel_size=3, stride=2),
}))
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(OrderedDict({
'dropout1': nn.Dropout(),
'linear1': nn.Linear(256 * 6 * 6, 4096),
'relu1': nn.ReLU(inplace=True),
'dropout2': nn.Dropout(),
'linear2': nn.Linear(4096, 4096),
'relu2': nn.ReLU(inplace=True),
'linear3': nn.Linear(4096, num_classes),
})) def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x def alexnet(pretrained=False, progress=True, **kwargs):
r"""AlexNet model architecture from the
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = AlexNet(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['alexnet'],
progress=progress)
model.load_state_dict(state_dict)
return model

输出结果为

odict_keys(['conv1', 'relu1', 'maxpool2d1', 'conv2', 'relu2', 'MaxPool2d2', 'Conv2d3', 'ReLU3', 'Conv2d4', 'ReLU4', 'Conv2d5', 'ReLU5', 'MaxPool2d3'])

如果使用如下代码:

alexnet = torchvision.models.AlexNet()
print(alexnet._modules['classifier']._modules.keys())

得到的输出结果为:

odict_keys(['dropout1', 'linear1', 'relu1', 'dropout2', 'linear2', 'relu2', 'linear3'])

pytorch(12)ContainersAndAlexNet的更多相关文章

  1. Netruon 理解(12):使用 Linux bridge 将 Linux network namespace 连接外网

    学习 Neutron 系列文章: (1)Neutron 所实现的虚拟化网络 (2)Neutron OpenvSwitch + VLAN 虚拟网络 (3)Neutron OpenvSwitch + GR ...

  2. 基于MVC4+EasyUI的Web开发框架经验总结(12)--利用Jquery处理数据交互的几种方式

    在基于MVC4+EasyUI的Web开发框架里面,大量采用了Jquery的方法,对数据进行请求或者提交,方便页面和服务器后端进行数据的交互处理.本文主要介绍利用Jquery处理数据交互的几种方式,包括 ...

  3. Web 在线文件管理器学习笔记与总结(11)获取文件夹信息 (12)返回上一级操作

    (11)获取文件夹信息 文件夹没有修改操作. index.php: <?php require 'dir.func.php'; require 'file.func.php'; require ...

  4. Android菜鸟的成长笔记(12)——Handler、Loop、MessageQueue

    原文:[置顶] Android菜鸟的成长笔记(12)——Handler.Loop.MessageQueue 当一个程序第一次启动时,Android会启动一条主线程(Main Thread),主线程主要 ...

  5. Windows Phone开发(12):认识一下独具个性的磁贴

    原文:Windows Phone开发(12):认识一下独具个性的磁贴 对"磁贴"的理解是一点也不抽象的,为什么呢?只要你愿意启动WP系统,无论你是在模拟器中还是在真机中,是的,桌面 ...

  6. 网站静态化处理—web前端优化—中(12)

    网站静态化处理—web前端优化—中(12) Web前端很多优化原则都是从如何提升网络通讯效率的角度提出的,但是这些原则使用的时候还是有很多陷阱在里面,如果我们不能深入理解这些优化原则背后所隐藏的技术原 ...

  7. python入门(12)dict

    python入门(12)dict Python内置了字典:dict的支持,dict全称dictionary,在其他语言中也称为map,使用键-值(key-value)存储,具有极快的查找速度. 举个例 ...

  8. (1-2)line-height的各类属性值

    (1-2)line-height的各类属性值 首先来个疑问!没有问题印象不深嘛 一.line-height支持哪些属性值呢? 五只手指头就能数过来了咯. 比如normal, <number> ...

  9. Java设计模式(12)迭代模式(Iterator模式)

    上了这么多年学,我发现一个问题,好象老师都很喜欢点名,甚至点名都成了某些老师的嗜好,一日不点名,就饭吃不香,觉睡不好似的,我就觉得很奇怪,你的课要是讲的好,同学又怎么会不来听课呢,殊不知:“误人子弟, ...

随机推荐

  1. 解决M1 MacBook无法使用pip安装Numpy

    问题描述 Python官方已发布支持M1 Apple Silicon的版本,但是在使用pip包管理工具安装一些依赖时发生了错误,这里面就包括在科学计算领域常用的numpy.pandas等.目前可以通过 ...

  2. MySQL 企业案例:误删核心业务表

    问题描述: 1.正在运行的网站系统,MySQL 数据库,数据量 25G,日业务增量 10 - 15M 2.备份策略:每天 23:00,计划任务调用 mysqldump 执行全备脚本 3.故障时间点:上 ...

  3. CentOS7系统时间和硬件时间不同步问题

    CentOS7系统中有两个时间:系统时间 和 硬件时间 我们常用命令 date 会输出系统时间,用 date 命令修改的也是系统时间 硬件时间是写入到 BIOS 中的时间,用 hwclock -r 命 ...

  4. oslab oranges 一个操作系统的实现 实验四 认识保护模式(三):中断异常

    实验目的: 理解中断与异常机制的实现机理 对应章节:第三章3.4节,3.5节 实验内容: 1. 理解中断与异常的机制 2. 调试8259A的编程基本例程 3. 调试时钟中断例程 4. 建立IDT,实现 ...

  5. 局部变量 static new 结构体指针

    struct ListNode { int val; ListNode* next; ListNode(int x) : val(x), next(NULL) {} }; 有一个函数利用LisNode ...

  6. IPC$入侵

    一 唠叨一下: 网上关于ipc$入侵的文章可谓多如牛毛,而且也不乏优秀之作,攻击步骤甚至可以说已经成为经典的模式,因此也没人愿意再把这已经成为定式的东西拿出来摆弄. 二 什么是ipc$ IPC$(In ...

  7. when I was installing github for windows ,some errors occurred !

    1: 2: 3: 4: install.log error messages:

  8. 慕课网站 & MOOC website

    慕课网站 & MOOC website MOOC, massive open online course Mooc for everyone ! 国家精品课程 & 在线学习平台 慕课平 ...

  9. Immutable.js 实现原理

    Immutable.js 实现原理 Immutable collections for JavaScript v4.0.0-rc.12 released on Oct 31, 2018 https:/ ...

  10. Fetch & POST

    Fetch & POST fetch( `http://10.1.5.202/deploy/http/send/viewtree`, { method: "POST", m ...