官方github上已经有了pytorch基础模型的实现,链接

但是其中一些模型,尤其是resnet,都是用函数生成的各个层,自己看起来是真的难受!

所以自己按照caffe的样子,写一个pytorch的resnet18模型,当然和1000分类模型不同,模型做了一些修改,输入48*48的3通道图片,输出7类。

import torch.nn as nn
import torch.nn.functional as F class ResNet18Model(nn.Module):
def __init__(self):
super().__init__() self.bn64_0 = nn.BatchNorm2d(64)
self.bn64_1 = nn.BatchNorm2d(64)
self.bn64_2 = nn.BatchNorm2d(64)
self.bn64_3 = nn.BatchNorm2d(64)
self.bn64_4 = nn.BatchNorm2d(64) self.bn128_0 = nn.BatchNorm2d(128)
self.bn128_1 = nn.BatchNorm2d(128)
self.bn128_2 = nn.BatchNorm2d(128)
self.bn128_3 = nn.BatchNorm2d(128) self.bn256_0 = nn.BatchNorm2d(256)
self.bn256_1 = nn.BatchNorm2d(256)
self.bn256_2 = nn.BatchNorm2d(256)
self.bn256_3 = nn.BatchNorm2d(256) self.bn512_0 = nn.BatchNorm2d(512)
self.bn512_1 = nn.BatchNorm2d(512)
self.bn512_2 = nn.BatchNorm2d(512)
self.bn512_3 = nn.BatchNorm2d(512) self.shortcut_straight_0 = nn.Sequential()
self.shortcut_straight_1 = nn.Sequential()
self.shortcut_straight_2 = nn.Sequential()
self.shortcut_straight_3 = nn.Sequential()
self.shortcut_straight_4 = nn.Sequential() self.shortcut_conv_bn_64_128_0 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.shortcut_conv_bn_128_256_0 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.shortcut_conv_bn_256_512_0 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(512)) self.conv_w3_h3_in3_out64_s1_p1_0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out64_s1_p1_0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out128_s2_p1_0 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in128_out128_s1_p1_0 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in128_out256_s2_p1_0 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in256_out256_s1_p1_0 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in256_out512_s2_p1_0 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in512_out512_s1_p1_0 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.avg_pool_0 = nn.AdaptiveAvgPool2d((1, 1))
self.fc_512_7_0 = nn.Linear(512, 7)
self.dropout_0 = nn.Dropout(p=0.5) def forward(self, x): # 48*48*3
t = self.conv_w3_h3_in3_out64_s1_p1_0(x) #48*48*64
t = self.bn64_0(t)
y1 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_0(y1) #48*48*64
t = self.bn64_1(t)
y2 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_1(y2) #48*48*64
t = self.bn64_2(t)
t += self.shortcut_straight_0(y1)
y3 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_2(y3) #48*48*64
t = self.bn64_3(t)
y4 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_3(y4) #48*48*64
t = self.bn64_4(t)
t += self.shortcut_straight_1(y3)
y5 = F.relu(t) t = self.conv_w3_h3_in64_out128_s2_p1_0(y5) #24*24*128
t = self.bn128_0(t)
y6 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_0(y6) #24*24*128
t = self.bn128_1(t)
t += self.shortcut_conv_bn_64_128_0(y5)
y7 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_1(y7) #24*24*128
t = self.bn128_2(t)
y8 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_2(y8) #24*24*128
t = self.bn128_3(t)
t += self.shortcut_straight_2(y7)
y9 = F.relu(t) t = self.conv_w3_h3_in128_out256_s2_p1_0(y9) #12*12*256
t = self.bn256_0(t)
y10 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_0(y10) #12*12*256
t = self.bn256_1(t)
t += self.shortcut_conv_bn_128_256_0(y9)
y11 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_1(y11) #12*12*256
t = self.bn256_2(t)
y12 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_2(y12) #12*12*256
t = self.bn256_3(t)
t += self.shortcut_straight_3(y11)
y13 = F.relu(t) t = self.conv_w3_h3_in256_out512_s2_p1_0(y13) #6*6*512
t = self.bn512_0(t)
y14 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_0(y14) #6*6*512
t = self.bn512_1(t)
t += self.shortcut_conv_bn_256_512_0(y13)
y15 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_1(y15) #6*6*512
t = self.bn512_2(t)
y16 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_2(y16) #6*6*512
t = self.bn512_3(t)
t += self.shortcut_straight_4(y15)
y17 = F.relu(t) out = self.avg_pool_0(y17) #1*1*512
out = out.view(out.size(0), -1)
out = self.dropout_0(out)
out = self.fc_512_7_0(out) return out if __name__ == '__main__':
net = ResNet18Model()
# print(net) import torch
net_in = torch.rand(1, 3, 48, 48)
net_out = net(net_in)
print(net_out)
print(net_out.size())

  

