[源码解读] ResNet源码解读(pytorch)
自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明。
import torch
import torchvision
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.utils.model_zoo as model_zoo
import math
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
# 3x3 kernel
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
# get BasicBlock which layers < 50(18, 34)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.BN = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride) # outplane is not in_planes*self.expansion, is planes
self.stride = stride
self.downsample = downsample
def forward(self, x):
residual = x # mark the data before BasicBlock
x = self.conv1(x)
x = self.BN(x)
x = self.relu(x)
x = self.conv2(x)
x = self.BN(x) # BN operation is before relu operation
if self.downsample is not None: # is not None
residual = self.downsample(residual) # resize the channel
x += residual
x = self.relu(x)
return x
# get BottleBlock which layers >= 50
class Bottleneck(nn.Module):
expansion = 4 # the factor of the last layer of BottleBlock and the first layer of it
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.con2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*4)
self.downsample = downsample
self.stride = stride
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.con2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.relu(x)
return x
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=100):
self.inplanes = 64 # the original channel
super(ResNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 以下构建残差块, 具体参数可以查看resnet参数表
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.average_pool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512*block.expansion, num_classes)
# 对卷积和与BN层初始化,论文中也提到过
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# 这里是为了结局两个残差块之间可能维度不匹配无法直接相加的问题,相同类型的残差块只需要改变第一个输入的维数就好,后面的输入维数都等于输出维数
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
# 扩维
if stride != 1 or self.inplanes != block.expansion * planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, block.expansion*planes,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(block.expansion*planes)
)
layers = []
# 特判第一残差块
layers.append(block(self.inplanes, planes, downsample=downsample)) # outplane is planes not planes*block.expansion
self.inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.average_pool(x)
x = x.view(x.size(0), -1) # resize batch-size x H
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
[源码解读] ResNet源码解读(pytorch)的更多相关文章
- RxJava系列6(从微观角度解读RxJava源码)
RxJava系列1(简介) RxJava系列2(基本概念及使用介绍) RxJava系列3(转换操作符) RxJava系列4(过滤操作符) RxJava系列5(组合操作符) RxJava系列6(从微观角 ...
- 入口开始,解读Vue源码(一)-- 造物创世
Why? 网上现有的Vue源码解析文章一搜一大批,但是为什么我还要去做这样的事情呢?因为觉得纸上得来终觉浅,绝知此事要躬行. 然后平时的项目也主要是Vue,在使用Vue的过程中,也对其一些约定产生了一 ...
- JVM源码分析之SystemGC完全解读
JVM源码分析之SystemGC完全解读 概述 JVM的GC一般情况下是JVM本身根据一定的条件触发的,不过我们还是可以做一些人为的触发,比如通过jvmti做强制GC,通过System.gc触发,还可 ...
- Spring源码-循环依赖源码解读
Spring源码-循环依赖源码解读 笔者最近无论是看书还是从网上找资料,都没发现对Spring源码是怎么解决循环依赖这一问题的详解,大家都是解释了Spring解决循环依赖的想法(有的解释也不准确,在& ...
- Derek解读Bytom源码-持久化存储LevelDB
作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...
- Derek解读Bytom源码-创世区块
作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...
- Redux学习之解读applyMiddleware源码深入middleware工作机制
随笔前言 在上一周的学习中,我们熟悉了如何通过redux去管理数据,而在这一节中,我们将一起深入到redux的知识中学习. 首先谈一谈为什么要用到middleware 我们知道在一个简单的数据流场景中 ...
- SpringMVC源码解读 - RequestMapping注解实现解读 - RequestMappingInfo
使用@RequestMapping注解时,配置的信息最后都设置到了RequestMappingInfo中. RequestMappingInfo封装了PatternsRequestCondition, ...
- SpringMVC源码解读 - RequestMapping注解实现解读 - RequestCondition体系
一般我们开发时,使用最多的还是@RequestMapping注解方式. @RequestMapping(value = "/", param = "role=guest& ...
随机推荐
- jquery的强大选择器
$("#myELement") 选择id值等于myElement的元素,id值不能重复在文档中只能有一个id值是myElement所以得到的是唯一的元素 $("di ...
- js 获取Array数组 最大值 最小值
https://stackoverflow.com/questions/1669190/find-the-min-max-element-of-an-array-in-javascript // 错误 ...
- <Android 开源库> GreenDAO 使用方法具体解释<译文>
简单介绍 greenDAO是一个开源的Android ORM,使SQLite数据库的开发再次变得有趣. 它减轻了开发者处理底层的数据库需求,同一时候节省开发时间. SQLite是一个非常不错的关系型数 ...
- (0.2.5)Mysql安装——RPM方式安装
rpm安装mysql 卸载与安装服务端 一.安装服务端与客户端 #查看RPM包中所有的文件shell> rpm -qpl mysql-community-server-version-dis ...
- 前端 javascript 写代码方式
javascript 和python一样可以用终端写代码 写Js代码: - html文件中编写 - 临时,浏览器的终端 console
- 注册表REG文件编写大全
Windows 中的注册表文件( system.dat 和 user.dat )是 Windows 的核心数据库,因此,对 Windows 来说是非常重要的. 通过修改注册表文件中的数据,可以达到优化 ...
- 008-Hadoop Hive sql语法详解3-DML 操作:元数据存储
一.概述 hive不支持用insert语句一条一条的进行插入操作,也不支持update操作.数据是以load的方式加载到建立好的表中.数据一旦导入就不可以修改. DML包括:INSERT插入.UPDA ...
- JQuery变量数字相加的研究
在 jquery中,一个变量和一个数字相加,期望得到1+2=3但是: eg: <input type="text" class="input_Num" i ...
- 47. Permutations II (全排列有重复的元素)
Given a collection of numbers that might contain duplicates, return all possible unique permutations ...
- Linux下修改时间
修改linux的时间可以使用date指令 date命令的功能是显示和设置系统日期和时间. 输入date 查看目前系统时间. 修改时间需要 date -功能字符 修改内容 命令中各选项的含义分别为: - ...