网络中的网络NIN

之前介绍的LeNet,AlexNet,VGG设计思路上的共同之处,是加宽(增加卷积层的输出的channel数量)和加深(增加卷积层的数量),再接全连接层做分类.  

NIN提出了一个不同的思路,串联多个由卷积层和'全连接层'(1x1卷积)构成的小网络来构建一个深层网络.

论文地址:https://arxiv.org/pdf/1312.4400.pdf

nin的重点我总结主要就2点:

  • mlpconv的提出(我们用1x1卷积实现),整合多个feature map上的特征.进一步增强非线性.
  • 全局平均池化替代全连接层

推荐一篇我觉得不错的解读博客:https://blog.csdn.net/hjimce/article/details/50458190

1x1卷积



1x1卷积对channel维度上的元素做乘加操作.

如上图所示,由于1x1卷积对空间维度上的元素并没有做关联,所以空间维度(h,w)上的信息得以传递到后面的层中.

举个例子,以[h,w,c]这种顺序为例,1x1卷积只会将[0,0,0],[0,0,1],[0,0,2]做乘加操作.

[0,0,x]的元素和[0,1,x]的元素是不会发生关系的.

NIN结构

NIN Net是在AlexNet的基础上提出的他们的结构分别如下所示:

AlexNet结构如下:



注意,这个图里的maxpool是在第一二五个卷积层以后.这个图稍微有点误导.即11x11的卷积核后做maxpool,再做卷积.而不是卷积-卷积-池化.

NIN结构如下:



这是网上找的一个示意图,nin的论文里并没有完整的结构图.

这个图有一点不对,最后一个卷积那里应该用的卷积核的shape应该是3x3x384.共1000个,下图红圈处应该是3x3x384x1000,1000,1000.对应到我们的实现,应该是3x3x384x10,10,10.因为我们的数据集只有10个类别.

下面我们先来实现卷积部分:

首先我们定义nin的'小网络'模块.即'常规卷积-1x1卷积-1x1卷积'这一部分.

def make_layers(in_channels,out_channels,kernel_size, stride, padding):
conv = nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size, stride, padding),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=1, stride=1, padding=0),#1x1卷积,整合多个feature map的特征
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=1, stride=1, padding=0),#1x1卷积,整合多个feature map的特征
nn.ReLU(inplace=True)
) return conv

然后对于网络的卷积部分,我们就可以写出如下代码

conv1 = make_layers(1,96,11,4,2)
pool1 = nn.MaxPool2d(kernel_size=3,stride=2)
conv2 = make_layers(96,256,kernel_size=5,stride=1,padding=2)
pool2 = nn.MaxPool2d(kernel_size=3,stride=2)
conv3 = make_layers(256,384,kernel_size=3,stride=1,padding=1)
pool3 = nn.MaxPool2d(kernel_size=3,stride=2)
conv4 = make_layers(384,10,kernel_size=3,stride=1,padding=1)

我们来验证一下模型

X = torch.rand(1, 1, 224, 224)
o1 = conv1(X)
print(o1.shape) #[1,96,55,55]
o1_1 = pool1(o1)
print(o1_1.shape) #[1,96,27,27] o2 = conv2(o1_1)
print(o2.shape) #[1,256,27,27]
o2_1 = pool2(o2)
print(o2_1.shape) #[1,256,13,13] o3 = conv3(o2_1)
print(o3.shape) #[1,384,13,13]
o3_1 = pool3(o3)
print(o3_1.shape) #[1,384,6,6] o4 = conv4(o3_1)
print(o4.shape) #[1,10,6,6]

每一层的输出shape都是对的,说明我们模型写对了.如果不对,我们就去调整make_layers()的参数,主要是padding.

卷积部分得到[1,10,6,6]的输出以后,我们要做一个全局平均池化,全局平均池化什么意思呢?

我们先看普通池化,比方说一个10x10的输入,用2x2的窗口去做池化,然后这个窗口不断地滑动,从而对不同的2x2区域可以做求平均(平均池化),取最大值(最大值池化)等.这个就可以理解为'局部'的池化,2x2是10x10的一部分嘛.

相应地,所谓全局池化,自然就是用一个和输入大小一样的窗口做池化,即对全部的输入做池化操作.

所以我们可以实现出全局平均池化部分:

ap = nn.AvgPool2d(kernel_size=6,stride=1)
o5 = ap(o4)
print(o5.shape) #[1,10,1,1]

torch中的nn模块已经提供了平均池化操作函数,我们要做的就是把kernel_size赋值成和输入的feature map的size一样大小就好了,这样就实现了全局平均池化.

全局平均池化的重要意义

用全局平均池化替代全连接层,一个显而易见的好处就是,参数量极大地减少了,从而也防止了过拟合.

