PyTorch 介绍 | TRANSFORMS
数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。
所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform 修改features,targe_transform 修改标签。torchvision.transforms提供了几种现成的常用转换操作。
FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensor 和 Lambda。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)
输出:
点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()
ToTensor将PIL图像或NumPy ndarray 转换为 FloatTensor。并且将图片像素值缩放到范围[0., 1.]
Lambda Transforms
Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y得到索引位置,调用scatter_将其赋为1。
target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
延伸阅读
PyTorch 介绍 | TRANSFORMS的更多相关文章
- PyTorch 介绍 | DATSETS & DATALOADERS
用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD
训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...
- pytorch随笔
pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...
- Keras vs. PyTorch in Transfer Learning
We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...
- Pytorch(一)
一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- Generative Adversarial Network (GAN) - Pytorch版
import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...
- Tensorflow和pytorch安装(windows安装)
一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...
随机推荐
- 【LeetCode】350. Intersection of Two Arrays II 解题报告(Java & Python)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 解题方法 Java排序+双指针 Python排序+双指针 Python解 ...
- 【LeetCode】657. Judge Route Circle 解题报告
[LeetCode]657. Judge Route Circle 标签(空格分隔): LeetCode 题目地址:https://leetcode.com/problems/judge-route- ...
- 【LeetCode】14. Longest Common Prefix 最长公共前缀
作者: 负雪明烛 id: fuxuemingzhu 个人博客:http://fuxuemingzhu.cn/ 个人公众号:负雪明烛 本文关键词:prefix, 公共前缀,题解,leetcode, 力扣 ...
- Java 将Excel转为OFD
OFD是一种开放版式文档(Open Fixed-layout Document )的英文缩写,是我国国家版式文档格式标准.本文,通过Java后端程序代码展示如何将Excel转为OFD格式.方法步骤如下 ...
- Java初学者作业——编写Java程序,实现判断所输入字符的类型(数字、小写字母、大写字母或其他字符)
返回本章节 返回作业目录 需求说明: 编写Java程序,实现判断所输入字符的类型(数字.小写字母.大写字母或其他字符) 实现思路: 声明变量c,用于存储用户输入的字符. 通过Scanner接收用户输入 ...
- Error: Cannot find module '@dcloudio/uni-cli-i18n' 解决方案
这个错误是因为node_modules缺少了 '@dcloudio/uni-cli-i18n' 以下是错误信息 解决方案: yarn add -D @dcloudio/uni-cli-i18n ...
- ANT 通配符使用说明
通配符说明 通配符 说明 ? 匹配任意一个字符 * 匹配零个.一个.多个字符 ** 匹配零个.一个.多个目录 使用示例 URL路径 说明 /app/p?ttern 匹配 /app/pattern 和 ...
- 大厂必问的Java集合面试题
本文目录: 常见的集合有哪些? List .Set和Map 的区别 ArrayList 了解吗? ArrayList 的扩容机制? 怎么在遍历 ArrayList 时移除一个元素? Arraylist ...
- JMeter_用户自定义变量
在实际测试过程中,我们经常会碰到脚本开发时与测试执行时的服务地址不一样的情况,为了方便,我们会把访问地址参数化,当访问地址变化了,我们只需要把参数对应的值改动一下就可以了. 一.添加用户自定义变量元件 ...
- 详谈 Java工厂 --- 抽象工厂模式
1.前言 感觉工厂模式都好鸡肋,还特别绕来绕去,当然,好处还是有的,将一些类似的业务都集成到工厂了, 不需要理会底层是怎么运行的,我只需要向调用工厂即可获取我要的结果,也不需要考虑工厂返回的东西类型, ...