pytorch批训练数据构造
这是对莫凡python的学习笔记。
1.创建数据
import torch
import torch.utils.data as Data BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)
可以看到创建了两个一维数据,x:1~10,y:10~1
2.构造数据集对象,及数据加载器对象
torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = False,
num_workers = 2)
num_workers应该指的是多线程
3.输出数据集,这一步主要是看一下batch长什么样子
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch:',epoch,'| Step:', step, '| batch x:',
batch_x.numpy(), '| batch y:', batch_y.numpy())
输出如下
('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
可以看到,batch_size等于8,则第二个bacth的数据只有两个。
将batch_size改为5,输出如下
('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
pytorch批训练数据构造的更多相关文章
- pytorch:EDSR 生成训练数据的方法
Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...
- [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader
[源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...
- Pytorch之训练器设置
Pytorch之训练器设置 引言 深度学习训练的时候有很多技巧, 但是实际用起来效果如何, 还是得亲自尝试. 这里记录了一些个人尝试不同技巧的代码. tensorboardX 说起tensorflow ...
- [NN] 随机VS批训练
本文翻译节选自1998-Efficient BackProp, Yann LeCun et al.. 4.1 随机VS批训练 每一次迭代, 传统训练方式都需要遍历所有数据集来计算平均梯度. 批训练也同 ...
- [Pytorch]PyTorch Dataloader自定义数据读取
整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...
- pytorch1.0批训练神经网络
pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...
- libsvm的安装,数据格式,常见错误,grid.py参数选择,c-SVC过程,libsvm参数解释,svm训练数据,libsvm的使用详解,SVM核函数的选择
直接conda install libsvm安装的不完整,缺几个.py文件. 第一种安装方法: 下载:http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm. ...
- Alink漫谈(七) : 如何划分训练数据集和测试数据集
Alink漫谈(七) : 如何划分训练数据集和测试数据集 目录 Alink漫谈(七) : 如何划分训练数据集和测试数据集 0x00 摘要 0x01 训练数据集和测试数据集 0x02 Alink示例代码 ...
- GitHub上YOLOv5开源代码的训练数据定义
GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...
随机推荐
- Oracle、SqlServer——临时表
一.oracle 1.概述: oracle数据库的临时表的特点: 临时表默认保存在TEMP中: 表结构一直存在,直到删除:即创建一次,永久使用: 不支持主外键. 可以索引临时表和在临时表基础上建立视图 ...
- Struts1使用技巧
转自:https://blog.csdn.net/chjttony/article/details/6099101 1.Struts1是Apache推出的java web开发领域一个比较早,同时也是使 ...
- java之静态函数和静态变量
静态变量: 静态变量好似一种成员变量,它的特点是前面有static. 普通变量会有多份,它在每个对象当中都存在,但是静态变量只有一份,它是属于类的. 静态变量的调用方法: 1.类名.变量名 Custo ...
- Ubuntu14.04 安装Source Insight
在Ubuntu中,安装Windows程序用wine,然后用wine安装Windows软件即可. 1.安装wine 在终端输入以下命令: sudo apt-get install wine 2.用win ...
- POJ 1220 高精度/进制转换
n进制转m进制,虽然知道短除法但是还是不太理解,看了代码理解一些了: 记住这个就好了: for(int k=0;l; ){ for(int i=l ; i>=1 ; i--){ num[i - ...
- apt-get默认下载路径
备忘: Ubuntu中apt-get下载的安装包都在哪里呢? 在/var/cache/apt/archives里,里边的安装包可以取出来以备后用.
- elasticsearch 6.2.4 安装 elasticsearch-analysis-ik 分词器 (windows 10下)
访问 https://github.com/medcl/elasticsearch-analysis-ik 找 releases 找到对应的 es 版本 下载 elasticsearch-analy ...
- Help Bubu UVALive - 4490
传送门 题目大意 有n本书,最多k次操作,每次操作可以把一本书拿出来,放到一个位置去,有一个指标较mess度,他是书的高度的段数,连续的书高度一样算一段,现在给你最先开始各个位置上的书的高度,求操作后 ...
- Luogu 3479 [POI2009]GAS-Fire Extinguishers
补上了这一道原题,感觉弱化版的要简单好多. 神贪心: 我们设$cov_{x, i}$表示在$x$的子树中与$x$距离为$i$的还没有被覆盖到的结点个数,设$rem_{x, i}$表示在$x$的子树中与 ...
- Git 之 修复bug
前面介绍了Git版本控制,为我们省去了很多不必要的麻烦. 回滚 有没有想过,在我们开发过程中,修改需要是常有的事,如果我们现在不想要这个功能了,那么如何回到之前的版本呢?回滚,回到上一个版本. 那如果 ...