训练分类器

 

1. 数据

 

处理图像,文本,音频或视频数据时,可以使用将数据加载到 NumPy 数组中的标准 Python 包。 然后,将该数组转换为torch.*Tensor

 
  • 对于图像,Pillow,OpenCV 等包很有用
  • 对于音频,请使用 SciPy 和 librosa 等包
  • 对于文本,基于 Python 或 Cython 的原始加载,或者 NLTK 和 SpaCy 很有用
 

专门针对视觉,一个名为torchvision的包,其中包含用于常见数据集(例如 Imagenet,CIFAR10,MNIST 等)的数据加载器,以及用于图像(即torchvision.datasets和torch.utils.data.DataLoader)的数据转换器

 

我们将使用 CIFAR10 数据集。 它具有以下类别:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”。 CIFAR-10 中的图像尺寸为3x32x32,即尺寸为32x32像素的 3 通道彩色图像

 

数据集来源:CIFAR-10 and CIFAR-100 datasets

airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
 

由于图片地址在国外,以上图片的加载可能不如人意,大致就是这个图像:

 

2. 训练一个分类器

 

我们将会按顺序做以下步骤:

 
  1. 用torchvision 加载和标准化CIFAR10训练和测试数据
  2. 定义一个神经网络
  3. 定义一个损失函数
  4. 使用训练数据训练网络
  5. 使用测试数据测试网络
 

2.1. 加载数据并标准化

 

使用torchvision加载CIFAR10数据十分简单:

In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
 

输出的torchvision数据集是PILImage图像,其范围是[0,1]。我们将它转化为Tensor的标准范围[-1,1]

In [2]:
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) batch_size = 4 trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=0) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
 
Files already downloaded and verified
Files already downloaded and verified
 
  • 注意:如果在Windows上运行并且得到BrankPipeError,请尝试将Torch.utils.Data.Dataloader()的Num_Worker设置为0。官网示例是Num_Worker设置为2
 

让我们显示一下训练的图片:

In [3]:
import matplotlib.pyplot as plt
import numpy as np # functions to show an image def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show() # get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next() # show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))
 
 
dog   frog  dog   cat
 

2.2.定义一个卷积神经网络

In [4]:
import torch.nn as nn
import torch.nn.functional as F class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
 

2.3.定义一个损失函数和优化器

 

让我们使用分类交叉熵损失和带有动量的 SGD

In [5]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
 

2.4.训练网络

 

有趣的事情开始了,我们只需要循环我们的迭代器,并反馈到网络进行优化

In [6]:
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data # zero the parameter gradients
optimizer.zero_grad() # forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step() # print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0 print('Finished Training')
 
[1,  2000] loss: 2.193
[1, 4000] loss: 1.847
[1, 6000] loss: 1.661
[1, 8000] loss: 1.569
[1, 10000] loss: 1.488
[1, 12000] loss: 1.445
[2, 2000] loss: 1.405
[2, 4000] loss: 1.355
[2, 6000] loss: 1.329
[2, 8000] loss: 1.320
[2, 10000] loss: 1.277
[2, 12000] loss: 1.250
Finished Training
 

快速保存训练模型:

In [7]:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
 

2.5.使用测试集测试网络

 

显示测试集中的图像:

In [8]:
dataiter = iter(testloader)
images, labels = dataiter.next() # print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))
 
 
GroundTruth:  cat   ship  ship  plane
 

加载保存的模型:

In [9]:
net = Net()
net.load_state_dict(torch.load(PATH))
Out[9]:
<All keys matched successfully>
 

使用神经网络进行预测:

In [10]:
outputs = net(images)
In [11]:
outputs
Out[11]:
tensor([[-0.4519, -2.6896,  1.1111,  2.4411, -1.2739,  0.9407,  1.2027, -0.9218,
-0.3061, -1.4944],
[ 4.0095, 5.7177, -1.3274, -3.2596, -4.4239, -6.4377, -5.2835, -5.2639,
8.8550, 3.4490],
[ 2.2643, 1.9055, 0.2977, -1.2159, -1.5517, -2.6117, -2.5904, -2.0696,
3.1488, 0.7971],
[ 3.6302, 0.2553, 0.3926, -1.3850, 0.2644, -2.8077, -2.8192, -1.0332,
1.9776, 0.4094]], grad_fn=<AddmmBackward0>)
 

输出是 10 类的能量。 一个类别的能量越高,网络就认为该图像属于特定类别。 因此,让我们获取最高能量的指数:

In [12]:
_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
for j in range(4)))
 
Predicted:  cat   ship  ship  plane
 

此次结果看起来不错

 

我们看看这个网络在整个数据集的表现:

In [13]:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
 
Accuracy of the network on the 10000 test images: 56 %
 

这看起来是比偶然更好(偶然的准确率是10%,即从10个类别中选择一个),看起来这个网络学到了一些东西

 

看看这个这个分类器在哪些类别分类好,哪些类别分类差:

In [14]:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes} # again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1 # print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
 
