官方github上已经有了pytorch基础模型的实现,链接

但是其中一些模型,尤其是resnet,都是用函数生成的各个层,自己看起来是真的难受!

所以自己按照caffe的样子,写一个pytorch的resnet18模型,当然和1000分类模型不同,模型做了一些修改,输入48*48的3通道图片,输出7类。

import torch.nn as nn
import torch.nn.functional as F class ResNet18Model(nn.Module):
def __init__(self):
super().__init__() self.bn64_0 = nn.BatchNorm2d(64)
self.bn64_1 = nn.BatchNorm2d(64)
self.bn64_2 = nn.BatchNorm2d(64)
self.bn64_3 = nn.BatchNorm2d(64)
self.bn64_4 = nn.BatchNorm2d(64) self.bn128_0 = nn.BatchNorm2d(128)
self.bn128_1 = nn.BatchNorm2d(128)
self.bn128_2 = nn.BatchNorm2d(128)
self.bn128_3 = nn.BatchNorm2d(128) self.bn256_0 = nn.BatchNorm2d(256)
self.bn256_1 = nn.BatchNorm2d(256)
self.bn256_2 = nn.BatchNorm2d(256)
self.bn256_3 = nn.BatchNorm2d(256) self.bn512_0 = nn.BatchNorm2d(512)
self.bn512_1 = nn.BatchNorm2d(512)
self.bn512_2 = nn.BatchNorm2d(512)
self.bn512_3 = nn.BatchNorm2d(512) self.shortcut_straight_0 = nn.Sequential()
self.shortcut_straight_1 = nn.Sequential()
self.shortcut_straight_2 = nn.Sequential()
self.shortcut_straight_3 = nn.Sequential()
self.shortcut_straight_4 = nn.Sequential() self.shortcut_conv_bn_64_128_0 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.shortcut_conv_bn_128_256_0 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.shortcut_conv_bn_256_512_0 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(512)) self.conv_w3_h3_in3_out64_s1_p1_0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out64_s1_p1_0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out128_s2_p1_0 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in128_out128_s1_p1_0 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in128_out256_s2_p1_0 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in256_out256_s1_p1_0 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in256_out512_s2_p1_0 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in512_out512_s1_p1_0 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.avg_pool_0 = nn.AdaptiveAvgPool2d((1, 1))
self.fc_512_7_0 = nn.Linear(512, 7)
self.dropout_0 = nn.Dropout(p=0.5) def forward(self, x): # 48*48*3
t = self.conv_w3_h3_in3_out64_s1_p1_0(x) #48*48*64
t = self.bn64_0(t)
y1 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_0(y1) #48*48*64
t = self.bn64_1(t)
y2 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_1(y2) #48*48*64
t = self.bn64_2(t)
t += self.shortcut_straight_0(y1)
y3 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_2(y3) #48*48*64
t = self.bn64_3(t)
y4 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_3(y4) #48*48*64
t = self.bn64_4(t)
t += self.shortcut_straight_1(y3)
y5 = F.relu(t) t = self.conv_w3_h3_in64_out128_s2_p1_0(y5) #24*24*128
t = self.bn128_0(t)
y6 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_0(y6) #24*24*128
t = self.bn128_1(t)
t += self.shortcut_conv_bn_64_128_0(y5)
y7 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_1(y7) #24*24*128
t = self.bn128_2(t)
y8 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_2(y8) #24*24*128
t = self.bn128_3(t)
t += self.shortcut_straight_2(y7)
y9 = F.relu(t) t = self.conv_w3_h3_in128_out256_s2_p1_0(y9) #12*12*256
t = self.bn256_0(t)
y10 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_0(y10) #12*12*256
t = self.bn256_1(t)
t += self.shortcut_conv_bn_128_256_0(y9)
y11 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_1(y11) #12*12*256
t = self.bn256_2(t)
y12 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_2(y12) #12*12*256
t = self.bn256_3(t)
t += self.shortcut_straight_3(y11)
y13 = F.relu(t) t = self.conv_w3_h3_in256_out512_s2_p1_0(y13) #6*6*512
t = self.bn512_0(t)
y14 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_0(y14) #6*6*512
t = self.bn512_1(t)
t += self.shortcut_conv_bn_256_512_0(y13)
y15 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_1(y15) #6*6*512
t = self.bn512_2(t)
y16 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_2(y16) #6*6*512
t = self.bn512_3(t)
t += self.shortcut_straight_4(y15)
y17 = F.relu(t) out = self.avg_pool_0(y17) #1*1*512
out = out.view(out.size(0), -1)
out = self.dropout_0(out)
out = self.fc_512_7_0(out) return out if __name__ == '__main__':
net = ResNet18Model()
# print(net) import torch
net_in = torch.rand(1, 3, 48, 48)
net_out = net(net_in)
print(net_out)
print(net_out.size())

  

