[源码解读] ResNet源码解读(pytorch)
自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明。
import torch
import torchvision
import argparse
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms, models
import torch.utils.model_zoo as model_zoo
import math
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
}
def conv3x3(in_planes, out_planes, stride=1):
# 3x3 kernel
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
# get BasicBlock which layers < 50(18, 34)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(in_planes, planes, stride)
self.BN = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes, stride) # outplane is not in_planes*self.expansion, is planes
self.stride = stride
self.downsample = downsample
def forward(self, x):
residual = x # mark the data before BasicBlock
x = self.conv1(x)
x = self.BN(x)
x = self.relu(x)
x = self.conv2(x)
x = self.BN(x) # BN operation is before relu operation
if self.downsample is not None: # is not None
residual = self.downsample(residual) # resize the channel
x += residual
x = self.relu(x)
return x
# get BottleBlock which layers >= 50
class Bottleneck(nn.Module):
expansion = 4 # the factor of the last layer of BottleBlock and the first layer of it
def __init__(self, in_planes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.con2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes*4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*4)
self.downsample = downsample
self.stride = stride
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.con2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
residual = self.downsample(residual)
x += residual
x = self.relu(x)
return x
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=100):
self.inplanes = 64 # the original channel
super(ResNet, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# 以下构建残差块, 具体参数可以查看resnet参数表
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.average_pool = nn.AvgPool2d(7, stride=1)
self.fc = nn.Linear(512*block.expansion, num_classes)
# 对卷积和与BN层初始化,论文中也提到过
for m in self.modules():
if isinstance(m, nn.Conv2d):
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# 这里是为了结局两个残差块之间可能维度不匹配无法直接相加的问题,相同类型的残差块只需要改变第一个输入的维数就好,后面的输入维数都等于输出维数
def _make_layer(self, block, planes, num_blocks, stride=1):
downsample = None
# 扩维
if stride != 1 or self.inplanes != block.expansion * planes:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, block.expansion*planes,kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(block.expansion*planes)
)
layers = []
# 特判第一残差块
layers.append(block(self.inplanes, planes, downsample=downsample)) # outplane is planes not planes*block.expansion
self.inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.max_pool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.average_pool(x)
x = x.view(x.size(0), -1) # resize batch-size x H
x = self.fc(x)
return x
def resnet18(pretrained=False, **kwargs):
"""Constructs a ResNet-18 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
return model
def resnet34(pretrained=False, **kwargs):
"""Constructs a ResNet-34 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
return model
def resnet50(pretrained=False, **kwargs):
"""Constructs a ResNet-50 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
return model
def resnet101(pretrained=False, **kwargs):
"""Constructs a ResNet-101 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
return model
def resnet152(pretrained=False, **kwargs):
"""Constructs a ResNet-152 model.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
return model
[源码解读] ResNet源码解读(pytorch)的更多相关文章
- RxJava系列6(从微观角度解读RxJava源码)
RxJava系列1(简介) RxJava系列2(基本概念及使用介绍) RxJava系列3(转换操作符) RxJava系列4(过滤操作符) RxJava系列5(组合操作符) RxJava系列6(从微观角 ...
- 入口开始,解读Vue源码(一)-- 造物创世
Why? 网上现有的Vue源码解析文章一搜一大批,但是为什么我还要去做这样的事情呢?因为觉得纸上得来终觉浅,绝知此事要躬行. 然后平时的项目也主要是Vue,在使用Vue的过程中,也对其一些约定产生了一 ...
- JVM源码分析之SystemGC完全解读
JVM源码分析之SystemGC完全解读 概述 JVM的GC一般情况下是JVM本身根据一定的条件触发的,不过我们还是可以做一些人为的触发,比如通过jvmti做强制GC,通过System.gc触发,还可 ...
- Spring源码-循环依赖源码解读
Spring源码-循环依赖源码解读 笔者最近无论是看书还是从网上找资料,都没发现对Spring源码是怎么解决循环依赖这一问题的详解,大家都是解释了Spring解决循环依赖的想法(有的解释也不准确,在& ...
- Derek解读Bytom源码-持久化存储LevelDB
作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...
- Derek解读Bytom源码-创世区块
作者:Derek 简介 Github地址:https://github.com/Bytom/bytom Gitee地址:https://gitee.com/BytomBlockchain/bytom ...
- Redux学习之解读applyMiddleware源码深入middleware工作机制
随笔前言 在上一周的学习中,我们熟悉了如何通过redux去管理数据,而在这一节中,我们将一起深入到redux的知识中学习. 首先谈一谈为什么要用到middleware 我们知道在一个简单的数据流场景中 ...
- SpringMVC源码解读 - RequestMapping注解实现解读 - RequestMappingInfo
使用@RequestMapping注解时,配置的信息最后都设置到了RequestMappingInfo中. RequestMappingInfo封装了PatternsRequestCondition, ...
- SpringMVC源码解读 - RequestMapping注解实现解读 - RequestCondition体系
一般我们开发时,使用最多的还是@RequestMapping注解方式. @RequestMapping(value = "/", param = "role=guest& ...
随机推荐
- Python--格式化输出%s和%d
https://www.cnblogs.com/claidx/p/7253288.html pythn print格式化输出. %r 用来做 debug 比较好,因为它会显示变量的原始数据(raw d ...
- 【题解】P5151 HKE与他的小朋友
[题解]P5151 HKE与他的小朋友 实际上,位置的关系可以看做一组递推式,\(f(a_i)=f(a_j),f(a_j)=f(a_t),etc...\)那么我们可以压进一个矩阵里面. 考虑到这个矩阵 ...
- 【opencv入门篇】 10个程序快速上手opencv【上】
导言:本系列博客目的在于能够在vs快速上手opencv,理论知识涉及较少,大家有兴趣可以查阅其他博客深入了解相关的理论知识,本博客后续也会对图像方向的理论进一步分析,敬请期待:) PS:官方文档永远是 ...
- K-medodis聚类算法MATLAB
国内博客,上介绍实现的K-medodis方法为: 与K-means算法类似.只是距离选择与聚类中心选择不同. 距离为曼哈顿距离 聚类中心选择为:依次把一个聚类中的每一个点当作当前类的聚类中心,求出代价 ...
- 发现一个小技巧:火狐浏览器对phpmyadmin支持更友好
这段时间ytkah正在迁移服务器(A→B),为了方便起见,直接用phpmyadmin导入数据库.一般我们是用navicat来操作数据库的,但是服务器A设置了权限,无法用navicat连接,只好在浏览器 ...
- C的指针疑惑:C和指针13(高级指针话题)上
int *f(); f为一个函数,返回值类型是一个指向整形的指针. int (*f)(); 两对括号,第二对括号是函数调用操作符,但第一对括号只起到聚组的作用. f为一个函数指针,它所指向的函数返回一 ...
- UVA10700:Camel trading(栈和队列)
题目链接:http://acm.hust.edu.cn/vjudge/contest/view.action?cid=68990#problem/J 题目大意: 给一个没有加上括号的表达式且只有+ , ...
- Linux系统——Raid磁盘阵列
Raid磁盘阵列 作用:解决磁盘速度.安全问题 Raid原理 Raid0 写入速度极快,有几块硬盘,写入速度就近似几倍,但是安全性极差,只要一块盘坏了,所有盘的数据全部坏掉,最少两块硬盘组合 性价比最 ...
- Django:学习笔记(8)——视图
Django:学习笔记(8)——视图
- ZOJ - 3229 Shoot the Bullet (有源汇点上下界最大流)
题意:要在n天里给m个女生拍照,每个女生有拍照数量的下限Gi,每天有拍照数量的上限Di,每天当中每个人有拍照的上限Lij和Rij.求在满足限制的基础上,所有人最大能拍多少张照片. 分析:抛开限制,显然 ...