CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

先要对该数据集进行分类

步骤如下
1.使用torchvision加载并预处理CIFAR-10数据集、
2.定义网络
3.定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络

 import torchvision as tv            #里面含有许多数据集
import torch
import torchvision.transforms as transforms #实现图片变换处理的包
from torchvision.transforms import ToPILImage #使用torchvision加载并预处理CIFAR10数据集
show = ToPILImage() #可以把Tensor转成Image,方便进行可视化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [, ] -> [0.0,1.0]
trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=,shuffle=True,num_workers=)
testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=,shuffle=True,num_workers=)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
(data,label) = trainset[]
print(classes[label])#输出ship
show((data+)/).resize((,))
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(' '.join('%11s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid((images+)/)).resize((,))#make_grid的作用是将若干幅图像拼成一幅图像 #定义网络
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(,,)
self.conv2 = nn.Conv2d(,,)
self.fc1 = nn.Linear(**,)
self.fc2 = nn.Linear(,)
self.fc3 = nn.Linear(,)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(,))
x = F.max_pool2d(F.relu(self.conv2(x)),)
x = x.view(x.size()[],-)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
print(net) #定义损失函数和优化器
from torch import optim
criterion = nn.CrossEntropyLoss()#定义交叉熵损失函数
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9) #训练网络
from torch.autograd import Variable
for epoch in range():
running_loss = 0.0
for i, data in enumerate(trainloader, ):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % ==:
print('[%d, %5d] loss: %.3f'%(epoch+,i+,running_loss/))
running_loss = 0.0
print("----------finished training---------")
dataiter = iter(testloader)
images, labels = dataiter.next()
print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid(images/ - 0.5)).resize((,))#?????
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data,)#返回最大值和其索引
print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range()))
correct =
total =
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, )
total +=labels.size()
correct +=(predicted == labels).sum()
print('10000张测试集中的准确率为: %d %%'%(*correct/total))
if torch.cuda.is_available():
net.cuda()
images = images.cuda()
labels = labels.cuda()
output = net(Variable(images))
loss = criterion(output, Variable(labels))

学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。

这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。

用pytorch进行CIFAR-10数据集分类的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. 单向LSTM笔记, LSTM做minist数据集分类

    单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...

  3. PyTorch深度学习实践——多分类问题

    多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...

  4. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  5. Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes

    Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...

  6. Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression

    Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...

  7. Python实现鸢尾花数据集分类问题——基于skearn的SVM

    Python实现鸢尾花数据集分类问题——基于skearn的SVM 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = 'Xiaoli ...

  8. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  9. 6.keras-基于CNN网络的Mnist数据集分类

    keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...

  10. pytorch构建自己的数据集

    现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...

随机推荐

  1. spark算子之Aggregate

    Aggregate函数 一.源码定义 /** * Aggregate the elements of each partition, and then the results for all the ...

  2. Map集合类(二.其他map集合jdk1.8)

    java集合笔记一 java集合笔记二 java集合笔记三 1.hashtable(线程安全) 1.存储数据为数组+链表2.存储键值对或获取时通过hash值取模数组长度确定节点在数组中的下标位置 in ...

  3. php随机生成数字加字母的字符串

    function getRandomString($len, $chars=null) { if (is_null($chars)) { $chars = "ABCDEFGHIJKLMNOP ...

  4. ASP.NET 服务器控件对应的HTML标签

    label----------<span/> button---------<input type="submit"/> textbox--------&l ...

  5. ubuntu终端仿真程序和文件管理程序

    1.SecureCRT是一款支持SSH(SSH1和SSH2)的终端仿真程序,简单的说是Windows下登录UNIX或Linux服务器主机的软件.可以理解为ubuntu下的Terminal. 如果Sec ...

  6. java求两个数中的大数

    java求两个数中的大数 java中的max函数在Math中 应用如下: int a=34: int b=45: int ans=Math.max(34,45); 那么ans的值就是45.

  7. (转)C#中String跟string的“区别”

    string是c#中的类,String是.net Framework的类(在C# IDE中不会显示蓝色) C# string映射为.net Framework的String 如果用string,编译器 ...

  8. PAT_A1101#Quick Sort

    Source: PAT A1101 Quick Sort (25 分) Description: There is a classical process named partition in the ...

  9. 1060 Are They Equal (25 分)

    1060 Are They Equal (25 分)   If a machine can save only 3 significant digits, the float numbers 1230 ...

  10. Codeforces 1176B - Merge it!

    题目链接:http://codeforces.com/problemset/problem/1176/B 题意:给定序列,任意俩个元素可以相加成一个元素,求序列元素能被3整除的最大数量. 思路: 对于 ...