pytorch resnet实现的更多相关文章

  1. PyTorch ResNet 使用与源码解析

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py 这篇文章首先会简 ...

  2. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  3. 解读 pytorch对resnet的官方实现

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 贴代码 import torch.nn as ...

  4. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. Pytorch构建ResNet

    学了几天Pytorch,大致明白代码在干什么了,贴一下.. import torch from torch.utils.data import DataLoader from torchvision ...

  7. 陈云pytorch学习笔记_用50行代码搭建ResNet

    import torch as t import torch.nn as nn import torch.nn.functional as F from torchvision import mode ...

  8. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  9. 【pytorch】改造resnet为全卷积神经网络以适应不同大小的输入

    为什么resnet的输入是一定的? 因为resnet最后有一个全连接层.正是因为这个全连接层导致了输入的图像的大小必须是固定的. 输入为固定的大小有什么局限性? 原始的resnet在imagenet数 ...

随机推荐

  1. python模块----paramicko模块 (ssh远程主机并命令或传文件)

    paramiko模块 paramicko模块是非标准库模块,需要pip下载 paramicko:模拟ssh登陆linux主机,也有上传下载功能.ansible自动化部署软件底层就有应用paramick ...

  2. php小程序-文章发布系统(mvc框架)

    php小程序-文章发布系统(mvc框架) 一 项目视图 二 项目经验 通过对mvc微型框架的实现,对mvc理论加深,有利于以后框架的学习 三 项目源码 http://files.cnblogs.com ...

  3. 【Spring-Security】Re01 入门上手

    一.所需的组件 SpringBoot项目需要的POM依赖: <dependency> <groupId>org.springframework.boot</groupId ...

  4. git的几种实用操作(合并代码与暂存复原代码)

    总述     git工具也用了很久,自己也写了几篇使用教程,今天继续给大家分享一些我工作中使用过的git操作. 1.git合并远程仓库的代码 2.git stash保存当前的修改 这两种情况大家应该都 ...

  5. B - B Silver Cow Party (最短路+转置)

    有n个农场,编号1~N,农场里奶牛将去X号农场.这N个农场之间有M条单向路(注意),通过第i条路将需要花费Ti单位时间.选择最短时间的最优路径来回一趟,花费在去的路上和返回农场的这些最优路径的最长时间 ...

  6. Codeforces Round #666 (Div. 2) Power Sequence、Multiples of Length 思维

    题目链接:Power Sequence 题意: 给你n个数vi,你可以对这个序列进行两种操作 1.可以改变其中任意个vi的位置,无成本 2.可以对vi进行加1或减1,每次操作成本为1 如果操作之后的v ...

  7. poj2001 Shortest Prefixes (trie树)

    Description A prefix of a string is a substring starting at the beginning of the given string. The p ...

  8. Slim Span POJ 3522 (最小差值生成树)

    题意: 最小生成树找出来最小的边权值总和使得n个顶点都连在一起.那么这找出来的边权值中的最大权值和最小权值之差就是本题的结果 但是题目要求让这个输出的结果最小,也就是差值最小.那么这就不是最小生成树了 ...

  9. 快速获取 Wi-Fi 密码——GitHub 热点速览 v.21.06

    作者:HelloGitHub-小鱼干 还有 2 天开启春节七天宅家生活,GitHub 也凑了一把春节热闹,wifi-password 连续霸榜 3 天,作为一个能快速让你连上 Wi-Fi 的小工具,春 ...

  10. Kubernets二进制安装(12)之部署Node节点服务的kube-Proxy

    kube-proxy是Kubernetes的核心组件,部署在每个Node节点上,它是实现Kubernetes Service的通信与负载均衡机制的重要组件; kube-proxy负责为Pod创建代理服 ...