用pytorch进行CIFAR-10数据集分类
CIFAR-10.(Canadian Institute for Advanced Research)是由 Alex Krizhevsky、Vinod Nair 与 Geoffrey Hinton 收集的一个用于图像识别的数据集,60000个32*32的彩色图像,50000个training data,10000个 test data 有10类,飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船、卡车,每类6000张图。与MNIST相比,色彩、颜色噪点较多,同一类物体大小不一、角度不同、颜色不同。

先要对该数据集进行分类
步骤如下:
1.使用torchvision加载并预处理CIFAR-10数据集、
2.定义网络
3.定义损失函数和优化器
4.训练网络并更新网络参数
5.测试网络
import torchvision as tv #里面含有许多数据集
import torch
import torchvision.transforms as transforms #实现图片变换处理的包
from torchvision.transforms import ToPILImage #使用torchvision加载并预处理CIFAR10数据集
show = ToPILImage() #可以把Tensor转成Image,方便进行可视化
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize(mean = (0.5,0.5,0.5),std = (0.5,0.5,0.5))])#把数据变为tensor并且归一化range [, ] -> [0.0,1.0]
trainset = tv.datasets.CIFAR10(root='data1/',train = True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=,shuffle=True,num_workers=)
testset = tv.datasets.CIFAR10('data1/',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,batch_size=,shuffle=True,num_workers=)
classes = ('plane','car','bird','cat','deer','dog','frog','horse','ship','truck')
(data,label) = trainset[]
print(classes[label])#输出ship
show((data+)/).resize((,))
dataiter = iter(trainloader)
images, labels = dataiter.next()
print(' '.join('%11s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid((images+)/)).resize((,))#make_grid的作用是将若干幅图像拼成一幅图像 #定义网络
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(,,)
self.conv2 = nn.Conv2d(,,)
self.fc1 = nn.Linear(**,)
self.fc2 = nn.Linear(,)
self.fc3 = nn.Linear(,)
def forward(self,x):
x = F.max_pool2d(F.relu(self.conv1(x)),(,))
x = F.max_pool2d(F.relu(self.conv2(x)),)
x = x.view(x.size()[],-)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
print(net) #定义损失函数和优化器
from torch import optim
criterion = nn.CrossEntropyLoss()#定义交叉熵损失函数
optimizer = optim.SGD(net.parameters(),lr = 0.001,momentum=0.9) #训练网络
from torch.autograd import Variable
for epoch in range():
running_loss = 0.0
for i, data in enumerate(trainloader, ):#enumerate将其组成一个索引序列,利用它可以同时获得索引和值,enumerate还可以接收第二个参数,用于指定索引起始值
inputs, labels = data
inputs, labels = Variable(inputs), Variable(labels)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % ==:
print('[%d, %5d] loss: %.3f'%(epoch+,i+,running_loss/))
running_loss = 0.0
print("----------finished training---------")
dataiter = iter(testloader)
images, labels = dataiter.next()
print('实际的label: ',' '.join('%08s'%classes[labels[j]] for j in range()))
show(tv.utils.make_grid(images/ - 0.5)).resize((,))#?????
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data,)#返回最大值和其索引
print('预测结果:',' '.join('%5s'%classes[predicted[j]] for j in range()))
correct =
total =
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, )
total +=labels.size()
correct +=(predicted == labels).sum()
print('10000张测试集中的准确率为: %d %%'%(*correct/total))
if torch.cuda.is_available():
net.cuda()
images = images.cuda()
labels = labels.cuda()
output = net(Variable(images))
loss = criterion(output, Variable(labels))
学习率太大会很难逼近最优值,所以要注意在数据集小的情况下学习率尽量小一些,epoch尽量大一些。
这个例子是陈云的深度学习pytorch框架书上的一个demo,运行该代码需要注意的是数据集的下载问题,因为运行程序很可能数据集下载很慢或者直接下载失败,因此推荐使用迅雷根据指定网址直接下载,半分钟就可以下载好。
用pytorch进行CIFAR-10数据集分类的更多相关文章
- 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow
原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...
- 单向LSTM笔记, LSTM做minist数据集分类
单向LSTM笔记, LSTM做minist数据集分类 先介绍下torch.nn.LSTM()这个API 1.input_size: 每一个时步(time_step)输入到lstm单元的维度.(实际输入 ...
- PyTorch深度学习实践——多分类问题
多分类问题 目录 多分类问题 Softmax 在Minist数据集上实现多分类问题 作业 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩 ...
- 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化
一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...
- Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes
Python实现鸢尾花数据集分类问题——基于skearn的NaiveBayes 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = ...
- Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression
Python实现鸢尾花数据集分类问题——基于skearn的LogisticRegression 一. 逻辑回归 逻辑回归(Logistic Regression)是用于处理因变量为分类变量的回归问题, ...
- Python实现鸢尾花数据集分类问题——基于skearn的SVM
Python实现鸢尾花数据集分类问题——基于skearn的SVM 代码如下: # !/usr/bin/env python # encoding: utf-8 __author__ = 'Xiaoli ...
- 3.keras-简单实现Mnist数据集分类
keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...
- 6.keras-基于CNN网络的Mnist数据集分类
keras-基于CNN网络的Mnist数据集分类 1.数据的载入和预处理 import numpy as np from keras.datasets import mnist from keras. ...
- pytorch构建自己的数据集
现在需要在json文件里面读取图片的URL和label,这里面可能会出现某些URL地址无效的情况. python读取json文件 此处只需要将json文件里面的内容读取出来就可以了 with open ...
随机推荐
- Python常用三方库安装
//首先更新pip python -m pip install --upgrade pip //一个类似Matlab的Plot绘制数据图的库. python -m pip install matplo ...
- php操作redis--字典(hash)篇
常用函数:hSet,hGet,hGetAll等. 应用场景:存储用户信息对象数据,包括id,姓名,年龄和生日,通过用户id来获取姓名,年龄等信息. 连接 $redis = new Redis(); $ ...
- Delphi GlobalAlloc、GlobalLock、GlobalUnlock、GlobalFree 函数
GlobalAlloc 函数 分配一块内存,该函数会返回分配的内存句柄. GlobalLock 函数 锁定内存块,该函数接受一个内存句柄作为参数,然后返回一个指向被锁定的内存块的指针. 您可以用该指针 ...
- 【Codeforces Beta Round #88 C】Cycle
[Link]:http://codeforces.com/problemset/problem/117/C [Description] 问你一张图里面有没有一个三元环,有的话就输出. [Solutio ...
- PHP导出excel word的代码
php导出为word原理 一般,有2种方法可以导出doc文档,一种是使用com,并且作为php的一个扩展库安装到服务器上,然后创建一个com,调用它的方法.安装过office的服务器可以调用一个叫wo ...
- HA配置
主T800 eth0 192.168.2.32备T600 eth1 192.168.2.33安装nginx yum install -y nginx 关闭主备的防火墙iptables.selinux ...
- pycharm中django同步数据库问题
一.Django数据同步过程中遇到的问题: 1.raise ImproperlyConfigured('mysqlclient 1.3.13 or newer is required; you hav ...
- CSS中复选框单选框与常用12px文字不对齐问题(转载)
原文 http://www.cnblogs.com/aobingyan/p/3823556.html,有删改 目前中文网站上面的文字,绝大多数网站的主流文字大小为12px,因为在目前高分辨率显示器屏 ...
- 二、springcloud微服务测试环境搭建
版本说明: springcloud:Greenwich.SR3 springboot:2.1.8 1.构建步骤 1.1.microservicecloud整体父工程Project 新建父工程micro ...
- DDOS到底是什么,怎么预防,看看就明白了
可怕的DDOS怎么预防 分布式拒绝服务(DDoS: distributed denial-of-service)攻击是恶意破坏目标服务器.服务或网络的正常通信量的企图,其方法是用大量Internet通 ...