官方github上已经有了pytorch基础模型的实现,链接

但是其中一些模型,尤其是resnet,都是用函数生成的各个层,自己看起来是真的难受!

所以自己按照caffe的样子,写一个pytorch的resnet18模型,当然和1000分类模型不同,模型做了一些修改,输入48*48的3通道图片,输出7类。

import torch.nn as nn
import torch.nn.functional as F class ResNet18Model(nn.Module):
def __init__(self):
super().__init__() self.bn64_0 = nn.BatchNorm2d(64)
self.bn64_1 = nn.BatchNorm2d(64)
self.bn64_2 = nn.BatchNorm2d(64)
self.bn64_3 = nn.BatchNorm2d(64)
self.bn64_4 = nn.BatchNorm2d(64) self.bn128_0 = nn.BatchNorm2d(128)
self.bn128_1 = nn.BatchNorm2d(128)
self.bn128_2 = nn.BatchNorm2d(128)
self.bn128_3 = nn.BatchNorm2d(128) self.bn256_0 = nn.BatchNorm2d(256)
self.bn256_1 = nn.BatchNorm2d(256)
self.bn256_2 = nn.BatchNorm2d(256)
self.bn256_3 = nn.BatchNorm2d(256) self.bn512_0 = nn.BatchNorm2d(512)
self.bn512_1 = nn.BatchNorm2d(512)
self.bn512_2 = nn.BatchNorm2d(512)
self.bn512_3 = nn.BatchNorm2d(512) self.shortcut_straight_0 = nn.Sequential()
self.shortcut_straight_1 = nn.Sequential()
self.shortcut_straight_2 = nn.Sequential()
self.shortcut_straight_3 = nn.Sequential()
self.shortcut_straight_4 = nn.Sequential() self.shortcut_conv_bn_64_128_0 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.shortcut_conv_bn_128_256_0 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.shortcut_conv_bn_256_512_0 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(512)) self.conv_w3_h3_in3_out64_s1_p1_0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out64_s1_p1_0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out128_s2_p1_0 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in128_out128_s1_p1_0 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in128_out256_s2_p1_0 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in256_out256_s1_p1_0 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in256_out512_s2_p1_0 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in512_out512_s1_p1_0 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.avg_pool_0 = nn.AdaptiveAvgPool2d((1, 1))
self.fc_512_7_0 = nn.Linear(512, 7)
self.dropout_0 = nn.Dropout(p=0.5) def forward(self, x): # 48*48*3
t = self.conv_w3_h3_in3_out64_s1_p1_0(x) #48*48*64
t = self.bn64_0(t)
y1 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_0(y1) #48*48*64
t = self.bn64_1(t)
y2 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_1(y2) #48*48*64
t = self.bn64_2(t)
t += self.shortcut_straight_0(y1)
y3 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_2(y3) #48*48*64
t = self.bn64_3(t)
y4 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_3(y4) #48*48*64
t = self.bn64_4(t)
t += self.shortcut_straight_1(y3)
y5 = F.relu(t) t = self.conv_w3_h3_in64_out128_s2_p1_0(y5) #24*24*128
t = self.bn128_0(t)
y6 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_0(y6) #24*24*128
t = self.bn128_1(t)
t += self.shortcut_conv_bn_64_128_0(y5)
y7 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_1(y7) #24*24*128
t = self.bn128_2(t)
y8 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_2(y8) #24*24*128
t = self.bn128_3(t)
t += self.shortcut_straight_2(y7)
y9 = F.relu(t) t = self.conv_w3_h3_in128_out256_s2_p1_0(y9) #12*12*256
t = self.bn256_0(t)
y10 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_0(y10) #12*12*256
t = self.bn256_1(t)
t += self.shortcut_conv_bn_128_256_0(y9)
y11 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_1(y11) #12*12*256
t = self.bn256_2(t)
y12 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_2(y12) #12*12*256
t = self.bn256_3(t)
t += self.shortcut_straight_3(y11)
y13 = F.relu(t) t = self.conv_w3_h3_in256_out512_s2_p1_0(y13) #6*6*512
t = self.bn512_0(t)
y14 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_0(y14) #6*6*512
t = self.bn512_1(t)
t += self.shortcut_conv_bn_256_512_0(y13)
y15 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_1(y15) #6*6*512
t = self.bn512_2(t)
y16 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_2(y16) #6*6*512
t = self.bn512_3(t)
t += self.shortcut_straight_4(y15)
y17 = F.relu(t) out = self.avg_pool_0(y17) #1*1*512
out = out.view(out.size(0), -1)
out = self.dropout_0(out)
out = self.fc_512_7_0(out) return out if __name__ == '__main__':
net = ResNet18Model()
# print(net) import torch
net_in = torch.rand(1, 3, 48, 48)
net_out = net(net_in)
print(net_out)
print(net_out.size())

  

