pytorch学习笔记四之训练分类器
训练分类器¶
1. 数据¶
处理图像,文本,音频或视频数据时,可以使用将数据加载到 NumPy 数组中的标准 Python 包。 然后,将该数组转换为torch.*Tensor
- 对于图像,Pillow,OpenCV 等包很有用
- 对于音频,请使用 SciPy 和 librosa 等包
- 对于文本,基于 Python 或 Cython 的原始加载,或者 NLTK 和 SpaCy 很有用
专门针对视觉,一个名为torchvision的包,其中包含用于常见数据集(例如 Imagenet,CIFAR10,MNIST 等)的数据加载器,以及用于图像(即torchvision.datasets和torch.utils.data.DataLoader)的数据转换器
我们将使用 CIFAR10 数据集。 它具有以下类别:“飞机”,“汽车”,“鸟”,“猫”,“鹿”,“狗”,“青蛙”,“马”,“船”,“卡车”。 CIFAR-10 中的图像尺寸为3x32x32,即尺寸为32x32像素的 3 通道彩色图像
数据集来源:CIFAR-10 and CIFAR-100 datasets
airplane | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
automobile | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
bird | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
cat | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
deer | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
dog | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
frog | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
horse | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
ship | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
truck | ![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
![]() |
由于图片地址在国外,以上图片的加载可能不如人意,大致就是这个图像:
2. 训练一个分类器¶
我们将会按顺序做以下步骤:
- 用torchvision 加载和标准化CIFAR10训练和测试数据
- 定义一个神经网络
- 定义一个损失函数
- 使用训练数据训练网络
- 使用测试数据测试网络
2.1. 加载数据并标准化¶
使用torchvision加载CIFAR10数据十分简单:
import torch
import torchvision
import torchvision.transforms as transforms
输出的torchvision数据集是PILImage图像,其范围是[0,1]。我们将它转化为Tensor的标准范围[-1,1]
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=0)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=0)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified
- 注意:如果在Windows上运行并且得到BrankPipeError,请尝试将Torch.utils.Data.Dataloader()的Num_Worker设置为0。官网示例是Num_Worker设置为2
让我们显示一下训练的图片:
import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

dog frog dog cat
2.2.定义一个卷积神经网络¶
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1) # flatten all dimensions except batch
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
2.3.定义一个损失函数和优化器¶
让我们使用分类交叉熵损失和带有动量的 SGD
import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
2.4.训练网络¶
有趣的事情开始了,我们只需要循环我们的迭代器,并反馈到网络进行优化
for epoch in range(2): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print(f'[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
print('Finished Training')
[1, 2000] loss: 2.193
[1, 4000] loss: 1.847
[1, 6000] loss: 1.661
[1, 8000] loss: 1.569
[1, 10000] loss: 1.488
[1, 12000] loss: 1.445
[2, 2000] loss: 1.405
[2, 4000] loss: 1.355
[2, 6000] loss: 1.329
[2, 8000] loss: 1.320
[2, 10000] loss: 1.277
[2, 12000] loss: 1.250
Finished Training
快速保存训练模型:
PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)
2.5.使用测试集测试网络¶
显示测试集中的图像:
dataiter = iter(testloader)
images, labels = dataiter.next()
# print images
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

