动手学深度学习6-认识Fashion_MNIST图像数据集
本节将使用torchvision包,它是服务于pytorch深度学习框架的,主要用来构建计算机视觉模型。
torchvision主要由以下几个部分构成:
- torchvision.datasets:一些加载数据的函数以及常用的数据集的接口
 - torchvision.models: 包含常用的模型结构(含预训练模型),例如AlexNet,VGG,ResNet;
 - torchvision.transforms:常用的图片变换,例如裁剪,旋转等;
 - torchvision.utils: 其他的一些有用的方法
 
获取数据集
导入本节需要的包或者模块
import torch
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import time
import sys
sys.path.append('..')  # 为了导入上层目录的d2lzh_pytorch
import d2lzh_pytorch as d2l
通过调用torchvision的torchvision.datasets来下载这个数据集
可以通过train参数获取指定的训练集或者测试集、
测试集只用了评估模型,并不用来训练模型
同时指定了参数transform = transform.ToTensor()使所有数据转化为Tensor,如果不进行转化,则返回的是PIL照片。
transform.ToTensor()将尺寸为(H,W,C)且数据位于[0,255]的PIL图片或者数据类型为np.unit8的Numpy数组转化为(CxHxW)且数据类型为torch.float32且位于[0.0,1.0]的Tensor。
- 如果用像素值(0,255)表示图片数据,一律将其类型设置为unit8,避免出问题
 
mnist_train= torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=True,transform=transforms.ToTensor())
mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST',download=True,train=False,transform=transforms.ToTensor())
print(type(mnist_train))
print(len(mnist_train), len(mnist_test))
<class 'torchvision.datasets.mnist.FashionMNIST'>
60000 10000
feature,label = mnist_train[0]
print(feature.shape,label)  # channel * height* width
torch.Size([1, 28, 28]) tensor(9)
feature对应的高和宽均为28像素的图像,由于我们使用了transforms.ToTensor(),所以每个像素的数值为[0,1]的32位浮点数。需要注意的是,feature的尺寸是(CxHxW)的,而不是(HxWxC)。第一维是通道数,因为数据集中是灰度图像,所以通道数为1,后面两维分别是图像的高和宽。
Fashion_MNIST中一共包括了10个类别,分别是t-shirt(T恤),trouser(裤子),pullover(套衫),dress(连衣裙),coat(外套),sandal(凉鞋),shirt(衬衫),sneaker(运动鞋),bag(包)和ankle boot(短靴)
import d2lzh_pytorch as d2l
def get_fashion_mnist_labels(labels):
    text_labels = ['t-shirt','trouser','pullover','dress','coat','sandal',
                  'shirt','sneaker','bag','ankle boost'
                  ]
    return [text_labels[int(i)] for i in labels]
def show_fashion_mnist(images,labels):
    d2l.use_svg_display()
    _,figs = plt.subplots(1,len(images),figsize=(12,12))  # 1行10列
    for f ,img,lbl in zip(figs,images,labels):
        f.imshow(img.view((28,28)).numpy())
        f.set_title(lbl)
        f.axes.get_xaxis().set_visible(False)
        f.axes.get_yaxis().set_visible(False)
    plt.show()
X,y = [],[]
for i in range(10):
    X.append(mnist_train[i][0])
    y.append(mnist_test[i][1])
show_fashion_mnist(X,get_fashion_mnist_labels(y))

读取小批量样本
我们将在训练集上训练模型,并将训练好的模型预测测试集上评估模型的表现。
可以用torch.utils.data.Dataloader来创建一个读取小批量样本的DataLoader实例。
在实际中,数据读取经常是训练的性能瓶颈,特别是当模型较为简单或者计算硬件性能较高时,pytorch的DataLoader中一个很方便的功能是允许使用多进程来加速数据读取。这里我们通过参数num_workers来设置进程数来加速读取数据
batch_size= 256
if sys.platform.startswith('win'):
    num_worker=0   # 表示不用额外的进程来加速读取数据
else:
    num_worker=4
train_iter = torch.utils.data.DataLoader(mnist_train,batch_size=batch_size,shuffle=True,num_workers=num_worker)
test_iter = torch.utils.data.DataLoader(mnist_test,batch_size=batch_size,shuffle=False,num_workers=num_worker)
start = time.time()
for X,y in train_iter:
    continue
print('%.2f sec' % (time.time()-start))
1.28 sec
小结
- Fashion_MNIST 是一个10类服饰的分类数据集,之后章节后使用它来验证不同算法的表现
 - 我们将高和宽分别是H和W像素的图像的形状记为HxW或(h,w)
 
动手学深度学习6-认识Fashion_MNIST图像数据集的更多相关文章
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
		
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
 - 对比《动手学深度学习》 PDF代码+《神经网络与深度学习 》PDF
		