pytorch resnet实现的更多相关文章

  1. PyTorch ResNet 使用与源码解析

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py 这篇文章首先会简 ...

  2. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  3. 解读 pytorch对resnet的官方实现

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 贴代码 import torch.nn as ...

  4. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. Pytorch构建ResNet

    学了几天Pytorch,大致明白代码在干什么了,贴一下.. import torch from torch.utils.data import DataLoader from torchvision ...

  7. 陈云pytorch学习笔记_用50行代码搭建ResNet

    import torch as t import torch.nn as nn import torch.nn.functional as F from torchvision import mode ...

  8. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  9. 【pytorch】改造resnet为全卷积神经网络以适应不同大小的输入

    为什么resnet的输入是一定的? 因为resnet最后有一个全连接层.正是因为这个全连接层导致了输入的图像的大小必须是固定的. 输入为固定的大小有什么局限性? 原始的resnet在imagenet数 ...

随机推荐

  1. python 拼接字

    在编译脚本的时候,由于脚本的框架是统一写好的,于是乎用上了拼接字的功能, 本脚本实现的是波特率设置的自动化,利用的是正则表达式,TASK函数是统一写好的,此处只做调用 from Args import ...

  2. HDU4358 Boring counting【dsu on tree】

    Boring counting Problem Description In this problem we consider a rooted tree with N vertices. The v ...

  3. 【洛谷 p3371】模板-单源最短路径(图论)

    题目:给出一个有向图,请输出从某一点出发到所有点的最短路径长度. 解法:spfa算法. 1 #include<cstdio> 2 #include<cstdlib> 3 #in ...

  4. Broken robot CodeForces - 24D (三对角矩阵简化高斯消元+概率dp)

    题意: 有一个N行M列的矩阵,机器人最初位于第i行和第j列.然后,机器人可以在每一步都转到另一个单元.目的是转到最底部(第N个)行.机器人可以停留在当前单元格处,向左移动,向右移动或移动到当前位置下方 ...

  5. AtCoder Beginner Contest 184 F - Programming Contest (双向搜索)

    题意:有一个长度为\(n\)的数组,你可以从中选一些数出来使得它们的和不大于\(t\),问能选出来的最大的和是多少. 题解:\(n\)的数据范围是\(40\),直接二进制枚举贴TLE,之前写过这样的一 ...

  6. L2-013 红色警报 (25分) 并查集复杂度

    代码: 1 /* 2 这道题也是简单并查集,并查集复杂度: 3 空间复杂度为O(N),建立一个集合的时间复杂度为O(1),N次合并M查找的时间复杂度为O(M Alpha(N)), 4 这里Alpha是 ...

  7. CF1462-E1. Close Tuples (easy version)

    题意: 给出一个由n个数字组成的数组,先让你找出符合下列条件的子集的数量: 每个子集包含的数字个数为m = 3 这三个数字中的最大值减去最小值不超过k = 2 思路: 首先对给出的数组进行排序,现在假 ...

  8. Gitlab日常维护(三)之Gitlab的备份、迁移、升级

    一.Gitlab的备份 使用Gitlab一键安装包安装Gitlab非常简单, 同样的备份恢复与迁移也非常简单. 使用一条命令即可创建完整的Gitlab备份 [root@gitlab ~]# gitla ...

  9. js create Array ways All In One

    js create Array ways All In One ES6 const arr = [...document.querySelectorAll(`[data-dom="^div& ...

  10. uni-app 实战-打包 &#128230; App

    uni-app 实战-打包 App Android & iOS 证书 广告 refs xgqfrms 2012-2020 www.cnblogs.com 发布文章使用:只允许注册用户才可以访问 ...