PyTorch 介绍 | TRANSFORMS
数据并不总是满足机器学习算法所需的格式。我们使用transform对数据进行一些操作,使得其能适用于训练。
所有的TorchVision数据集都有两个参数,用以接受包含transform逻辑的可调用项-transform
修改features,targe_transform
修改标签。torchvision.transforms提供了几种现成的常用转换操作。
FashionMNIST features是PIL Image格式,标签是整型。为了训练,我们需要将其转换为标准的tensors,并且标签是one-hot编码的tensor。为了完成这些转换,使用 ToTensor
和 Lambda
。
import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda
ds = datasets.FashionMNIST(
root='data',
train=True,
download=True,
transform=ToTensor(),
# 在创建的具有10个0值数组中,单独取第一个维度的y位置(原始标签),赋为1,即为one-hot编码
target_tansform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0,
torch.tensor(y), value=1))
)
输出:
点击查看代码
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-images-idx3-ubyte.gz to data/FashionMNIST/raw/train-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/train-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/train-labels-idx1-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-images-idx3-ubyte.gz to data/FashionMNIST/raw
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz
Downloading http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/FashionMNIST/raw/t10k-labels-idx1-ubyte.gz to data/FashionMNIST/raw
ToTensor()
ToTensor将PIL图像或NumPy ndarray
转换为 FloatTensor
。并且将图片像素值缩放到范围[0., 1.]
Lambda Transforms
Lambda转换可使用任何用户定义的lambda函数。这里,我们定义了一个函数,可以将整型转换成one-hot编码的tensor,首先创建一个大小为10的0值tensor,根据给定标签 y
得到索引位置,调用scatter_将其赋为1。
target_transform = Lambda(lambda y: torch.zeros(
10,dtype=torch.float).scatter_(dim=0, index=torch.tensor(y), value=1))
延伸阅读
PyTorch 介绍 | TRANSFORMS的更多相关文章
- PyTorch 介绍 | DATSETS & DATALOADERS
用于处理数据样本的代码可能会变得凌乱且难以维护:理想情况下,我们希望数据集代码和模型训练代码解耦(分离),以获得更好的可读性和模块性.PyTorch提供了两个data primitives:torch ...
- PyTorch 介绍 | BUILD THE NEURAL NETWORK
神经网络由对数据进行操作的layers/modules组成.torch.nn 命名空间提供了所有你需要的构建块,用于构建你自己的神经网络.PyTorch的每一个module都继承自nn.Module. ...
- PyTorch 介绍 | AUTOMATIC DIFFERENTIATION WITH TORCH.AUTOGRAD
训练神经网络时,最常用的算法就是反向传播.在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整. 为了计算这些梯度,PyTorch内置了名为 torch.autograd 的微分引擎. ...
- pytorch随笔
pytorch中transform函数 一般用Compose把多个步骤整合到一起: 比如说 transforms.Compose([ transforms.CenterCrop(10), transf ...
- Keras vs. PyTorch in Transfer Learning
We perform image classification, one of the computer vision tasks deep learning shines at. As traini ...
- Pytorch(一)
一.Pytorch介绍 Pytorch 是Torch在Python上的衍生物 和Tensorflow相比: Pytorch建立的神经网络是动态的,而Tensorflow建立的神经网络是静态的 Tens ...
- PyTorch 实战:计算 Wasserstein 距离
PyTorch 实战:计算 Wasserstein 距离 2019-09-23 18:42:56 This blog is copied from: https://mp.weixin.qq.com/ ...
- Generative Adversarial Network (GAN) - Pytorch版
import os import torch import torchvision import torch.nn as nn from torchvision import transforms f ...
- Tensorflow和pytorch安装(windows安装)
一. Tensorflow安装 1. Tensorflow介绍 Tensorflow是广泛使用的实现机器学习以及其它涉及大量数学运算的算法库之一.Tensorflow由Google开发,是GitHub ...
随机推荐
- 【LeetCode】540. Single Element in a Sorted Array 解题报告(Python & C++)
作者: 负雪明烛 id: fuxuemingzhu 个人博客: http://fuxuemingzhu.cn/ 目录 题目描述 解题方法 方法一:异或 方法二:判断相邻元素是否相等 方法三:二分查找 ...
- Docker 与 K8S学习笔记(十)—— 容器的端口映射
我们一般将应用部署在容器里面,而一个服务器上会有许许多多的容器,那么外界该如何访问我们的应用呢?答案是:端口映射. Docker可以将容器对外提供服务的端口映射到host的某个端口上,外网通过此端口访 ...
- 设计模式学习——JAVA动态代理原理分析
一.JDK动态代理执行过程 上一篇我们讲了JDK动态代理的简单使用,今天我们就来研究一下它的原理. 首先我们回忆下上一篇的代码: public class Main { public static v ...
- Redis常见使用场景
缓存 数据共享分布式 分布式锁 全局ID 计数器 限流 位统计 购物车 用户消息时间线timeline 消息队列 抽奖 点赞.签到.打卡 商品标签 商品筛选 用户关注.推荐模型 排行榜 1.缓存 St ...
- web服务之nginx部署
本期内容概要 了解web服务 Nginx和Apache的对比 部署Nginx 内容详细 1.什么是web服务 Web服务是一种服务导向架构的技术,通过标准的Web协议提供服务,目的是保证不同平台的应用 ...
- JavaScript交互式网页设计 • 【第3章 JavaScript浏览器对象模型】
全部章节 >>>> 本章目录 3.1 浏览器对象模型 3.1.1 浏览器对象模型 3.2 window 对象 3.2.1 window 对象的常用属性及方法 3.2.2 使 ...
- golang vue 使用 websocket 的例子
一. 编写golang服务端 1.导入必要的websocket包,golang.org/x/net/websocket 或 github.com/golang/net/websocket 2.编写消息 ...
- MyBatis 一级缓存实现详解及使用注意事项
一级缓存介绍 在应用运行过程中,我们有可能在一次数据库会话中,执行多次查询条件完全相同的SQL,MyBatis提供了一级缓存的方案优化这部分场景,如果是相同的SQL语句,会优先命中一级缓存,避免直接对 ...
- Linux登录时,下游回显非常慢
问题现象 登录linux时,远程连接正常,[root@...]回显非常慢,在执行脚本时,很容易导致命令下发错乱 原因分析 家目录下的.bash_history文件太大,导致每次登陆时读取这个文件耗时太 ...
- 基于ShardingJDBC的分库分表及读写分离整理
ShardingJDBC的核心流程主要分成六个步骤,分别是:SQL解析->SQL优化->SQL路由->SQL改写->SQL执行->结果归并,流程图如下: sharding ...