训练一个分类器

上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重。
你现在可能在想下一步。

关于数据?

一般情况下处理图像、文本、音频和视频数据时,可以使用标准的Python包来加载数据到一个numpy数组中。
然后把这个数组转换成 torch.*Tensor

  • 图像可以使用 Pillow, OpenCV
  • 音频可以使用 scipy, librosa
  • 文本可以使用原始Python和Cython来加载,或者使用 NLTK或
    SpaCy 处理

特别的,对于图像任务,我们创建了一个包
torchvision,它包含了处理一些基本图像数据集的方法。这些数据集包括
Imagenet, CIFAR10, MNIST 等。除了数据加载以外,torchvision 还包含了图像转换器,
torchvision.datasetstorch.utils.data.DataLoader

torchvision包不仅提供了巨大的便利,也避免了代码的重复。

在这个教程中,我们使用CIFAR10数据集,它有如下10个类别
:‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,
‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10的图像都是
3x32x32大小的,即,3颜色通道,32x32像素。

训练一个图像分类器

依次按照下列顺序进行:

  1. 使用torchvision加载和归一化CIFAR10训练集和测试集

  2. 定义一个卷积神经网络

  3. 定义损失函数

  4. 在训练集上训练网络

  5. 在测试集上测试网络

  6. 读取和归一化 CIFAR10


使用torchvision可以非常容易地加载CIFAR10。

import torch
import torchvision
import torchvision.transforms as transforms

torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。

transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2) testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2) classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\cifar-10-python.tar.gz

100%|███████████████████████████████████████████████████████████████████████████████▉| 170M/170M [20:39<00:00, 155kB/s]

Files already downloaded and verified

我们展示一些训练图像。

import matplotlib.pyplot as plt
import numpy as np # 展示图像的函数 def imshow(img):
img = img / 2 + 0.5 # unnormalize
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0))) # 获取随机数据
dataiter = iter(trainloader)
images, labels = dataiter.next() # 展示图像
imshow(torchvision.utils.make_grid(images))
# 显示图像标签
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))
171MB [20:51, 155kB/s]                                                                                                 

  cat   car   cat  ship
  1. 定义一个卷积神经网络

从之前的神经网络一节复制神经网络代码,并修改为输入3通道图像。

import torch.nn as nn
import torch.nn.functional as F class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16 * 5 * 5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x net = Net()
  1. 定义损失函数和优化器

我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。

import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
  1. 训练网路

有趣的时刻开始了。
我们只需在数据迭代器上循环,将数据输入给网络,并优化。

for epoch in range(2):  # 多批次循环

    running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入
inputs, labels = data # 梯度置0
optimizer.zero_grad() # 正向传播,反向传播,优化
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step() # 打印状态信息
running_loss += loss.item()
if i % 2000 == 1999: # 每2000批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0 print('Finished Training')
  1. 在测试集上测试网络

我们在整个训练集上进行了2次训练,但是我们需要检查网络是否从数据集中学习到有用的东西。
通过预测神经网络输出的类别标签与实际情况标签进行对比来进行检测。
如果预测正确,我们把该样本添加到正确预测列表。
第一步,显示测试集中的图片并熟悉图片内容。

dataiter = iter(testloader)
images, labels = dataiter.next() # 显示图片
imshow(torchvision.utils.make_grid(images))
print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))
GroundTruth:    cat  ship  ship plane

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-7N9Ui0NR-1612164760383)(output_14_1.png)]

让我们看看神经网络认为以上图片是什么。

outputs = net(images)

输出是10个标签的能量。
一个类别的能量越大,神经网络越认为它是这个类别。所以让我们得到最高能量的标签。

_, predicted = torch.max(outputs, 1)

print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]
for j in range(4)))
Predicted:  plane plane plane plane

结果看来不错。

接下来让看看网络在整个测试集上的结果如何。

correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
Accuracy of the network on the 10000 test images: 9 %

结果看起来不错,至少比随机选择要好,随机选择的正确率为10%。
似乎网络学习到了一些东西。

在识别哪一个类的时候好,哪一个不好呢?

class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))
with torch.no_grad():
for data in testloader:
images, labels = data
outputs = net(images)
_, predicted = torch.max(outputs, 1)
c = (predicted == labels).squeeze()
for i in range(4):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1 for i in range(10):
print('Accuracy of %5s : %2d %%' % (
classes[i], 100 * class_correct[i] / class_total[i]))
Accuracy of plane : 99 %
Accuracy of car : 0 %
Accuracy of bird : 0 %
Accuracy of cat : 0 %
Accuracy of deer : 0 %
Accuracy of dog : 0 %
Accuracy of frog : 0 %
Accuracy of horse : 0 %
Accuracy of ship : 0 %
Accuracy of truck : 0 %

下一步?

我们如何在GPU上运行神经网络呢?

在GPU上训练

把一个神经网络移动到GPU上训练就像把一个Tensor转换GPU上一样简单。并且这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 确认我们的电脑支持CUDA,然后显示CUDA信息:

print(device)

本节的其余部分假定device是CUDA设备。

然后这些方法将递归遍历所有模块并将模块的参数和缓冲区
转换成CUDA张量:

    net.to(device)

记住:inputs, targets 和 images 也要转换。

        inputs, labels = inputs.to(device), labels.to(device)

为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。

实践:
尝试增加你的网络的宽度(第一个nn.Conv2d的第2个参数,第二个nn.Conv2d的第一个参数,它们需要是相同的数字),看看你得到了什么样的加速。

