官方github上已经有了pytorch基础模型的实现,链接

但是其中一些模型,尤其是resnet,都是用函数生成的各个层,自己看起来是真的难受!

所以自己按照caffe的样子,写一个pytorch的resnet18模型,当然和1000分类模型不同,模型做了一些修改,输入48*48的3通道图片,输出7类。

import torch.nn as nn
import torch.nn.functional as F class ResNet18Model(nn.Module):
def __init__(self):
super().__init__() self.bn64_0 = nn.BatchNorm2d(64)
self.bn64_1 = nn.BatchNorm2d(64)
self.bn64_2 = nn.BatchNorm2d(64)
self.bn64_3 = nn.BatchNorm2d(64)
self.bn64_4 = nn.BatchNorm2d(64) self.bn128_0 = nn.BatchNorm2d(128)
self.bn128_1 = nn.BatchNorm2d(128)
self.bn128_2 = nn.BatchNorm2d(128)
self.bn128_3 = nn.BatchNorm2d(128) self.bn256_0 = nn.BatchNorm2d(256)
self.bn256_1 = nn.BatchNorm2d(256)
self.bn256_2 = nn.BatchNorm2d(256)
self.bn256_3 = nn.BatchNorm2d(256) self.bn512_0 = nn.BatchNorm2d(512)
self.bn512_1 = nn.BatchNorm2d(512)
self.bn512_2 = nn.BatchNorm2d(512)
self.bn512_3 = nn.BatchNorm2d(512) self.shortcut_straight_0 = nn.Sequential()
self.shortcut_straight_1 = nn.Sequential()
self.shortcut_straight_2 = nn.Sequential()
self.shortcut_straight_3 = nn.Sequential()
self.shortcut_straight_4 = nn.Sequential() self.shortcut_conv_bn_64_128_0 = nn.Sequential(nn.Conv2d(64, 128, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(128)) self.shortcut_conv_bn_128_256_0 = nn.Sequential(nn.Conv2d(128, 256, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(256)) self.shortcut_conv_bn_256_512_0 = nn.Sequential(nn.Conv2d(256, 512, kernel_size=1, stride=2, bias=False),nn.BatchNorm2d(512)) self.conv_w3_h3_in3_out64_s1_p1_0 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out64_s1_p1_0 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in64_out64_s1_p1_3 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in64_out128_s2_p1_0 = nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in128_out128_s1_p1_0 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_1 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in128_out128_s1_p1_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in128_out256_s2_p1_0 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in256_out256_s1_p1_0 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_1 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in256_out256_s1_p1_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False) self.conv_w3_h3_in256_out512_s2_p1_0 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False) self.conv_w3_h3_in512_out512_s1_p1_0 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False)
self.conv_w3_h3_in512_out512_s1_p1_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False) self.avg_pool_0 = nn.AdaptiveAvgPool2d((1, 1))
self.fc_512_7_0 = nn.Linear(512, 7)
self.dropout_0 = nn.Dropout(p=0.5) def forward(self, x): # 48*48*3
t = self.conv_w3_h3_in3_out64_s1_p1_0(x) #48*48*64
t = self.bn64_0(t)
y1 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_0(y1) #48*48*64
t = self.bn64_1(t)
y2 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_1(y2) #48*48*64
t = self.bn64_2(t)
t += self.shortcut_straight_0(y1)
y3 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_2(y3) #48*48*64
t = self.bn64_3(t)
y4 = F.relu(t) t = self.conv_w3_h3_in64_out64_s1_p1_3(y4) #48*48*64
t = self.bn64_4(t)
t += self.shortcut_straight_1(y3)
y5 = F.relu(t) t = self.conv_w3_h3_in64_out128_s2_p1_0(y5) #24*24*128
t = self.bn128_0(t)
y6 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_0(y6) #24*24*128
t = self.bn128_1(t)
t += self.shortcut_conv_bn_64_128_0(y5)
y7 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_1(y7) #24*24*128
t = self.bn128_2(t)
y8 = F.relu(t) t = self.conv_w3_h3_in128_out128_s1_p1_2(y8) #24*24*128
t = self.bn128_3(t)
t += self.shortcut_straight_2(y7)
y9 = F.relu(t) t = self.conv_w3_h3_in128_out256_s2_p1_0(y9) #12*12*256
t = self.bn256_0(t)
y10 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_0(y10) #12*12*256
t = self.bn256_1(t)
t += self.shortcut_conv_bn_128_256_0(y9)
y11 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_1(y11) #12*12*256
t = self.bn256_2(t)
y12 = F.relu(t) t = self.conv_w3_h3_in256_out256_s1_p1_2(y12) #12*12*256
t = self.bn256_3(t)
t += self.shortcut_straight_3(y11)
y13 = F.relu(t) t = self.conv_w3_h3_in256_out512_s2_p1_0(y13) #6*6*512
t = self.bn512_0(t)
y14 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_0(y14) #6*6*512
t = self.bn512_1(t)
t += self.shortcut_conv_bn_256_512_0(y13)
y15 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_1(y15) #6*6*512
t = self.bn512_2(t)
y16 = F.relu(t) t = self.conv_w3_h3_in512_out512_s1_p1_2(y16) #6*6*512
t = self.bn512_3(t)
t += self.shortcut_straight_4(y15)
y17 = F.relu(t) out = self.avg_pool_0(y17) #1*1*512
out = out.view(out.size(0), -1)
out = self.dropout_0(out)
out = self.fc_512_7_0(out) return out if __name__ == '__main__':
net = ResNet18Model()
# print(net) import torch
net_in = torch.rand(1, 3, 48, 48)
net_out = net(net_in)
print(net_out)
print(net_out.size())

  

