一、知识点:

  • 相关包:torch.utils.data

import torch
import torch.utils.data as Data
  • 包装数据类:TensorDataset

【包装数据和目标张量的数据集,通过沿着第一个维度索引两个张量来】

class torch.utils.data.TensorDataset(data_tensor, target_tensor)
#data_tensor (Tensor) - 包含样本数据
#target_tensor (Tensor) - 包含样本目标(标签)
  • 加载数据类:DataLoader

【数据加载器。组合数据集和采样器,并在数据集上提供单进程或多进程迭代器。】

class torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
#num_workers (int, optional) – 用多少个子进程加载数据
#drop_last (bool, optional) – 如果数据集大小不能被batch size整除,则设置为True后可删除最后一个不完整的batch。如果设为False并且数据集的大小不能被batch size整除,则最后一个batch将更小。(默认: False)

二、利用torch.utils.data进行批数据训练:

导入包:

import torch
import torch.utils.data as Data

设置参数并创建数据:

Batch_size = 5

x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

将数据包装到TensorDataset中:

torch_dataset = Data.TensorDataset(x , y)

加载数据:

loader = Data.DataLoader(
dataset = torch_dataset,
batch_size = Batch_size,
shuffle=True,
num_workers = 2, #采用两个进程来提取
)

epoch 3次,每次epoch的训练步数steps = 2【batch_size = 5,总数据量为10】:

若最后不够一个batch_size,就只拿剩下的。

for epoch in range(3):
for step , (batch_x,batch_y) in enumerate(loader):
#training……
print('epoch:',epoch,
'| step:',step,
'| batch_x:',batch_x.numpy(),
'| batch_y:',batch_y.numpy()
)

结果:

Pytorch基础(5)——批数据训练的更多相关文章

  1. pytorch中tensor张量数据基础入门

    pytorch张量数据类型入门1.对于pytorch的深度学习框架,其基本的数据类型属于张量数据类型,即Tensor数据类型,对于python里面的int,float,int array,flaot ...

  2. Pytorch基础——使用 RNN 生成简单序列

    一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...

  3. PyTorch基础——使用卷积神经网络识别手写数字

    一.介绍 实验内容 内容包括用 PyTorch 来实现一个卷积神经网络,从而实现手写数字识别任务. 除此之外,还对卷积神经网络的卷积核.特征图等进行了分析,引出了过滤器的概念,并简单示了卷积神经网络的 ...

  4. 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)

    本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch  (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...

  5. [人工智能]Pytorch基础

    PyTorch基础 摘抄自<深度学习之Pytorch>. Tensor(张量) PyTorch里面处理的最基本的操作对象就是Tensor,表示的是一个多维矩阵,比如零维矩阵就是一个点,一维 ...

  6. 【新生学习】第一周:深度学习及pytorch基础

    DEADLINE: 2020-07-25 22:00 写在最前面: 本课程的主要思路还是要求大家大量练习 pytorch 代码,在写代码的过程中掌握深度学习的各类算法,希望大家能够坚持练习,相信经度过 ...

  7. Faster-RCNN 自己的数据训练

    参考网址:https://blog.csdn.net/l297969586/article/category/7178545(一呆飞仙)Faster-RCNN_TF代码解读,参考网址:https:// ...

  8. Hadoop基础-MapReduce的数据倾斜解决方案

    Hadoop基础-MapReduce的数据倾斜解决方案 作者:尹正杰 版权声明:原创作品,谢绝转载!否则将追究法律责任. 一.数据倾斜简介 1>.什么是数据倾斜 答:大量数据涌入到某一节点,导致 ...

  9. PyTorch 数据集类 和 数据加载类 的一些尝试

    最近在学习PyTorch,  但是对里面的数据类和数据加载类比较迷糊,可能是封装的太好大部分情况下是不需要有什么自己的操作的,不过偶然遇到一些自己导入的数据时就会遇到一些问题,因此自己对此做了一些小实 ...

随机推荐

  1. 网站配置https(腾讯云域名操作)

    我们都知道http协议是超文本传输协议,早期的网站使用的都是http,但是并不安全,数据在传输过程中容易被拦截篡改.所以后面有了https,也就是经过ssl加密的http协议.本文主要对网站配置htt ...

  2. N天学习一个Linux命令之free

    用途 查看系统内存(物理/虚拟/缓存/共享)使用情况 用法 free [-b | -k | -m | -g | -h] [-o] [-s delay ] [-c count ] [-a] [-t] [ ...

  3. 多个机器获取微信access-token导致的有效性问题

    多个机器获取微信access-token导致的有效性问题 单个机器获取的access-token,只有最后一个是有效的: 多个机器各自获取自己的access-token,都是各自有效的: 在服务器和本 ...

  4. BNU 13259.Story of Tomisu Ghost 分解质因子

    Story of Tomisu Ghost It is now 2150 AD and problem-setters are having a horrified time as the ghost ...

  5. 93.EXTJS Form之VTypes

    转自:http://blog.sina.com.cn/s/blog_7778950d0100y2pg.html 本文我们主要探讨一下EXTJS的Form中验证的问题,可能用过EXTJS的Form的人都 ...

  6. Python入门 不必自己造轮子

    操作list list切片 字符串的分割 字符串的索引和切片 读文件 f = file('data.txt') data = f.read() print data f.close() 写文件 dat ...

  7. C++ 对象的赋值和复制 基本的

    对象的赋值 如果对一个类定义了两个或多个对象,则这些对象之间是可以进行赋值,或者说,一个对象的值可以赋值给另一个同类的对象.这里所指的值是指对象中所有数       据的成员的值.对象之间进行赋值是“ ...

  8. Watchcow(欧拉回路)

    http://poj.org/problem?id=2230 题意:给出n个field及m个连接field的边,然后要求遍历每条边仅且2次,求出一条路径来. #include <stdio.h& ...

  9. [Apple开发者帐户帮助]五、管理标识符(5)创建一个iCloud容器

    您必须拥有一个或多个iCloud容器才能启用iCloud. 所需角色:帐户持有人或管理员. 在“ 证书”,“标识符和配置文件”中,从左侧的弹出菜单中选择操作系统. 在“标识符”下,选择“iCloud容 ...

  10. Oracle配置说明

    当Oracle安装完成后,为后续能够顺利得导出空表,特做一下配置(重点关注2.1) 1.1.查询空表select table_name from user_tables where NUM_ROWS= ...