这是对莫凡python的学习笔记。

1.创建数据

import torch
import torch.utils.data as Data BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

可以看到创建了两个一维数据,x:1~10,y:10~1

2.构造数据集对象,及数据加载器对象

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = BATCH_SIZE,
shuffle = False,
num_workers = 2)

num_workers应该指的是多线程

3.输出数据集,这一步主要是看一下batch长什么样子

for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print('Epoch:',epoch,'| Step:', step, '| batch x:',
batch_x.numpy(), '| batch y:', batch_y.numpy())

输出如下

('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6., 5., 4., 3.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 9., 10.], dtype=float32), '| batch y:', array([2., 1.], dtype=float32))

可以看到,batch_size等于8,则第二个bacth的数据只有两个。

将batch_size改为5,输出如下

('Epoch:', 0, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10.,  9.,  8.,  7.,  6.], dtype=float32))
('Epoch:', 0, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 1, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 1, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))
('Epoch:', 2, '| Step:', 0, '| batch x:', array([1., 2., 3., 4., 5.], dtype=float32), '| batch y:', array([10., 9., 8., 7., 6.], dtype=float32))
('Epoch:', 2, '| Step:', 1, '| batch x:', array([ 6., 7., 8., 9., 10.], dtype=float32), '| batch y:', array([5., 4., 3., 2., 1.], dtype=float32))

pytorch批训练数据构造的更多相关文章

  1. pytorch:EDSR 生成训练数据的方法

    Pytorch:EDSR 生成训练数据的方法 引言 Winter is coming 正文 pytorch提供的DataLoader 是用来包装你的数据的工具. 所以你要将自己的 (numpy arr ...

  2. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  3. Pytorch之训练器设置

    Pytorch之训练器设置 引言 深度学习训练的时候有很多技巧, 但是实际用起来效果如何, 还是得亲自尝试. 这里记录了一些个人尝试不同技巧的代码. tensorboardX 说起tensorflow ...

  4. [NN] 随机VS批训练

    本文翻译节选自1998-Efficient BackProp, Yann LeCun et al.. 4.1 随机VS批训练 每一次迭代, 传统训练方式都需要遍历所有数据集来计算平均梯度. 批训练也同 ...

  5. [Pytorch]PyTorch Dataloader自定义数据读取

    整理一下看到的自定义数据读取的方法,较好的有一下三篇文章, 其实自定义的方法就是把现有数据集的train和test分别用 含有图像路径与label的list返回就好了,所以需要根据数据集随机应变. 所 ...

  6. pytorch1.0批训练神经网络

    pytorch1.0批训练神经网络 import torch import torch.utils.data as Data # Torch 中提供了一种帮助整理数据结构的工具, 叫做 DataLoa ...

  7. libsvm的安装,数据格式,常见错误,grid.py参数选择,c-SVC过程,libsvm参数解释,svm训练数据,libsvm的使用详解,SVM核函数的选择

    直接conda install libsvm安装的不完整,缺几个.py文件. 第一种安装方法: 下载:http://www.csie.ntu.edu.tw/~cjlin/cgi-bin/libsvm. ...

  8. Alink漫谈(七) : 如何划分训练数据集和测试数据集

    Alink漫谈(七) : 如何划分训练数据集和测试数据集 目录 Alink漫谈(七) : 如何划分训练数据集和测试数据集 0x00 摘要 0x01 训练数据集和测试数据集 0x02 Alink示例代码 ...

  9. GitHub上YOLOv5开源代码的训练数据定义

    GitHub上YOLOv5开源代码的训练数据定义 代码地址:https://github.com/ultralytics/YOLOv5 训练数据定义地址:https://github.com/ultr ...

随机推荐

  1. java 多线程系列基础篇(四)之 synchronized关键字

    1. synchronized原理 在java中,每一个对象有且仅有一个同步锁.这也意味着,同步锁是依赖于对象而存在.当我们调用某对象的synchronized方法时,就获取了该对象的同步锁.例如,s ...

  2. 问题:C#控制台 停留;结果:c#控制台如何延时显示

    Thread.Sleep(毫秒数);//比如Thread.Sleep(2000)即为延时2秒需using System.Threading; 随笔5 - C#控制台窗口的显示与隐藏 1. 定义一个Co ...

  3. Delphi 转圈 原型进度条 AniIndicator 及线程配合使用

    Delphi FMX 转圈 原型进度条 progress AniIndicator TAniIndicator TFloatAnimation VCL下也有转圈菊花进度条 TActivityIndic ...

  4. Android Binder 系统学习笔记(一)Binder系统的基本使用方法

    1.什么是RPC(远程过程调用) Binder系统的目的是实现远程过程调用(RPC),即进程A去调用进程B的某个函数,它是在进程间通信(IPC)的基础上实现的.RPC的一个应用场景如下: A进程想去打 ...

  5. pthon爬虫(9)--Selenium的用法

    简介 Selenium 是什么?一句话,自动化测试工具.它支持各种浏览器,包括 Chrome,Safari,Firefox 等主流界面式浏览器,如果你在这些浏览器里面安装一个 Selenium 的插件 ...

  6. python爬虫(8)--Xpath语法与lxml库

    1.XPath语法 XPath 是一门在 XML 文档中查找信息的语言.XPath 可用来在 XML 文档中对元素和属性进行遍历.XPath 是 W3C XSLT 标准的主要元素,并且 XQuery ...

  7. python爬虫--编码问题y

    1)中文网站爬取下来的内容中文显示乱码 Python中文乱码是由于Python在解析网页时默认用Unicode去解析,而大多数网站是utf-8格式的,并且解析出来之后,python竟然再以Unicod ...

  8. spring分模块开发

  9. 一段PHP异常

    这是我写的一段代码,里面通过PHP异常功能,实现报错时显示出错代码所在行.当使用者操作出错时,截图给我,我可以很快得去追踪和排查错误! public function added_business_s ...

  10. 第4章_Java仿微信全栈高性能后台+移动客户端

    基于web端使用netty和websocket来做一个简单的聊天的小练习.实时通信有三种方式:Ajax轮询.Long pull.websocket,现在很多的业务场景,比方说聊天室.或者手机端onli ...