PyTorch教程之Training a classifier
我们已经了解了如何定义神经网络,计算损失并对网络的权重进行更新。
接下来的问题就是:
一、What about data?
通常处理图像、文本、音频或视频数据时,可以使用标准的python包将数据加载到numpy数组中。然后你可以将这个数组转换成一个torch.Tensor.
对于图片, 涉及到的库有Pillowh和OpenCV。
对于音频,涉及到的库有scipy和librosa
对于文本,无论是原始的Python还是基于Cython的加载,都会用到NLTK或者SpaCy。
我们已经创建了一个名为torchvision的软件包。
torchvision为像Imagenet、CIFAR10、MNIST等普通数据集提供数据加载器,并为图像、viz、torchvision提供数据转换器,也就是torchvision.datasets torch.utils.data.DataLoader.
我们在这里使用的是CIFAR10数据集。它的类包括:“飞机”、“汽车”、“鸟”、“猫”、“鹿”、“狗”、“青蛙”、“马”、“船”、“卡车”。cifar - 10中的图像大小为3x32x32,即3 - channel彩色图像,大小为32x32像素。
二、Training an image classifier
我们将按顺序进行以下步骤:
1使用torchvision对CIFAR10训练和测试数据集进行加载和规范化
2.定义一个卷积神经网络
3.定义一个损失函数
4.在训练数据上训练神经网络
5.在测试数据上测试神经网络
1加载并规范化CIFAR10
import相关类:
import torch
import torchvision
import torchvision.transforms as transforms
创建transform来处理图像数据
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
下载训练数据集到./data/目录下:
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
查看下载的数据:
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
image, label = trainset[0]
print(image.size())
print(label)
print(classes[label])
输出结果:
torch.Size([3, 32, 32])
6
frog
torchvision数据集的输出是范围[0,1]的PILImage图像。我们将它们转换为标准化范围的Tensor[- 1,1]:
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
同理,我们下载测试数据集并将其转化为Tensor:
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
2.定义一个卷积神经网络
从PyTorch教程之Neural Networks复制代码,然后修改成获取3通道图像(而不是原本定义为1通道的图像)。
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
3.定义一个损失函数
这里使用的损失函数为Classification Cross-Entropy loss and SGD with momentum:
import torch.optim as optim criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4.在训练数据上训练神经网络
我们只需要对数据迭代器进行循环,并将输入反馈到网络并进行优化。
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs
inputs, labels = data
# wrap them in Variable
inputs, labels = Variable(inputs), Variable(labels)
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.data[0]
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0 print('Finished Training')
输出结果:
[1, 2000] loss: 2.224
[1, 4000] loss: 1.896
[1, 6000] loss: 1.721
[1, 8000] loss: 1.591
[1, 10000] loss: 1.542
[1, 12000] loss: 1.471
[2, 2000] loss: 1.411
[2, 4000] loss: 1.377
[2, 6000] loss: 1.334
[2, 8000] loss: 1.316
[2, 10000] loss: 1.290
[2, 12000] loss: 1.281
5.在测试数据上测试神经网络
我们将通过预测神经网络输出的类标签来检查它,如果预测是正确的,我们将样本添加到正确预测的列表中。
获取前四个测试数据的GroundTruth:
dataiter = iter(testloader)
images, labels = dataiter.next() print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
输出结果:
GroundTruth: cat ship ship plane
神经网络输出是10类分别对应的energy,一个类的 energy量越高,神经网络就认为图像属于该类可能性越高,我们将energy最高的类作为我们预测结果:
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1) print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
输出结果:
Predicted: cat ship ship ship
我们在整个测试数据集上进行测试:
correct = 0
total = 0
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum() print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
输出结果显示正确率为56%
Accuracy of the network on the 10000 test images: 56 %
我们对不同的类识别效果进行分别统计:
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
for data in testloader:
images, labels = data
outputs = net(Variable(images))
_, predicted = torch.max(outputs.data, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i]
class_total[label] += 1 for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
结果显示:
Accuracy of plane : 52 %
Accuracy of car : 73 %
Accuracy of bird : 45 %
Accuracy of cat : 26 %
Accuracy of deer : 39 %
Accuracy of dog : 42 %
Accuracy of frog : 73 %
Accuracy of horse : 73 %
Accuracy of ship : 75 %
Accuracy of truck : 63 %
PyTorch教程之Training a classifier的更多相关文章
- pytorch例子学习——TRAINING A CLASSIFIER
参考:https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#sphx-glr-beginner-blitz-cifar1 ...
- PyTorch教程之Neural Networks
我们可以通过torch.nn package构建神经网络. 现在我们已经了解了autograd,nn基于autograd来定义模型并对他们有所区分. 一个 nn.Module模块由如下部分构成:若干层 ...
- PyTorch教程之Autograd
在PyTorch中,autograd是所有神经网络的核心内容,为Tensor所有操作提供自动求导方法. 它是一个按运行方式定义的框架,这意味着backprop是由代码的运行方式定义的. 一.Varia ...
- PyTorch教程之Tensors
Tensors类似于numpy的ndarrays,但是可以在GPU上使用来加速计算. 一.Tensors的构建 from __future__ import print_function import ...
- Kail Linux渗透测试教程之Recon-NG框架
Kail Linux渗透测试教程之Recon-NG框架 信息收集 信息收集是网络攻击最重要的阶段之一.要想进行渗透攻击,就需要收集目标的各类信息.收集到的信息越多,攻击成功的概率也就越大.本章将介绍信 ...
- [转]搬瓦工教程之九:通过Net-Speeder为搬瓦工提升网速
搬瓦工教程之九:通过Net-Speeder为搬瓦工提升网速 有的同学反映自己的搬瓦工速度慢,丢包率高.这其实和你的网络服务提供商有关.据我所知一部分上海电信的同学就有这种问题.那么碰到了坑爹的网络服务 ...
- jQuery EasyUI教程之datagrid应用(三)
今天继续之前的整理,上篇整理了datagrid的数据显示及其分页功能 获取数据库数据显示在datagrid中:jQuery EasyUI教程之datagrid应用(一) datagrid实现分页功能: ...
- jQuery EasyUI教程之datagrid应用(二)
上次写到了让数据库数据在网页datagrid显示,我们只是单纯的实现了显示,仔细看的话显示的信息并没有达到我们理想的效果,这里我们丰富一下: 上次显示的结果是这样的 点击查看上篇:jQuery Eas ...
- jQuery EasyUI教程之datagrid应用(一)
最近一段时间都在做人事系统的项目,主要用到了EasyUI,数据库操作,然后抽点时间整理一下EasyUI的内容. 这里我们就以一个简洁的电话簿软件为基础,具体地说一下datagrid应用吧 datagr ...
随机推荐
- 迈向angularjs2系列(2):angular2指令详解
一:angular2 helloworld! 为了简单快速的运行一个ng2的app,那么通过script引入预先编译好的angular2版本和页面的基本框架. index.html: <!DOC ...
- AHD-模拟高清芯片RN6752替代TVP5150/CJC5150
RN6752功能:2路CVBS转BT656/601/1302高清视频解码芯片.产品特征输入:2路CVBS(不能同时输入)和S-VIDEO(一般不用)输出:BT601/656/1302支持NTSC,PA ...
- [js高手之路]Vue2.0基于vue-cli+webpack同级组件之间的通信教程
我们接着上文继续,本文我们讲解兄弟组件的通信,项目结构还是跟上文一样. 在src/assets目录下建立文件EventHandler.js,该文件的作用在于给同级组件之间传递事件 EventHandl ...
- JavaScript中你所不知道的Object(一)
Object实在是JavaScript中很基础的东西了,在工作中,它只有那么贫瘠的几个用法,让人感觉不过尔尔,但是我们真的了解它吗? 1. 当我们习惯用 var a = { name: 'tarol' ...
- 团队作业1——团队展示&博客作业查重系统
团队展示: 1.队名:六个核桃 2.队员学号: 王婧(201421123065).柯怡芳(201421123067组长).陈艺菡(201421123068). 钱惠(201421123071).尼玛( ...
- 团队作业8——第二次项目冲刺(Beta阶段)5.27
1.当天站立式会议照片 会议内容: 本次会议为第七次会议 本次会议在陆大楼2楼召开,本次会议内容: ①:检查总结上次任务完成情况 ②:安排今天的分工 ③:对昨天的问题进行讨论 2. 每个人的工作 (有 ...
- 201521123044 《Java程序设计》第3周学习总结
1. 本章学习总结 2. 书面作业 1. 代码阅读 public class Test1 { private int i = 1;//这行不能修改 private static int j = 2; ...
- 201521123028 《Java程序设计》第3周学习总结
1. 本周学习总结 2. 书面作业 Q1.代码阅读 public class Test1 { private int i = 1;//这行不能修改 private static int j = 2; ...
- 下载安装ActiveMQ(消息队列)
安装步骤: 第一步.安装jdk环境,因为ActiveMQ是使用java语言编写. 第二步.将下载好的activemq压缩包上传至Linux系统,进行解压. 第三步.进入解压后的bin/目录,进行启动a ...
- 才趟过的一个坑,css造成的Validform表单提交按钮点击无效
最近入手的一个项目,在开发的过程中,遇到了一个以前没遇到过的问题,废了半天的功夫才弄懂原因,留下足迹,警醒后人,下面开始讲故事啦! 在一个昏天暗地的上午,我一个人照常坐在办公室安静的工作中!项目编码已 ...