pytorch resnet实现的更多相关文章

  1. PyTorch ResNet 使用与源码解析

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py 这篇文章首先会简 ...

  2. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  3. 解读 pytorch对resnet的官方实现

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 贴代码 import torch.nn as ...

  4. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. Pytorch构建ResNet

    学了几天Pytorch,大致明白代码在干什么了,贴一下.. import torch from torch.utils.data import DataLoader from torchvision ...

  7. 陈云pytorch学习笔记_用50行代码搭建ResNet

    import torch as t import torch.nn as nn import torch.nn.functional as F from torchvision import mode ...

  8. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  9. 【pytorch】改造resnet为全卷积神经网络以适应不同大小的输入

    为什么resnet的输入是一定的? 因为resnet最后有一个全连接层.正是因为这个全连接层导致了输入的图像的大小必须是固定的. 输入为固定的大小有什么局限性? 原始的resnet在imagenet数 ...

随机推荐

  1. easy-ui的datagrid

    <div id="magazineGrid"></div> <script> $('#magazineGrid').datagrid({ hei ...

  2. 八:SpringBoot-集成JPA持久层框架,简化数据库操作

    SpringBoot-集成JPA持久层框架,简化数据库操作 1.JPA框架简介 1.1 JPA与Hibernate的关系: 2.SpringBoot整合JPA Spring Data JPA概述: S ...

  3. SparkMLlib—协同过滤推荐算法,电影推荐系统,物品喜好推荐

    SparkMLlib-协同过滤推荐算法,电影推荐系统,物品喜好推荐 一.协同过滤 1.1 显示vs隐式反馈 1.2 实例介绍 1.2.1 数据说明 评分数据说明(ratings.data) 用户信息( ...

  4. centos安装、升级新火狐最新版 31

    1.登录火狐主页 下载最新版本firefox-31.0.tar.bz2 解压: tar -jxvf firefox-31.0.tar.bz2 2.然后把旧版本的firefox卸掉 # yum eras ...

  5. php文件下载的实现(header)

    php文件下载的实现(header) $file_xls=$path;    //   文件的保存路径 $example_name=basename($file_xls);  //获取文件名   he ...

  6. 基于efcore的分表组件开源

    ShardingCore ShardingCore 是一个支持efcore 2.x 3.x 5.x的一个对于数据库分表的一个简易扩展, 目前该库暂未支持分库(未来会支持),仅支持分表,该项目的理念是让 ...

  7. 牛客网暑期ACM多校训练营(第二场)carpet

    传送门:carpet 题意 有一个n*m的地毯,aij表示地毯每格的元素,bij表示地毯每格的价格,要求选取一块价格最大值最小的地毯,并且这块地毯无限铺开之后,原地毯是其子矩阵. 题解 先找到这个矩阵 ...

  8. Codeforces Round #608 (Div. 2) E. Common Number (二分,构造)

    题意:对于一个数\(x\),有函数\(f(x)\),如果它是偶数,则\(x/=2\),否则\(x-=1\),不断重复这个过程,直到\(x-1\),我们记\(x\)到\(1\)的这个过程为\(path( ...

  9. hdu 1045 Fire Net 二分图匹配 && HDU-1281-棋盘游戏

    题意:任意两个个'车'不能出现在同一行或同一列,当然如果他们中间有墙的话那就没有什么事,问最多能放多少个'车' 代码+注释: 1 //二分图最大匹配问题 2 //难点在建图方面,如果这个图里面一道墙也 ...

  10. 洛谷-P1434 [SHOI2002]滑雪 (记忆化搜索)

    题意:有一个\(R*C\)的矩阵,可以从矩阵中的任意一个数开始,每次都可以向上下左右选一个比当前位置小的数走,求走到\(1\)的最长路径长度. 题解:这题很明显看到就知道是dfs,但是直接爆搜会TLE ...