cifar10主要是由32x32的三通道彩色图, 总共10个类别,这里我们使用残差网络构造网络结构

网络结构:

第一层:首先经过一个卷积,归一化,激活 32x32x16 -> 32x32x16

第二层:  通过一多个残差模型

残差模块的网络构造:

如果stride != 1 or in_channel != out_channel, 就构造downsample网络结构进行降采样操作

利用残差模块进行第一次残差卷积, 将downsample传入

连续进行多次的残差卷积

from torchvision import transforms
from torch import nn
# 首先对图片进行数据转换 train_transform = transforms.Compose([
transforms.Scale(40), # 相当于是resize操作,
transforms.RandomHorizontalFlip(), # 表示进行左右的翻转
transforms.RandomCrop(32), #表示进行随机的裁剪
transforms.ToTensor(), # 将数据转换为tensor格式
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 进行-均值 / 标准差, 将数据转换为-1, 1 之间 ]) test_transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
]) def conv3x3(in_channels, out_channels, stride=1):
return nn.Conv2d(in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False) class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, downsample=None):
super(ResidualBlock, self).__init__()
self.conv1 = conv3x3(in_channels, out_channels, stride=1)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(True)
self.conv2 = conv3x3(out_channels, out_channels, stride=1)
self.bn = nn.BatchNorm2d(out_channels)
self.downsample = downsample def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn(x)
out = self.relu(x)
out = self.conv2(x)
out = self.bn(x)
if self.downsample:
residual = self.downsample(x)
out += residual
return self.relu(out) class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=10):
super(ResNet, self).__init__()
self.in_channels = 16
self.conv = conv3x3(3, 16)
self.bn = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(True)
self.layers1 = self.make_block(block, 16, layers[0])
self.layers2 = self.make_block(block, 32, layers[0])
self.layers3 = self.make_block(block, 64, layers[1])
self.avg_pool = nn.AvgPool2d(8)
self.fc = nn.Linear(64, num_classes) def make_block(self, block, out_channels, blocks, stride=1):
downsample = None
if stride != 1 or out_channels != self.in_channels:
downsample = nn.Sequential(conv3x3(self.in_channels, out_channels, stride=stride),
nn.BatchNorm2d(out_channels))
layers = []
layers.append(block(self.in_channels, out_channels, stride=stride, downsample = downsample))
for i in blocks:
layers.append(block(self.out_channels, out_channels, stride=stride, downsample=downsample)) return nn.Sequential(*layers) def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
out = self.layers1(out)
out= self.layers2(out)
out = self.layers3(out)
out = self.avg_pool(out)
out = self.fc(out) return out

pytorch-cifar10分类网络结构的更多相关文章

  1. 深度学习之 cnn 进行 CIFAR10 分类

    深度学习之 cnn 进行 CIFAR10 分类 import torchvision as tv import torchvision.transforms as transforms from to ...

  2. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...

  3. TensorFlow基础笔记(3) cifar10 分类学习

    TensorFlow基础笔记(3) cifar10 分类学习 CIFAR-10 is a common benchmark in machine learning for image recognit ...

  4. TF Boys (TensorFlow Boys ) 养成记(四):TensorFlow 简易 CIFAR10 分类网络

    前面基本上把 TensorFlow 的在图像处理上的基础知识介绍完了,下面我们就用 TensorFlow 来搭建一个分类 cifar10 的神经网络. 首先准备数据: cifar10 的数据集共有 6 ...

  5. 原 CNN--卷积神经网络从R-CNN到Faster R-CNN的理解(CIFAR10分类代码)

    1. 什么是CNN 卷积神经网络(Convolutional Neural Networks, CNN)是一类包含卷积计算且具有深度结构的前馈神经网络(Feedforward Neural Netwo ...

  6. Pytorch1.0入门实战三:ResNet实现cifar-10分类,利用visdom可视化训练过程

    人的理想志向往往和他的能力成正比. —— 约翰逊 最近一直在使用pytorch深度学习框架,很想用pytorch搞点事情出来,但是框架中一些基本的原理得懂!本次,利用pytorch实现ResNet神经 ...

  7. softmax实现cifar10分类

    将cifar10改成单一通道后,套用前面的softmax分类,分类率40%左右,想哭... .caret, .dropup > .btn > .caret { border-top-col ...

  8. caffe搭建--WINDOWS+VS2013下生成caffe并进行cifar10分类测试

    http://blog.csdn.net/naaaa/article/details/52118437 标签: windowsvs2013caffecifar10 2016-08-04 15:33 1 ...

  9. Edgeboard试用 — 基于CIFAR10分类模型的移植

    前言 在上一次的测试中,我们按照官方给的流程,使用EasyDL快速实现了一个具有性别检测功能的人脸识别系统,那么今天,我们将要试一下通过Paddlepaddle从零开始,训练一个自己的多分类模型,并进 ...

  10. Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存

    用pytorch进行文本分类,数据集为keras内置的imdb影评数据(二分类),代码包含六个部分(详见代码) 使用环境: pytorch:1.1.0 cuda:10.0 gpu:RTX2070 (1 ...

随机推荐

  1. 【坑】Mybatis原始获取配置方式,获取配置失败

    错误环境: mysql版本:6.0.6 mybatis 3.4.1 idea 2017.1.2 maven 3.5.0 错误描述: 配置经路径见图1,classpath是java文件夹 获取配置的代码 ...

  2. VUE【一、概述】

    早上写的忘了保存..还有很多唠叨的内容...哎又得重新写一遍..想吐槽那个自动保存有卵用.. 今天周一,早上起来继续 由于周六加了一整天班,导致周日无心学习,一天都在玩游戏看电影,到了晚上反而更加空虚 ...

  3. 第一章、Django概述

    目录 第一章.Django概述 一.了解软件开发架构 二.HTTP协议 三.响应状态码 四.请求方式 五.基于wsgiref模块 六..动静态网页 七.python三大主流web框架 八.安装Djan ...

  4. springboot Invalid character found in the request target. The valid characters are defined in RFC 7230 and RFC 3986

    报错如下: 在请求目标中发现无效字符.有效字符在RFC 7230和RFC 3986中定义. 原因是Tomcat在 7.0.73, 8.0.39, 8.5.7 版本后,添加了对于http头的验证. 就是 ...

  5. git回退到历史版本

    问题描述 在开发的过程中,想要修改一个参数的命名.然后修改各种地方,并且push上码云的远程仓库.然后突然发现还要改很多地方,突然后悔不想改动了.那该怎么办呢? 处理步骤 回退本地的git版本 将本地 ...

  6. jQuery表单验证正则表达式-简单

    <html xmlns="http://www.w3.org/1999/xhtml"> <head> <meta http-equiv="C ...

  7. linux usb驱动记录(一)

    一.linux 下的usb驱动框架 在linux系统中,usb驱动可以从两个角度去观察,一个是主机侧,一个是设备侧.linux usb 驱动的总体框架如下图所示:   从主机侧看usb驱动可分为四层: ...

  8. Is there a difference between `==` and `is` in Python?

    is will return True if two variables point to the same object, == if the objects referred to by the ...

  9. C#信号量(Semaphore,SemaphoreSlim)

    Object->MarshalByRefObject->WaitHandle->Semaphore 1.作用: 多线程环境下,可以控制线程的并发数量来限制对资源的访问 2.举例: S ...

  10. js过滤时间格式

    Date.prototype.Format = function(fmt) { //author: meizz var o = { "M+" : this.getMonth()+1 ...