VGG 主要有两种结构,分别是 VGG16 和 VGG19,两者并没有本质上的区别,只是网络深度不一样。

对于给定的感受野,采用堆积的小卷积核是优于采用大的卷积核的,因为多层非线性层可以增加网络深度来保证学习更复杂的模式,而且代价还比较小(参数更少)。

比如,三个步长为 $1$ 的 $3 \times 3$ 卷积核的叠加,即对应 $7 \times 7$ 的感受野(即三个 $3 \times 3$ 连续卷积相当于一个 $7 \times 7$ 卷积),如果我们假设卷积输入输出的 channel 数均为 $C$,那么三个 $3 \times 3$ 连续卷积的参数总量为 $3 \times (9 C^2)$(一个卷积核的大小 $3 \times 3 \times C$,一层有 $C$ 个卷积核,总共有三层),而如果直接使用 $7 \times 7$ 卷积核,其参数总量为 $49 C^2$。很明显 $27C^2 < 49C^2$,即减少了参数。

VGG的网络结构非常简单,全部都是 $(3,3)$ 的卷积核,步长为 $1$,四周补 $1$ 圈 $0$(使得输入输出的 $(H,W)$ 不变):

我们可以参考 torchvision.models 中的源码:

class VGG(nn.Module):

    def __init__(self, features, num_classes=1000, init_weights=True):
super(VGG, self).__init__()
self.features = features
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.classifier = nn.Sequential(
nn.Linear(512 * 7 * 7, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(True),
nn.Dropout(),
nn.Linear(4096, num_classes),
)
if init_weights:
self._initialize_weights() def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0) def make_layers(cfg, batch_norm=False):
layers = []
in_channels = 3
for v in cfg:
if v == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = v
return nn.Sequential(*layers) cfgs = {
'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
} def _vgg(arch, cfg, batch_norm, pretrained, progress, **kwargs):
if pretrained:
kwargs['init_weights'] = False
model = VGG(make_layers(cfgs[cfg], batch_norm=batch_norm), **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model

cfgs 中的 "A,B,D,E" 就分别对应上表中的 A,B,D,E 四种网络结构。

'M' 代表最大池化层, nn.MaxPool2d(kernel_size=2, stride=2)

从代码中可以看出,经过 self.features 的一大堆卷积层之后,先是经过一个 nn.AdaptiveAvgPool2d((7, 7)) 的自适应平均池化层(自动选择参数使得输出的 tensor 满足 $(H = 7,W = 7)$),再通过一个 torch.flatten(x, 1) 推平成一个 $(batch\_size, n)$ 的 tensor,再连接后面的全连接层。

torchvision.models.vgg19_bn(pretrained=False, progress=True, **kwargs)

pretrained: 如果设置为 True,则返回在 ImageNet 预训练过的模型。

关于torchvision.models中VGG的笔记的更多相关文章

  1. PyTorch源码解读之torchvision.models(转)

    原文地址:https://blog.csdn.net/u014380165/article/details/79119664 PyTorch框架中有一个非常重要且好用的包:torchvision,该包 ...

  2. 转:openwrt中luci学习笔记

    原文地址:openwrt中luci学习笔记 最近在学习OpenWrt,需要在OpenWrt的WEB界面增加内容,本文将讲述修改OpenWrt的过程和其中遇到的问题. 一.WEB界面开发         ...

  3. VS2013中Python学习笔记[Django Web的第一个网页]

    前言 前面我简单介绍了Python的Hello World.看到有人问我搞搞Python的Web,一时兴起,就来试试看. 第一篇 VS2013中Python学习笔记[环境搭建] 简单介绍Python环 ...

  4. Django之admin中管理models中的表格

    Django之admin中管理models中的表格 django中使用admin管理models中的表格时,如何将表格注册到admin中呢? 具体操作就是在项目文件夹中的app文件夹中的admin中注 ...

  5. Django models中关于blank与null的补充说明

    Django models中关于blank与null的补充说明 建立一个简易Model class Person(models.Model): GENDER_CHOICES=( (1,'Male'), ...

  6. J粒子发现40周年-丁肇中中科院讲座笔记

    J粒子发现40周年-丁肇中中科院讲座笔记 华清远见2014-10-18   北京海淀区  张俊浩 watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQveXVuZm ...

  7. models中,字段参数limit_choices_to的用法

    这里,在使用 ModelForm 渲染前端页面的前提下,对于 models 中的 ManyToManyField 类型字段会在 ModelForm 中被转化为 ModelMultipleChoiceF ...

  8. Django——models中导入数据重复的解决办法

    如果你导入数据过多,导入时出错了,或者你手动停止了,导入了一部分,还有一部分没有导入.或者你再次运行上面的命令,你会发现数据重复了,怎么办呢? django.db.models 中还有一个函数叫 ge ...

  9. Ant Design框架中不同的组件访问不同的models中的数据

    Ant Design框架中不同的组件访问不同的models中的数据 本文记录了我在使用该框架的时候踩过的坑,方便以后查阅. 一.models绑定 在某个组件(控件或是页面),要想从某个models中获 ...

随机推荐

  1. Python爬取51job实例

    用Python爬取51job里面python相关职业.工作地址和薪资. 51job上的信息 程序代码 from bs4 import BeautifulSoup from urllib.request ...

  2. c3p0 获取数据源

    getDataSourcec3p0Resource private static void f3Resource() throws Exception { Connection conn = getD ...

  3. 【动画演示】:JS 作用域链不在话下

    作者:Lydia Hallie译者:前端小智来源:dev 点赞再看,养成习惯 本文 GitHub https://github.com/qq44924588... 上已经收录,更多往期高赞文章的分类, ...

  4. js通过cookie对两个没有关系的jsp页面进行传值

    //Cookie取值 function readCookie (name) { var cookieValue = ""; var search = name + "=& ...

  5. 学习:Android框架

      引言 通过前面两篇: Android 开发之旅:环境搭建及HelloWorld Android 开发之旅:HelloWorld项目的目录结构 我 们对android有了个大致的了解,知道如何搭建a ...

  6. smoj2806建筑物

    题面 有R红色立方体,G绿色立方体和B蓝色立方体.每个立方体的边长是1.现在有一个N × N的木板,该板被划分成1×1个单元.现在要把所有的R+G+B个立方体都放在木板上.立方体必须放置在单元格内,单 ...

  7. 关于2017届学长制作分享软件share(失物招领)的使用体验和需改进的内容

    使用体验 1.注册界面 注册界面提示明显,提示用户输入什么类型的密码,而且输入什么样的用户名不限,注册界面色调比较单一,注册内容比较少,而且比较简单,体验感比较好,但注册界面色调和设计全无,使用感一般 ...

  8. 谈一下你对uWSGI和 nginx的理解(原理)

    要注意 WSGI / uwsgi / uWSGI 这三个概念的区分. WSGI是一种通信协议. uwsgi是一种线路协议而不是通信协议,在此常用于在uWSGI服务器与其他网络服务器的数据通信. uWS ...

  9. VScode禁用alt+key触发菜单栏快捷键

    因为用惯了Mac,突然改回Windows,但是已经习惯了按Command键.所以在Windows下把vscode的快捷键全改成alt+key了.但是Windows的alt+key快捷键就比较烦人了.所 ...

  10. Jackson自定义反序列化

    // 设置jackson时间反系列化格式 SimpleModule module = new SimpleModule(); module.addDeserializer(Date.class, ne ...