CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

先要对该数据集进行分类

步骤如下
1.使用torchvision加载并预处理CIFAR-10数据集、
2.定义网络
3.定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络

 import torchvision as tv            #里面含有许多数据集
import torch
import torchvision.transforms as transforms #实现图片变换处理的包
from torchvision.transforms import ToPILImage #使用torchvision加载并预处理CIFAR10数据集
show = ToPILImage() #可以把Tensor转成Image,方便进行可视化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [, ] -> [0.0,1.0]
trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=,shuffle=True,num_workers=)
testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=,shuffle=True,num_workers=)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
(data,label) = trainset[]
print(classes[label])#输出ship
show((data+)/).resize((,))
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(' '.join('%11s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid((images+)/)).resize((,))#make_grid的作用是将若干幅图像拼成一幅图像 #定义网络
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(,,)
self.conv2 = nn.Conv2d(,,)
self.fc1 = nn.Linear(**,)
self.fc2 = nn.Linear(,)
self.fc3 = nn.Linear(,)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(,))
x = F.max_pool2d(F.relu(self.conv2(x)),)
x = x.view(x.size()[],-)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
print(net) #定义损失函数和优化器
from torch import optim
criterion = nn.CrossEntropyLoss()#定义交叉熵损失函数
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9) #训练网络
from torch.autograd import Variable
for epoch in range():
running_loss = 0.0
for i, data in enumerate(trainloader, ):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % ==:
print('[%d, %5d] loss: %.3f'%(epoch+,i+,running_loss/))
running_loss = 0.0
print("----------finished training---------")
dataiter = iter(testloader)
images, labels = dataiter.next()
print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid(images/ - 0.5)).resize((,))#?????
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data,)#返回最大值和其索引
print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range()))
correct =
total =
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, )
total +=labels.size()
correct +=(predicted == labels).sum()
print('10000张测试集中的准确率为: %d %%'%(*correct/total))
if torch.cuda.is_available():
net.cuda()
images = images.cuda()
labels = labels.cuda()
output = net(Variable(images))
loss = criterion(output, Variable(labels))

学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。

这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。

用pytorch进行CIFAR-10数据集分类的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  3. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  4. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  5. Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes

    Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...

  6. Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression

    Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...

  7. Python实现鸢尾花数据集分类问题——基于skearn的SVM

    Python实现鸢尾花数据集分类问题——基于skearn的SVM 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = 'Xiaoli ...

  8. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  9. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  10. pytorch构建自己的数据集

    现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...

随机推荐

  1. 服务注册与发现---spring cloud

    Eureka基本架构 Register Service :服务注册中心,它是一个 Eureka Server ,提供服务注册和发现的功能. Provider Service :服务提供者,它是 Eur ...

  2. Fiddler,对数据进行抓包,拦截,修改等操作

    转载....... 一.fiddler设置fiddler默认是只能抓取http网络格式的,所以我们要先设置下使fiddler可以获取到https网络格式 首先tools→options→https进去 ...

  3. psecurity配置

    <?xml version="1.0" encoding="UTF-8"?><beans xmlns="http://www.spr ...

  4. python的代码块缓存机制,小数据池机制。

    同一代码块的缓存机制 在python中一个模块,一个函数,一个类,一个文件等都是一个代码块. 机制内容:Python在执行同一个代码块的初始化对象的命令时,会检查是否其值是否已经存在,如果存在,会将其 ...

  5. delphi 文件操作(信息获取)

    delphi获取Exe文件版本信息的函数 Type TFileVersionInfo = Record FixedInfo:TVSFixedFileInfo; {版本信息} CompanyName:S ...

  6. apk签名原理及实现

    发布过Android应用的朋友们应该都知道,Android APK的发布是需要签名的.签名机制在Android应用和框架中有着十分重要的作用. 例如,Android系统禁止更新安装签名不一致的APK: ...

  7. 【SVN】提交报错:×××文件is not under version control

    解决方法:1.删除出错的文件,然后在出错文件所在文件夹执行还原操作 2.VS中可将文件先排除在项目外,再包含在项目内,即可正常提交

  8. (动态改变数据源遇到的问题)ORACLE11g:No Dialect mapping for JDBC type: -9解决方案

    在动态改变数据源时 hibernate配置不能使用Oracle官方的方言(org.hibernate.dialect.Oracle10gDialect) 做法写一个方言扩展类,缺什么类型,添加什么类型 ...

  9. Codeforces Round #536 E. Lunar New Year and Red Envelopes /// 贪心 记忆化搜索 multiset取最大项

    题目大意: 给定n m k:(1≤n≤1e5, 0≤m≤200, 1≤k≤1e5) 表示n个时间长度内 最多被打扰m次 k个红包 接下来k行描述红包 s t d w:(1≤s≤t≤d≤n , 1≤w≤ ...

  10. ArcGis相接面补节点c#

    相接(Touch)面执行切割后 新面与原相接面会缺少公共节点. private void AddPointToTouchesPolygon(IFeatureCursor newFeatureCurso ...