随着AlphaGo与李世石大战的落幕,人工智能成为话题焦点.AlphaGo背后的工作原理"深度学习"也跳入大众的视野.什么是深度学习,什么是神经网络,为何一段程序在精密的围棋大赛中 ...
 - 【动手学深度学习】Jupyter notebook中 import mxnet出错
		
问题描述 打开d2l-zh目录,使用jupyter notebook打开文件运行,import mxnet 出现无法导入mxnet模块的问题, 但是命令行运行是可以导入mxnet模块的. 原因: 激活 ...
 - 《动手学深度学习》系列笔记—— 1.2 Softmax回归与分类模型
		
目录 softmax的基本概念 交叉熵损失函数 模型训练和预测 获取Fashion-MNIST训练集和读取数据 get dataset softmax从零开始的实现 获取训练集数据和测试集数据 模型参 ...
 - 动手学深度学习14- pytorch Dropout 实现与原理
		
方法 从零开始实现 定义模型参数 网络 评估函数 优化方法 定义损失函数 数据提取与训练评估 pytorch简洁实现 小结 针对深度学习中的过拟合问题,通常使用丢弃法(dropout),丢弃法有很多的 ...
 - 动手学深度学习9-多层感知机pytorch
		
多层感知机 隐藏层 激活函数 小结 多层感知机 之前已经介绍过了线性回归和softmax回归在内的单层神经网络,然后深度学习主要学习多层模型,后续将以多层感知机(multilayer percetro ...
 - 动手学深度学习1- pytorch初学
		
pytorch 初学 Tensors 创建空的tensor 创建随机的一个随机数矩阵 创建0元素的矩阵 直接从已经数据创建tensor 创建新的矩阵 计算操作 加法操作 转化形状 tensor 与nu ...
 - mxnet 动手学深度学习
		
http://zh.gluon.ai/chapter_crashcourse/introduction.html 强化学习(Reinforcement Learning) 如果你真的有兴趣用机器学习开 ...
 - 动手学深度学习10- pytorch多层感知机从零实现
		
多层感知机 定义模型的参数 定义激活函数 定义模型 定义损失函数 训练模型 小结 多层感知机 import torch import numpy as np import sys sys.path.a ...
 
随机推荐
- Delphi - 操作Excel数据公式的实现
			
procedure TF_SMP_FT_NEW.RzBitBtn_StartToChangeClick(Sender: TObject); var i, j, ni, nj, iRows, iCol, ...
 - Oracle - 数字处理 - 取上取整、向下取整、保留N位小数、四舍五入、数字格式化
			
用oracle sql对数字进行操作: 取上取整.向下取整.保留N位小数.四舍五入.数字格式化 取整(向下取整): select floor(5.534) from dual; select trun ...
 - List的Clear方法与RemoveAll方法用法小结
			
转自:https://blog.csdn.net/yl2isoft/article/details/17059093 结果分析 执行List的Clear方法和RemoveAll方法,List将清除指定 ...
 - Java网上体育商城系统ssh
			
网上体育商城的主要功能包括:前台用户登录退出.注册.在线购物.修改个人信息.后台商品管理等等.本系统结构如下:(1)商品浏览模块: 首页浏览最新上市商品,按销量排行显示商品 ...
 - SSO单点登录和CAS
			
一.单点登录流程 =====客户端====== 1.拦截客户端的请求判断是否有局部的session 2.1如果有局部的session,放行请求. 2.2如果没有局部session 2.2.1请求中有携 ...
 - Java自学-集合框架 遍历
			
遍历ArrayList的三种方法 步骤 1 : 用for循环遍历 通过前面的学习,知道了可以用size()和get()分别得到大小,和获取指定位置的元素,结合for循环就可以遍历出ArrayList的 ...
 - swift(一)基础变量类型
			
import Foundation println("Hello, World!") /* int a; */ var a = //隐式类型转换 a = println(a) le ...
 - HeadFirst设计模式---装饰者
			
定义装饰者模式 装饰者模式动态地将责任附加到对象上,若要扩展功能,装饰者提供了比继承更有弹性的替代方案.这句话摘自书中,给人读得很生硬难懂.通俗地来说,装饰者和被装饰者有相同的父类,装饰者的行为组装着 ...
 - traceroute在linux中的使用方法
			
traceroute在linux中的使用方法 一.traceroute的实现原理 二.traceroute命令使用方法 1.命令格式 2.常用命令参数 3.使用实例 一.traceroute的实现原理 ...
 - Rust中的Rc--引用计数智能指针
			
大部分情况下所有权是非常明确的:可以准确的知道哪个变量拥有某个值.然而,有些情况单个值可能会有多个所有者.例如,在图数据结构中,多个边可能指向相同的结点,而这个结点从概念上讲为所有指向它的边所拥有.结 ...