GroundTruth: cat ship ship plane
加载保存的模型:
net = Net()
net.load_state_dict(torch.load(PATH))
<All keys matched successfully>
使用神经网络进行预测:
outputs = net(images)
outputs
tensor([[-0.4519, -2.6896, 1.1111, 2.4411, -1.2739, 0.9407, 1.2027, -0.9218,
-0.3061, -1.4944],
[ 4.0095, 5.7177, -1.3274, -3.2596, -4.4239, -6.4377, -5.2835, -5.2639,
8.8550, 3.4490],
[ 2.2643, 1.9055, 0.2977, -1.2159, -1.5517, -2.6117, -2.5904, -2.0696,
3.1488, 0.7971],
[ 3.6302, 0.2553, 0.3926, -1.3850, 0.2644, -2.8077, -2.8192, -1.0332,
1.9776, 0.4094]], grad_fn=<AddmmBackward0>)
输出是 10 类的能量。 一个类别的能量越高,网络就认为该图像属于特定类别。 因此,让我们获取最高能量的指数:
_, predicted = torch.max(outputs, 1)
print('Predicted: ', ' '.join(f'{classes[predicted[j]]:5s}'
for j in range(4)))
Predicted: cat ship ship plane
此次结果看起来不错
我们看看这个网络在整个数据集的表现:
correct = 0
total = 0
# since we're not training, we don't need to calculate the gradients for our outputs
with torch.no_grad():
for data in testloader:
images, labels = data
# calculate outputs by running images through the network
outputs = net(images)
# the class with the highest energy is what we choose as prediction
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print(f'Accuracy of the network on the 10000 test images: {100 * correct // total} %')
Accuracy of the network on the 10000 test images: 56 %
这看起来是比偶然更好(偶然的准确率是10%,即从10个类别中选择一个),看起来这个网络学到了一些东西
看看这个这个分类器在哪些类别分类好,哪些类别分类差:
# prepare to count predictions for each class
correct_pred = {classname: 0 for classname in classes}
total_pred = {classname: 0 for classname in classes}
# again no gradients needed
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predictions = torch.max(outputs, 1)
# collect the correct predictions for each class
for label, prediction in zip(labels, predictions):
if label == prediction:
correct_pred[classes[label]] += 1
total_pred[classes[label]] += 1
# print accuracy for each class
for classname, correct_count in correct_pred.items():
accuracy = 100 * float(correct_count) / total_pred[classname]
print(f'Accuracy for class: {classname:5s} is {accuracy:.1f} %')
Accuracy for class: plane is 65.5 %
Accuracy for class: car is 67.1 %
Accuracy for class: bird is 30.4 %
Accuracy for class: cat is 53.5 %
Accuracy for class: deer is 44.2 %
Accuracy for class: dog is 35.9 %
Accuracy for class: frog is 68.2 %
Accuracy for class: horse is 70.3 %
Accuracy for class: ship is 68.9 %
Accuracy for class: truck is 60.4 %
2.6.在GPU上训练¶
如果可以使用 CUDA,首先将我们的设备定义为第一个可见的 cuda 设备:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# Assuming that we are on a CUDA machine, this should print a CUDA device:
print(device)
cuda:0
然后,这些方法将递归遍历所有模块,并将其参数和缓冲区转换为 CUDA 张量:
net.to(device)
Net(
(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=400, out_features=120, bias=True)
(fc2): Linear(in_features=120, out_features=84, bias=True)
(fc3): Linear(in_features=84, out_features=10, bias=True)
)
还必须将每一步的输入和目标也发送到 GPU:
inputs, labels = data[0].to(device), data[1].to(device)
3.参考资料¶
[2]训练分类器
pytorch学习笔记四之训练分类器的更多相关文章
- 莫烦PyTorch学习笔记(四)——回归
下面的代码说明个整个神经网络模拟回归的过程,代码含有详细注释,直接贴下来了 import torch from torch.autograd import Variable import torch. ...
- ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试
http://www.cnblogs.com/denny402/p/5852983.html ensorflow学习笔记四:mnist实例--用简单的神经网络来训练和测试 刚开始学习tf时,我们从 ...
- 官网实例详解-目录和实例简介-keras学习笔记四
官网实例详解-目录和实例简介-keras学习笔记四 2018-06-11 10:36:18 wyx100 阅读数 4193更多 分类专栏: 人工智能 python 深度学习 keras 版权声明: ...
- C#可扩展编程之MEF学习笔记(四):见证奇迹的时刻
前面三篇讲了MEF的基础和基本到导入导出方法,下面就是见证MEF真正魅力所在的时刻.如果没有看过前面的文章,请到我的博客首页查看. 前面我们都是在一个项目中写了一个类来测试的,但实际开发中,我们往往要 ...
- IOS学习笔记(四)之UITextField和UITextView控件学习
IOS学习笔记(四)之UITextField和UITextView控件学习(博客地址:http://blog.csdn.net/developer_jiangqq) Author:hmjiangqq ...
- java之jvm学习笔记四(安全管理器)
java之jvm学习笔记四(安全管理器) 前面已经简述了java的安全模型的两个组成部分(类装载器,class文件校验器),接下来学习的是java安全模型的另外一个重要组成部分安全管理器. 安全管理器 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(四) indigo devices
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS for Robotics Pr ...
- Typescript 学习笔记四:回忆ES5 中的类
中文网:https://www.tslang.cn/ 官网:http://www.typescriptlang.org/ 目录: Typescript 学习笔记一:介绍.安装.编译 Typescrip ...
- ES6学习笔记<四> default、rest、Multi-line Strings
default 参数默认值 在实际开发 有时需要给一些参数默认值. 在ES6之前一般都这么处理参数默认值 function add(val_1,val_2){ val_1 = val_1 || 10; ...
- muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制
目录 muduo网络库学习笔记(四) 通过eventfd实现的事件通知机制 eventfd的使用 eventfd系统函数 使用示例 EventLoop对eventfd的封装 工作时序 runInLoo ...
随机推荐
- java并发数据结构之CopyOnWriteArrayList
CopyOnWriteArrayList是一个线程安全的List实现,其在对对象进行读操作时,由于对象没有发生改变,因此不需要加锁,反之在对象进行增删等修改操作时,它会先复制一个对象副本,然后对副本进 ...
- ChatGPT能做什么?ChatGPT保姆级注册教程
最近 OpenAI 发布的 ChatGPT 聊天机器人很火,该聊天机器人可以在模仿人类说话风格的同时回答大量的问题. 在现实世界之中,例如数字营销.线上内容创作.回答客户服务查询,甚至可以用来帮助调试 ...
- TypeScript 之 Class
class private 和 # 的区别 前缀 private 只是TS语法,在运行时不起作用,外部能够访问,但是类型检查器会报错 class Bag { private item: any } 修 ...
- uni框架引入外部图标
说明 在使用uni框架的uni-nav-bar自定义导航栏的时候我想要引用外部的图标,但是似乎这个好像只能引入uni框架内置的图标 所以我只能把uni的图标进行增加处理,这样引入图标的方式就和正常的引 ...
- python 实现RSA公钥加密,私钥解密
from Crypto.PublicKey import RSA from Crypto.Cipher import PKCS1_v1_5 as Cipher_pkcs1_v1_5 from Cryp ...
- k8s本地联调工具kt-connect
1.Kt Connect简介 KT Connect ( Kubernetes Developer Tool ) 是轻量级的面向 Kubernetes 用户的开发测试环境治理辅助工具.其核心是通过建立本 ...
- “喜提”一个P2级故障—CMSGC太频繁,你知道这是什么鬼?
大家好,我是陶朱公Boy. 背景 今天跟大家分享一个前几天在线上碰到的一个GC故障- "CMSGC太频繁". 不知道大家看到这条告警内容后,是什么感触?我当时是一脸懵逼的,一万个为 ...
- 使用 GPG 签名提交
GPG 签名是对代码提交者进行身份验证的一种补充,即证明代码提交来密钥持有者,理论上可以确保在目前的破译技术水平下无法篡改内容.您可以使用 GPG 工具 (GNU Privacy Guard) 生成密 ...
- Spring 和 Spring MVC的区别
Spring 和 Spring MVC的区别 学习Spring MVC也有几天时间了,那么Spring和Spring MVC的区别到底在哪里,二者是什么关系呢?认为二者是一个东西那肯定是不对的,而 ...
- Spring Cloud服务发现组件Eureka
简介 Netflix Eureka是微服务系统中最常用的服务发现组件之一,非常简单易用.当客户端注册到Eureka后,客户端可以知道彼此的hostname和端口等,这样就可以建立连接,不需要配置. E ...