数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。

所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。

FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensorLambda

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)

输出:

点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw

ToTensor()

ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]

Lambda Transforms

Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。

target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))

延伸阅读

PyTorch 介绍 | TRANSFORMS的更多相关文章

  1. PyTorch 介绍 | DATSETS & DATALOADERS

    用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...

  2. PyTorch 介绍 | BUILD THE NEURAL NETWORK

    神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...

  3. PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD

    训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...

  4. pytorch随笔

    pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...

  5. Keras vs. PyTorch in Transfer Learning

    We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...

  6. Pytorch(一)

    一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...

  7. PyTorch 实战:计算 Wasserstein 距离

    PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...

  8. Generative Adversarial Network (GAN) - Pytorch版

    import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...

  9. Tensorflow和pytorch安装(windows安装)

    一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...

随机推荐

  1. 【LeetCode】1. Two Sum 两数之和

    作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 个人公众号:负雪明烛 本文关键词:two sum, 两数之和,题解,leetcode, 力 ...

  2. 【LeetCode】155. Min Stack 最小栈 (Python&C++)

    作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 目录 题目描述 解题方法 栈同时保存当前值和最小值 辅助栈 同步栈 不同步栈 日期 题目地 ...

  3. BestCoder Round #66 (div.2)B GTW likes gt

    思路:一个O(n)O(n)的做法.我们发现b_1,b_2,...,b_xb​1​​,b​2​​,...,b​x​​都加11就相当于b_{x+1},b_{x+2},...,b_nb​x+1​​,b​x+ ...

  4. React MobX 开始

    MobX 用于状态管理,简单高效.本文将于 React 上介绍如何开始,包括了: 了解 MobX 概念 从零准备 React 应用 MobX React.FC 写法 MobX React.Compon ...

  5. Orcale

    oracleoracle中不存在引擎的概念,数据处理大致可以分成两大类:联机事务处理OLTP(on-line transaction processing).联机分析处理OLAP(On-Line An ...

  6. eclipse的安装及最大子数组求和

    我安装的是eclipse.由于eclipse是一个基于Java的课扩展开发平台,所以在安装eclipse之前要先安装Java的开发工具JDK(Java Devolopment Dit),且安装JDK需 ...

  7. Kronecker Products and Stack Operator

    目录 定义 Stack Operator Kronecker Product 性质 Stack Operator Kronecker Product 半线性 Whitcomb L. Notes on ...

  8. 使用 JavaScript 根据消费金额和消费者是否为会员确定折扣,最终核算实际应该支付的金额

    查看本章节 查看作业目录 需求说明: 根据消费金额和消费者是否为会员确定折扣,最终核算实际应该支付的金额 消费金额在 200 元以上的会员折扣是 7.5 折,消费金额没有达到 200 元的会员折扣是 ...

  9. LDAP客户端安装

    安装环境: 10.43.159.7 客户端 使用ldap客户端验证登陆: 用户为10.43.159.9服务端上面创建的ldap:zdh1234 1.安装LDAP client认证需要的pam包 yum ...

  10. Sentry 企业级数据安全解决方案 - Relay 运行模式

    内容整理自官方开发文档 Relay 可以在几种主要模式之一下运行,如果您正在配置 Relay server 而不是使用默认设置,那么事先了解这些模式至关重要. 模式存储在配置文件中,该文件包含 rel ...