另一个角度看,是从网络结构上做正则化防止过拟合.比方说[1,10,6,6]的输入,即10个6x6的feature map,我们做全局平均池化后得到[1,10,1,1]的输出,展平后即10x1的输出,这10个标量,我们认为代表十个类别.训练的过程就是使这十个标量不断逼近代表真实类别的标量的过程.这使得模型的可解释性更好了.

参考:https://zhuanlan.zhihu.com/p/46235425

基于以上讨论,我们可以给出NinNet定义如下:

class NinNet(nn.Module):
def __init__(self):
super(NinNet, self).__init__()
self.conv = nn.Sequential(
make_layers(1,96,11,4,2),
nn.MaxPool2d(kernel_size=3,stride=2),
make_layers(96,256,kernel_size=5,stride=1,padding=2),
nn.MaxPool2d(kernel_size=3,stride=2),
make_layers(256,384,kernel_size=3,stride=1,padding=1),
nn.MaxPool2d(kernel_size=3,stride=2),
make_layers(384,10,kernel_size=3,stride=1,padding=1)
) self.gap = nn.Sequential(
nn.AvgPool2d(kernel_size=6,stride=1)
) def forward(self, img):
feature = self.conv(img)
output = self.gap(feature)
output = output.view(img.shape[0],-1)#[batch,10,1,1]-->[batch,10] return output

我们可以简单测试一下:

X = torch.rand(1, 1, 224, 224)
net = NinNet()
for name,module in net.named_children():
X = module(X)
print(name,X.shape)

输出

conv torch.Size([1, 10, 6, 6])
gap torch.Size([1, 10, 1, 1])

接下来就是熟悉的套路:

加载数据

batch_size,num_workers=16,4
train_iter,test_iter = learntorch_utils.load_data(batch_size,num_workers,resize=224)

定义模型

net = NinNet().cuda()
print(net)

定义损失函数

loss = nn.CrossEntropyLoss()

定义优化器

opt = torch.optim.Adam(net.parameters(),lr=0.001)

定义评估函数

def test():
start = time.time()
acc_sum = 0
batch = 0
for X,y in test_iter:
X,y = X.cuda(),y.cuda()
y_hat = net(X)
acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item()
batch += 1
#print('acc_sum %d,batch %d' % (acc_sum,batch)) acc = 1.0*acc_sum/(batch*batch_size)
end = time.time()
print('acc %3f,test for test dataset:time %d' % (acc,end - start)) return acc

训练

num_epochs = 3
def train():
for epoch in range(num_epochs):
train_l_sum,batch,acc_sum = 0,0,0
start = time.time()
for X,y in train_iter:
# start_batch_begin = time.time()
X,y = X.cuda(),y.cuda()
y_hat = net(X)
acc_sum += (y_hat.argmax(dim=1) == y).float().sum().item() l = loss(y_hat,y)
opt.zero_grad()
l.backward() opt.step()
train_l_sum += l.item() batch += 1 mean_loss = train_l_sum/(batch*batch_size) #计算平均到每张图片的loss
start_batch_end = time.time()
time_batch = start_batch_end - start train_acc = acc_sum/(batch*batch_size)
if batch % 100 == 0:
print('epoch %d,batch %d,train_loss %.3f,train_acc:%.3f,time %.3f' %
(epoch,batch,mean_loss,train_acc,time_batch)) if batch % 1000 == 0:
model_state = net.state_dict()
model_name = 'nin_epoch_%d_batch_%d_acc_%.2f.pt' % (epoch,batch,train_acc)
torch.save(model_state,model_name) print('***************************************')
mean_loss = train_l_sum/(batch*batch_size) #计算平均到每张图片的loss
train_acc = acc_sum/(batch*batch_size) #计算训练准确率
test_acc = test() #计算测试准确率
end = time.time()
time_per_epoch = end - start
print('epoch %d,train_loss %f,train_acc %f,test_acc %f,time %f' %
(epoch + 1,mean_loss,train_acc,test_acc,time_per_epoch)) train()

部分输出如下

epoch 0,batch 3600,train_loss 0.070,train_acc:0.603,time 176.200
epoch 0,batch 3700,train_loss 0.069,train_acc:0.606,time 181.160
***************************************
acc 0.701800,test for test dataset:time 11
epoch 1,train_loss 0.069109,train_acc 0.607550,test_acc 0.701800,time 195.619591
epoch 1,batch 100,train_loss 0.044,train_acc:0.736,time 5.053
epoch 1,batch 200,train_loss 0.047,train_acc:0.727,time 10.011
epoch 1,batch 300,train_loss 0.048,train_acc:0.735,time 15.210

可以看到由于没有了全连接层,训练时间明显缩短.

完整代码戳我