Accuracy for class: plane is 65.5 %
Accuracy for class: car is 67.1 %
Accuracy for class: bird is 30.4 %
Accuracy for class: cat is 53.5 %
Accuracy for class: deer is 44.2 %
Accuracy for class: dog is 35.9 %
Accuracy for class: frog is 68.2 %
Accuracy for class: horse is 70.3 %
Accuracy for class: ship is 68.9 %
Accuracy for class: truck is 60.4 %
 

2.6.在GPU上训练

 

如果可以使用 CUDA,首先将我们的设备定义为第一个可见的 cuda 设备:

In [15]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Assuming that we are on a CUDA machine, this should print a CUDA device:

print(device)
 
cuda:0
 

然后,这些方法将递归遍历所有模块,并将其参数和缓冲区转换为 CUDA 张量:

In [16]:
net.to(device)
Out[16]:
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
 

还必须将每一步的输入和目标也发送到 GPU:

In [17]:
inputs, labels = data[0].to(device), data[1].to(device)
 

3.参考资料

pytorch学习笔记四之训练分类器的更多相关文章

  1. 莫烦PyTorch学习笔记(四)——回归

    下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...

  2. ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试

    http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试   刚开始学习tf时,我们从 ...

  3. 官网实例详解-目录和实例简介-keras学习笔记四

    官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras   版权声明: ...

  4. C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻

    前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...

  5. IOS学习笔记(四)之UITextField和UITextView控件学习

    IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...

  6. java之jvm学习笔记四(安全管理器)

    java之jvm学习笔记四(安全管理器) 前面已经简述了java的安全模型的两个组成部分(类装载器,class文件校验器),接下来学习的是java安全模型的另外一个重要组成部分安全管理器. 安全管理器 ...

  7. Learning ROS for Robotics Programming Second Edition学习笔记(四) indigo devices

    中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...

  8. Typescript 学习笔记四:回忆ES5 中的类

    中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...

  9. ES6学习笔记<四> default、rest、Multi-line Strings

    default 参数默认值 在实际开发 有时需要给一些参数默认值. 在ES6之前一般都这么处理参数默认值 function add(val_1,val_2){ val_1 = val_1 || 10; ...

  10. muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制

    目录 muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制 eventfd的使用 eventfd系统函数 使用示例 EventLoop对eventfd的封装 工作时序 runInLoo ...

随机推荐

  1. 【Java EE】Day08 HTML&CSS

    一.表单 1.概述 采集数据,与服务器交互包括 form表单,需要提交则必须设置name属性,action方式主要包 get:参数会在地址栏显示(请求行),长度有限 post:参数被封装在请求体,参数 ...

  2. 【实时数仓】Day01-数据采集层:数仓分层、实时需求、架构分析、日志数据采集(采集到指定topic和落盘)、业务数据采集(MySQL-kafka)、Nginx反向代理、Maxwell、Canel

    一.数仓分层介绍 1.实时计算与实时数仓 实时计算实时性高,但无中间结果,导致复用性差 实时数仓基于数据仓库,对数据处理规划.分层,目的是提高数据的复用性 2.电商数仓的分层 ODS:原始日志数据和业 ...

  3. 如何通过Java应用程序压缩PDF文档

    PDF文档是我们日常办公中使用最频繁的文档格式.但因为大多数PDF文档都包含很多页面图像或大量图片,这就导致PDF文档过大,处理起来较为麻烦.PDF文件过大,就会导致传输或者下载的速度变慢,也会增加传 ...

  4. [攻防世界][江苏工匠杯]unseping

    打开靶机对应的url 上来就是代码审计 <?php highlight_file(__FILE__); class ease{ private $method; private $args; f ...

  5. 【转载】VFP编写DLL,并调用

    1. 编制DLL文件 ,保存为Temp.prg Define Class vfptools As Session OlePublic    Procedure Add As Integer    Lp ...

  6. JavaScript 中如何拦截全局 Fetch API 的请求和响应?

    本文翻译自 Intercepting JavaScript Fetch API requests and responses 拦截器是可用于预处理或后处理 HTTP 请求的代码块,有助于全局错误处理. ...

  7. Profiler导致的严重性能问题

    背景 客户反馈系统突然运行非常缓慢,持续了近20分钟的时间,通过SQL专家云定位到有人开启了Profiler导致,但是不能定位是谁开启的,请我们协助分析. 现象  登录SQL专家云,进入实时可视化页面 ...

  8. flask_apscheduler

    网上很多文章包括官网给出来的,更多的是执行的函数和初始话app在同一文件 按照app 路由 模型类 视图 往下延申,项目已经成型,所以上述的方法不适用. (不用操心模型类与app互相导入的错误)下面分 ...

  9. Java学习记录:2022年1月13日(其二)

    Java学习记录:2022年1月13日(其二) ​ 摘要:本篇笔记主要记录了在设计类时的一些注意事项,类加载时类中各个部分的加载顺序以及继承和多态的知识. 目录 Java学习记录:2022年1月13日 ...

  10. Less常用功能使用

    Less 是一门 CSS 预处理语言,它扩充了 CSS 语言,增加了诸如变量.混合(mixin).函数等功能,让 CSS 更易维护.方便制作主题.扩充.Less 可以运行在 Node 或浏览器端. L ...