用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

先要对该数据集进行分类
步骤如下:
1.使用torchvision加载并预处理CIFAR-10数据集、
2.定义网络
3.定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络
import torchvision as tv #里面含有许多数据集
import torch
import torchvision.transforms as transforms #实现图片变换处理的包
from torchvision.transforms import ToPILImage #使用torchvision加载并预处理CIFAR10数据集
show = ToPILImage() #可以把Tensor转成Image,方便进行可视化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [, ] -> [0.0,1.0]
trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=,shuffle=True,num_workers=)
testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=,shuffle=True,num_workers=)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
(data,label) = trainset[]
print(classes[label])#输出ship
show((data+)/).resize((,))
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(' '.join('%11s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid((images+)/)).resize((,))#make_grid的作用是将若干幅图像拼成一幅图像 #定义网络
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(,,)
self.conv2 = nn.Conv2d(,,)
self.fc1 = nn.Linear(**,)
self.fc2 = nn.Linear(,)
self.fc3 = nn.Linear(,)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(,))
x = F.max_pool2d(F.relu(self.conv2(x)),)
x = x.view(x.size()[],-)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
print(net) #定义损失函数和优化器
from torch import optim
criterion = nn.CrossEntropyLoss()#定义交叉熵损失函数
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9) #训练网络
from torch.autograd import Variable
for epoch in range():
running_loss = 0.0
for i, data in enumerate(trainloader, ):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % ==:
print('[%d, %5d] loss: %.3f'%(epoch+,i+,running_loss/))
running_loss = 0.0
print("----------finished training---------")
dataiter = iter(testloader)
images, labels = dataiter.next()
print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid(images/ - 0.5)).resize((,))#?????
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data,)#返回最大值和其索引
print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range()))
correct =
total =
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, )
total +=labels.size()
correct +=(predicted == labels).sum()
print('10000张测试集中的准确率为: %d %%'%(*correct/total))
if torch.cuda.is_available():
net.cuda()
images = images.cuda()
labels = labels.cuda()
output = net(Variable(images))
loss = criterion(output, Variable(labels))
学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。
这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。
用pytorch进行CIFAR-10数据集分类的更多相关文章
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- PyTorch深度学习实践——多分类问题
多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
- Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression
Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...
- Python实现鸢尾花数据集分类问题——基于skearn的SVM
Python实现鸢尾花数据集分类问题——基于skearn的SVM 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = 'Xiaoli ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- pytorch构建自己的数据集
现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...
随机推荐
- vue 插槽 slot
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- Case Closed?
Finally, we'll want a way to test whether a file we've opened is closed. Sometimes we'll have a lot ...
- 【leetcode】970. Powerful Integers
题目如下: Given two non-negative integers x and y, an integer is powerful if it is equal to x^i + y^j fo ...
- python读取ini配置文件的示例代码(仅供参考)
这篇文章主要介绍了python读取ini配置文件过程示范,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下 安装 pip install configp ...
- cocos2D-X Download
{ http://cocos2d-x.org/download https://github.com/cocos2d/cocos2d-x https://github.com/fusijie/Coco ...
- 使用springBoot和mybatis整合时出现如下错误:org.apache.ibatis.binding.BindingException: Invalid bound statement (not found)解决方案
在pom.xml文件中添加如下: <build> <resources> <resource> & ...
- 暴力字符串hash——cf1200E
#include<bits/stdc++.h> using namespace std; #define ll long long #define N 1000005 #define mo ...
- API应用实例
API声明透明 {API声明} type TSetLayeredWindowAttributes = function(wnd: HWND; crKey: DWORD; bAlpha: BYTE; d ...
- 转-C++之虚函数不能定义成内联函数的原因
转自:https://blog.csdn.net/flydreamforever/article/details/61429140 在C++中,inline关键字和virtual关键字分别用来定义c+ ...
- 数据可视化(Echart) :柱状图、折线图、饼图等六种基本图表的特点及适用场合
数据可视化(Echart) 柱状图.折线图.饼图等六种基本图表的特点及适用场合 参考网址 效果图 源码 <!DOCTYPE html> <html> <head> ...