从头学pytorch(十七):网络中的网络NIN的更多相关文章

  1. 从头学pytorch(一):数据操作

    跟着Dive-into-DL-PyTorch.pdf从头开始学pytorch,夯实基础. Tensor创建 创建未初始化的tensor import torch x = torch.empty(5,3 ...

  2. 从头学pytorch(九):模型构造

    模型构造 nn.Module nn.Module是pytorch中提供的一个类,是所有神经网络模块的基类.我们自定义的模块要继承这个基类. import torch from torch import ...

  3. Linux从头学02:x86中内存【段寻址】方式的来龙去脉

    作 者:道哥,10+年的嵌入式开发老兵. 公众号:[IOT物联网小镇],专注于:C/C++.Linux操作系统.应用程序设计.物联网.单片机和嵌入式开发等领域. 公众号回复[书籍],获取 Linux. ...

  4. 从头学pytorch(二十):残差网络resnet

    残差网络ResNet resnet是何凯明大神在2015年提出的.并且获得了当年的ImageNet比赛的冠军. 残差网络具有里程碑的意义,为以后的网络设计提出了一个新的思路. googlenet的思路 ...

  5. 从头学pytorch(三) 线性回归

    关于什么是线性回归,不多做介绍了.可以参考我以前的博客https://www.cnblogs.com/sdu20112013/p/10186516.html 实现线性回归 分为以下几个部分: 生成数据 ...

  6. 从头学pytorch(六):权重衰减

    深度学习中常常会存在过拟合现象,比如当训练数据过少时,训练得到的模型很可能在训练集上表现非常好,但是在测试集上表现不好. 应对过拟合,可以通过数据增强,增大训练集数量.我们这里先不介绍数据增强,先从模 ...

  7. 从头学pytorch(十八):GoogLeNet

    GoogLeNet GoogLeNet和vgg分别是2014的ImageNet挑战赛的冠亚军.GoogLeNet则做了更加大胆的网络结构尝试,虽然深度只有22层,但大小却比AlexNet和VGG小很多 ...

  8. 从头学pytorch(二) 自动求梯度

    PyTorch提供的autograd包能够根据输⼊和前向传播过程⾃动构建计算图,并执⾏反向传播. Tensor Tensor的几个重要属性或方法 .requires_grad 设为true的话,ten ...

  9. 从头学pytorch(十二):模型保存和加载

    模型读取和存储 总结下来,就是几个函数 torch.load()/torch.save() 通过python的pickle完成序列化与反序列化.完成内存<-->磁盘转换. Module.s ...

随机推荐

  1. H3C 帧中继协议栈

  2. axis2 wsdl2java工具

    wsdl2java工具使用方法描述: C:\Users\Administrator>wsdl2java -h Using AXIS2_HOME: E:\Apache_Projects\axis2 ...

  3. 符合阿里巴巴代码规范的checkstyle检测文件

    一.安装与简介 eclipse和idea都有对应的插件,找到插件安装界面.搜索checkstyle,点击安装后,重启IDE即可.(网上有很多安装教程,就不重复制造轮子了) 二.导入配置文件 在chec ...

  4. LR性能测试自动化集成JENKINS

    LR11不支持JENKINS集成,解决方案可以使用BAT代替执行,JENKINS定时调用BAT执行性能测试用例.   1. 先随便录制l一个LR脚本,保存为 D:\TEST\test01 2. 打开 ...

  5. 高并发WEB服务的演变

    一.越来越多的并发连接数 现在的Web系统面对的并发连接数在近几年呈现指数增长,高并发成为了一种常态,给Web系统带来不小的挑战.以最简单粗暴的方式解决,就是增加 Web系统的机器和升级硬件配置.虽然 ...

  6. vue中处理时间格式化的问题

    vue main.js中修改Date原型链,插入(百度) Date.prototype.format = function(fmt) { var o = { "M+" : this ...

  7. D Thanking-Bear magic

    题目描述 In order to become a magical girl, Thinking-Bear are learning magic circle. He first drew a reg ...

  8. JMETER+JENKINS接口测试持续集成

    FIDDER+ANT+JENKINS+JMETER+SVN+tomcat接口测试集成 操作流程: 1.测试人员通过FIDDER过滤抓取接口调用信息,导出成jmx文件.(jmeter支持命令行方式调用j ...

  9. Vue中通过属性绑定为元素绑定style行内样式

    1.直接在元素上通过:style的形式,书写样式对象 2.将样式对象定义在data中,并直接引用到:style中 3.在:style中通过数组,引用多个data上的样式对象

  10. Linux 内核列举设备和驱动

    如果你在编写总线级别的代码, 你可能不得不对所有已经注册到你的总线的设备或驱动进 行一些操作. 它可能会诱惑人直接进入 bus_type 结构中的各种结构, 但是最好使用已经 提供的帮助函数. 为操作 ...