pytorch resnet实现的更多相关文章

  1. PyTorch ResNet 使用与源码解析

    本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson8/resnet_inference.py 这篇文章首先会简 ...

  2. [源码解读] ResNet源码解读(pytorch)

    自己看读完pytorch封装的源码后,自己又重新写了一边(模仿其书写格式), 一些问题在代码中说明. import torch import torchvision import argparse i ...

  3. 解读 pytorch对resnet的官方实现

    地址:https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 贴代码 import torch.nn as ...

  4. 【深度学习】基于Pytorch的ResNet实现

    目录 1. ResNet理论 2. pytorch实现 2.1 基础卷积 2.2 模块 2.3 使用ResNet模块进行迁移学习 1. ResNet理论 论文:https://arxiv.org/pd ...

  5. ResNet网络的Pytorch实现

    1.文章原文地址 Deep Residual Learning for  Image Recognition 2.文章摘要 神经网络的层次越深越难训练.我们提出了一个残差学习框架来简化网络的训练,这些 ...

  6. Pytorch构建ResNet

    学了几天Pytorch,大致明白代码在干什么了,贴一下.. import torch from torch.utils.data import DataLoader from torchvision ...

  7. 陈云pytorch学习笔记_用50行代码搭建ResNet

    import torch as t import torch.nn as nn import torch.nn.functional as F from torchvision import mode ...

  8. PyTorch对ResNet网络的实现解析

    PyTorch对ResNet网络的实现解析 1.首先导入需要使用的包 import torch.nn as nn import torch.utils.model_zoo as model_zoo # ...

  9. 【pytorch】改造resnet为全卷积神经网络以适应不同大小的输入

    为什么resnet的输入是一定的? 因为resnet最后有一个全连接层.正是因为这个全连接层导致了输入的图像的大小必须是固定的. 输入为固定的大小有什么局限性? 原始的resnet在imagenet数 ...

随机推荐

  1. Language Guide (proto3) | proto3 语言指南(十二)定义服务

    Defining Services - 定义服务 如果要在RPC(Remote Procedure Call,远程过程调用)系统中使用消息类型,可以在.proto文件中定义RPC服务接口,协议缓冲区编 ...

  2. MIT 6.S081 Lab File System

    前言 打开自己的blog一看,居然三个月没更新了...回想一下前几个月,开题 + 实验室杂活貌似也没占非常多的时间,还是自己太懈怠了吧,掉线城和文明6真的是时间刹手( 不过好消息是把15445的所有l ...

  3. spark整合Phoenix相关案例

    spark 读取Phoenix hbase table表到 DataFrame的方式 Demo1: 方式一:spark read读取各数据库的通用方式 方式二:spark.load 方式三:phoen ...

  4. spark SQL (四)数据源 Data Source----Parquet 文件的读取与加载

    spark SQL Parquet 文件的读取与加载 是由许多其他数据处理系统支持的柱状格式.Spark SQL支持阅读和编写自动保留原始数据模式的Parquet文件.在编写Parquet文件时,出于 ...

  5. Linux常用命令详解(第三章)(ping、kill、seq、du、df、free、date、tar)

    本章命令(共7个): 1 2 3 4 5 6 7 8 ping kill seq du df free date tar 1." ping " 作用:向网络主机发送ICMP(检测主 ...

  6. Broken robot CodeForces - 24D (三对角矩阵简化高斯消元+概率dp)

    题意: 有一个N行M列的矩阵,机器人最初位于第i行和第j列.然后,机器人可以在每一步都转到另一个单元.目的是转到最底部(第N个)行.机器人可以停留在当前单元格处,向左移动,向右移动或移动到当前位置下方 ...

  7. 分块 && 例题 I Hate It HDU - 1754

    分块算法: 分块就是对暴力方法的一种优化:                          _ 假设我们总共的序列长度为n,然后我们把它切成√n 块,然后把每一块里的东西当成一个整体来看,完整块:被 ...

  8. 最新版gradle安装使用简介

    目录 简介 安装gradle和解决gradle安装的问题 Gradle特性 标准task Build phases Gradle Wrapper wrapper的使用 wrapper的升级 一个简单的 ...

  9. Linux系统CentOS进入单用户模式和救援模式详解

    一.概述 目前在运维日常工作中,经常会遇到服务器异常断电.忘记root密码.系统引导文件损坏无法进入系统等等操作系统层面的问题,给运维带来诸多不便,现将上述现象的解决方法和大家分享一下,本次主要以Ce ...

  10. Lightoj 1038 - Race to 1 Again【期望+dp】

    题目:戳这里 题意:一个数字n不断迭代地除以自身的因子得到1.求这个过程中操作除法次数的期望. 解题思路: 求概率基本都是从一个最基础的状态开始延伸推出公式,得出答案.因为每个数都有个共同的最终状态1 ...