本节将使用torchvision包,它是服务于pytorch深度学习框架的,主要用来构建计算机视觉模型。

torchvision主要由以下几个部分构成:

  1. torchvision.datasets:一些加载数据的函数以及常用的数据集的接口
  2. torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet,VGG,ResNet;
  3. torchvision.transforms:常用的图片变换,例如裁剪,旋转等;
  4. torchvision.utils: 其他的一些有用的方法
获取数据集

导入本节需要的包或者模块

import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..') # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l

通过调用torchvision的torchvision.datasets来下载这个数据集

可以通过train参数获取指定的训练集或者测试集、

测试集只用了评估模型,并不用来训练模型

同时指定了参数transform = transform.ToTensor()使所有数据转化为Tensor,如果不进行转化,则返回的是PIL照片。

transform.ToTensor()将尺寸为(H,W,C)且数据位于[0,255]的PIL图片或者数据类型为np.unit8的Numpy数组转化为(CxHxW)且数据类型为torch.float32且位于[0.0,1.0]的Tensor。

  • 如果用像素值(0,255)表示图片数据,一律将其类型设置为unit8,避免出问题
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor())
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
feature,label = mnist_train[0]
print(feature.shape,label) # channel * height* width
torch.Size([1, 28, 28]) tensor(9)

feature对应的高和宽均为28像素的图像,由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0,1]的32位浮点数。需要注意的是,feature的尺寸是(CxHxW)的,而不是(HxWxC)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1,后面两维分别是图像的高和宽。

Fashion_MNIST中一共包括了10个类别,分别是t-shirt(T恤),trouser(裤子),pullover(套衫),dress(连衣裙),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包)和ankle boot(短靴)

import d2lzh_pytorch as d2l
def get_fashion_mnist_labels(labels):
text_labels = ['t-shirt','trouser','pullover','dress','coat','sandal',
'shirt','sneaker','bag','ankle boost'
]
return [text_labels[int(i)] for i in labels] def show_fashion_mnist(images,labels):
d2l.use_svg_display()
_,figs = plt.subplots(1,len(images),figsize=(12,12)) # 1行10列
for f ,img,lbl in zip(figs,images,labels):
f.imshow(img.view((28,28)).numpy())
f.set_title(lbl)
f.axes.get_xaxis().set_visible(False)
f.axes.get_yaxis().set_visible(False)
plt.show()
X,y = [],[]
for i in range(10):
X.append(mnist_train[i][0])
y.append(mnist_test[i][1])
show_fashion_mnist(X,get_fashion_mnist_labels(y))

读取小批量样本

我们将在训练集上训练模型,并将训练好的模型预测测试集上评估模型的表现。

可以用torch.utils.data.Dataloader来创建一个读取小批量样本的DataLoader实例。

在实际中,数据读取经常是训练的性能瓶颈,特别是当模型较为简单或者计算硬件性能较高时,pytorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置进程数来加速读取数据

batch_size= 256

if sys.platform.startswith('win'):
num_worker=0 # 表示不用额外的进程来加速读取数据 else:
num_worker=4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)
start = time.time()
for X,y in train_iter:
continue
print('%.2f sec' % (time.time()-start))
1.28 sec
小结
  • Fashion_MNIST 是一个10类服饰的分类数据集,之后章节后使用它来验证不同算法的表现
  • 我们将高和宽分别是H和W像素的图像的形状记为HxW或(h,w)