实现的目标:

  • 深入了解了PyTorch的张量库和神经网络
  • 训练了一个小网络来分类图片

译者注:后面我们教程会训练一个真正的网络,使识别率达到90%以上。

多GPU训练

如果你想使用所有的GPU得到更大的加速,
请查看数据并行处理

下一步?

  • :doc:训练神经网络玩电子游戏 </intermediate/reinforcement_q_learning>
  • 在ImageNet上训练最好的ResNet
  • 使用对抗生成网络来训练一个人脸生成器
  • 使用LSTM网络训练一个字符级的语言模型
  • 更多示例
  • 更多教程
  • 在论坛上讨论PyTorch
  • Slack上与其他用户讨论

[Pytorch框架] 1.6 训练一个分类器的更多相关文章

  1. PyTorch Tutorials 4 训练一个分类器

    %matplotlib inline 训练一个分类器 上一讲中已经看到如何去定义一个神经网络,计算损失值和更新网络的权重. 你现在可能在想下一步. 关于数据? 一般情况下处理图像.文本.音频和视频数据 ...

  2. 【PyTorch深度学习60分钟快速入门 】Part4:训练一个分类器

      太棒啦!到目前为止,你已经了解了如何定义神经网络.计算损失,以及更新网络权重.不过,现在你可能会思考以下几个方面: 0x01 数据集 通常,当你需要处理图像.文本.音频或视频数据时,你可以使用标准 ...

  3. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  4. 使用weka训练一个分类器

    1 训练集数据 1.1 csv格式 5.1,3.5,1.4,0.2,Iris-setosa 4.9,3.0,1.4,0.2,Iris-setosa 4.7,3.2,1.3,0.2,Iris-setos ...

  5. Fine-tuning Convolutional Neural Networks for Biomedical Image Analysis: Actively and Incrementally如何使用尽可能少的标注数据来训练一个效果有潜力的分类器

    作者:AI研习社链接:https://www.zhihu.com/question/57523080/answer/236301363来源:知乎著作权归作者所有.商业转载请联系作者获得授权,非商业转载 ...

  6. 迁移学习算法之TrAdaBoost ——本质上是在用不同分布的训练数据,训练出一个分类器

    迁移学习算法之TrAdaBoost from: https://blog.csdn.net/Augster/article/details/53039489 TradaBoost算法由来已久,具体算法 ...

  7. 如何使用Pytorch迅速实现Mnist数据及分类器

    一段时间没有更新博文,想着也该写两篇文章玩玩了.而从一个简单的例子作为开端是一个比较不错的选择.本文章会手把手地教读者构建一个简单的Mnist(Fashion-Mnist同理)的分类器,并且会使用相对 ...

  8. 【chainer框架】【pytorch框架】

    教程: https://bennix.github.io/ https://bennix.github.io/blog/2017/12/14/chain_basic/ https://bennix.g ...

  9. pytorch:EDSR 生成训练数据的方法

    Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...

  10. 手写数字识别 卷积神经网络 Pytorch框架实现

    MNIST 手写数字识别 卷积神经网络 Pytorch框架 谨此纪念刚入门的我在卷积神经网络上面的摸爬滚打 说明 下面代码是使用pytorch来实现的LeNet,可以正常运行测试,自己添加了一些注释, ...

随机推荐

  1. linux下文件重命名

    Ubuntu下执行上面举例的重命名时,命令是这样的:rename 's/a/xxx/g' *.txt

  2. MyBatis-Plus 代码生成器超详细讲解

    参见:    https://www.jianshu.com/p/9d8ab1bb84bb

  3. ubuntu主机连接家里的网线

    第一步,先让物理机连接网络: 注释掉/etc/network/interfaces文件的最后一行,即: 意思是不要手动设置网络了,而是转为自动设置.这样,主机就可以联网了. 参考链接:https:// ...

  4. win10开启休眠

    powercfg /hibernate on 管理员模式下的命令提示符

  5. MacOS ssh config 配置

    Host 别名 #password 注释,保存密码 HostName IP User 服务器账号#root Port 端口 IdentityFile ~/.ssh/id_rsa #指定密钥 Remot ...

  6. python-魔法函数-__str__ __repr__ 的一次异常

    # encoding: utf-8import loggingERROR_NOT_FOUNDED_FILE = "error_not_founded_file"class Gene ...

  7. ajax缓存和fiddler——http协议调试代理工具

    1.在ie9下,ajax请求可能会有缓存,需要在请求上一个随机数 如:Math.random(); 2.fiddler2 打开以后可以查看所有的http请求情况,也可以使用本地脚本代替要请求的js文件 ...

  8. windows下 mstsc 远程Ubuntu 图形界面

    安装及设置xrdp ------------------------------------------------------ touch ~/installXrdp.sh  cat > ~/ ...

  9. Android笔记--内容提供者+Server端+Client端

    什么是内容提供者ContentProvider 为App存取内部数据提供的统一的外部接口,让不同的应用之间得以实现数据共享 Client App端 用户输入数据的一端,或者说是用户读取到存储的数据的一 ...

  10. Javaweb学习笔记第十四弹---对于Cookie和Filter的学习

    Apache Tomcat - Tomcat Native Downloads 会话追踪技术 会话:打开浏览器,建立连接,直到一方断开连接,会话才会结束:在一次会议中,可以有多次请求. 会话追踪:在多 ...