动手学深度学习6-认识Fashion_MNIST图像数据集的更多相关文章

  1. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  2. 对比《动手学深度学习》 PDF代码+《神经网络与深度学习 》PDF

    随着AlphaGo与李世石大战的落幕,人工智能成为话题焦点.AlphaGo背后的工作原理"深度学习"也跳入大众的视野.什么是深度学习,什么是神经网络,为何一段程序在精密的围棋大赛中 ...

  3. 【动手学深度学习】Jupyter notebook中 import mxnet出错

    问题描述 打开d2l-zh目录,使用jupyter notebook打开文件运行,import mxnet 出现无法导入mxnet模块的问题, 但是命令行运行是可以导入mxnet模块的. 原因: 激活 ...

  4. 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型

    目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...

  5. 动手学深度学习14- pytorch Dropout 实现与原理

    方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...

  6. 动手学深度学习9-多层感知机pytorch

    多层感知机 隐藏层 激活函数 小结 多层感知机 之前已经介绍过了线性回归和softmax回归在内的单层神经网络,然后深度学习主要学习多层模型,后续将以多层感知机(multilayer percetro ...

  7. 动手学深度学习1- pytorch初学

    pytorch 初学 Tensors 创建空的tensor 创建随机的一个随机数矩阵 创建0元素的矩阵 直接从已经数据创建tensor 创建新的矩阵 计算操作 加法操作 转化形状 tensor 与nu ...

  8. mxnet 动手学深度学习

    http://zh.gluon.ai/chapter_crashcourse/introduction.html 强化学习(Reinforcement Learning) 如果你真的有兴趣用机器学习开 ...

  9. 动手学深度学习10- pytorch多层感知机从零实现

    多层感知机 定义模型的参数 定义激活函数 定义模型 定义损失函数 训练模型 小结 多层感知机 import torch import numpy as np import sys sys.path.a ...

随机推荐

  1. Spring源码分析之IOC的三种常见用法及源码实现(三)

    上篇文章我们分析了AnnotationConfigApplicationContext的构造器里refresh方法里的invokeBeanFactoryPostProcessors,了解了@Compo ...

  2. Java生鲜电商平台-RBAC系统权限的设计与架构

    Java生鲜电商平台-RBAC系统权限的设计与架构 说明:根据上面的需求描述以及对需求的分析,我们得知通常的一个中小型系统对于权限系统所需实现的功能以及非功能性的需求,在下面我们将根据需求从技术角度上 ...

  3. 进程调度算法spf,fpf,时间片轮转算法实现

    调度的基本概念:从就绪队列中按照一定的算法选择一个进程并将处理机分配给它运行,以实现进程并发地执行. 进程信息 struct node { string name;//进程名称 int id;//进程 ...

  4. React中的三大属性

    一.前言: 属性1:state 属性2:props 属性3:ref 与事件处理 二.主要内容: 属性1:state 1,认识: 1) state 是组件对象中最重要的属性,值是一个对象(可以包含多个数 ...

  5. kuangbin专题简单搜索题目几道题目

    1.POJ1321棋盘问题 Description 在一个给定形状的棋盘(形状可能是不规则的)上面摆放棋子,棋子没有区别.要求摆放时任意的两个棋子不能放在棋盘中的同一行或者同一列,请编程求解对于给定形 ...

  6. js 设计模式——单例模式

    单例模式 保证一个类仅有一个实例,并提供一个访问它的全局访问点. 单例模式是一种常用的模式,有一些对象我们往往只需要一个,比如线程池.全局缓存.浏览器中的 window 对象等. JavaScript ...

  7. es6中,promise使用过程的小总结

    参考资料传送门:戳一戳 1.是什么 Promise是异步编程的一种解决方案,有三种状态:pending(进行中).fulfilled(已成功)和rejected(已失败); 一般成功了状态用resol ...

  8. 探究分析---利用sql批量更新部分时间的同比数据

    问题:如何将social_kol_tmp表 中的字段cost_YA中日期为201901-201909中的值替换为相同brand和pltform对应18年月份的col_cost字段的数据,其他日期的co ...

  9. 3-3 groupby操作

    Pandas章节应用的数据可以在以下链接下载:  https://files.cnblogs.com/files/AI-robort/Titanic_Data-master.zip .caret, . ...

  10. 外网穿透-natapp安装配置(windows)

    natapp官网 natapp服务器更新:全面支持HTTPS协议以及本地SSL证书,支持WSS协议.同时支持HTTP/2 WEB协议,支持微信小程序本地开发.全面自动支持泛子域名与访客真实IP地址. ...