《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

Dataset & Dataloader

1、Dataset & Dataloader作用

※Dataset—加载数据集,用索引的方式取数

※DataLoader—Mini-Batch

通过获得DataSet的索引以及数据集大小,来自动得生成小批量训练集

DataLoader先对数据集进行Shuffle,再将数据集按照Batch_Size的长度划分为小的Batch,并按照Iterations进行加载,以方便通过循环对每个Batch进行操作

Shuffle=True:随机打乱顺序

2、Mini-Batch

利用Mini-Batch均衡训练性能和时间

在外层循环中,每一层是一个epoch(训练周期),在内层循环中,每一次是一个Mini-Batch(Batch的迭代)

for epoch in range(training_epochs):
for i in range(total_batch):

3、相关术语

※Epoch:所有样本都参与了一次训练

※Batch-size:进行一次训练(前馈、反馈、更新)的样本数

※Iteration:有多少个Batch,每次

Epoch = Batch-size * Iteration

 4、代码部分

在构造数据集时,两种对数据加载到内存中的处理方式如下:

①加载所有数据到dataset,每次使用getitem()读索引,适用于数据量小的情况

②只对dataset进行初始化,仅存文件名到列表,每次使用时再通过索引到内存中去读取,适用于数据量大(图像、语音…)的情况

import torch
import numpy as np
## Dataset为抽象类,不能被实例化,只能被其他子类继承
from torch.utils.data import Dataset
## 实例化DataLoader,用于加载数据
from torch.utils.data import DataLoader ## Prepare Data
class DiabetesDataset(Dataset):
def __init__(self, filepath):
xy = np.loadtxt(filepath, delimiter=',', dtype=np.float32)
## 获取数据集长度
self.len = xy.shape[0]
self.x_data = torch.from_numpy(xy[:, :-1])
self.y_data = torch.from_numpy(xy[:, [-1]]) ## 索引:下标操作
def __getitem__(self, index):
return self.x_data[index], self.y_data[index] ## 返回数据量
def __len__(self):
return self.len dataset = DiabetesDataset('diabetes.csv.gz') ##num_workers多线程
train_loader = DataLoader(dataset = dataset,
batch_size = 32,
shuffle = True,
num_workers = 0) ##Design Model ##构造类,继承torch.nn.Module类
class Model(torch.nn.Module):
## 构造函数,初始化对象
def __init__(self):
##super调用父类
super(Model, self).__init__()
##构造三层神经网络
self.linear1 = torch.nn.Linear(8, 6)
self.linear2 = torch.nn.Linear(6, 4)
self.linear3 = torch.nn.Linear(4, 1)
##激活函数,进行非线性变换
self.sigmoid = torch.nn.Sigmoid() ## 构造函数,前馈运算
def forward(self, x):
x = self.sigmoid(self.linear1(x))
x = self.sigmoid(self.linear2(x))
x = self.sigmoid(self.linear3(x))
return x model = Model() ##Construct Loss and Optimizer ##损失函数,传入y和y_pred,size_average--是否取平均
criterion = torch.nn.BCELoss(size_average = True) ##优化器,model.parameters()找出模型所有的参数,Lr--学习率
optimizer = torch.optim.SGD(model.parameters(), lr=0.01) ## Training cycle
## windows环境下DataLoader的num_workers设置为多线程,需要将主程序(对数据操作的程序)封装到函数中
if __name__ =='__main__':
for epoch in range(100):
#enumerate:可获得当前迭代的次数
for i, data in enumerate(train_loader, 0):
## 准备数据
inputs, lables = data
##前向传播
y_pred = model(inputs)
loss = criterion(y_pred, lables)
print(epoch, i, loss.item()) ##梯度归零
optimizer.zero_grad()
##反向传播
loss.backward()
##更新
optimizer.step()

!!两个问题!!

①DataLoader的参数num_workers设置 >0

在windows中利用多线程读取,需要将主程序(对数据操作的程序)封装到函数中

## Training cycle
## windows环境下DataLoader的num_workers设置为多线程,需要将主程序(对数据操作的程序)封装到函数中
if __name__ =='__main__':
for epoch in range(100):
#enumerate:可获得当前迭代的次数
for i, data in enumerate(train_loader, 0):

但是运行还是报错,只能把num_workers = 0

②运行结果:损失不会一直下降,改小了学习率也不行

Pytorch实战学习(四):加载数据集的更多相关文章

  1. pytorch 加载数据集

    pytorch初学者,想加载自己的数据,了解了一下数据类型.维度等信息,方便以后加载其他数据. 1 torchvision.transforms实现数据预处理 transforms.Totensor( ...

  2. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  3. [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler

    [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampler 目录 [源码解析] PyTorch 分布式(1) --- 数据加载之DistributedSampl ...

  4. [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader

    [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 目录 [源码解析] PyTorch 分布式(2) --- 数据加载之DataLoader 0x00 摘要 0x01 ...

  5. cesium 学习(五) 加载场景模型

    cesium 学习(五) 加载场景模型 一.前言 现在开始实际的看看效果,目前我所接触到基本上都是使用Cesium加载模型这个内容,以及在模型上进行操作.So,现在进行一些加载模型的学习,数据的话可以 ...

  6. SciKit-Learn 加载数据集

    章节 SciKit-Learn 加载数据集 SciKit-Learn 数据集基本信息 SciKit-Learn 使用matplotlib可视化数据 SciKit-Learn 可视化数据:主成分分析(P ...

  7. 【 js 模块加载 】深入学习模块化加载(node.js 模块源码)

    一.模块规范 说到模块化加载,就不得先说一说模块规范.模块规范是用来约束每个模块,让其必须按照一定的格式编写.AMD,CMD,CommonJS 是目前最常用的三种模块化书写规范.  1.AMD(Asy ...

  8. 【 js 模块加载 】【源码学习】深入学习模块化加载(node.js 模块源码)

    文章提纲: 第一部分:介绍模块规范及之间区别 第二部分:以 node.js 实现模块化规范 源码,深入学习. 一.模块规范 说到模块化加载,就不得先说一说模块规范.模块规范是用来约束每个模块,让其必须 ...

  9. WebGL three.js学习笔记 加载外部模型以及Tween.js动画

    WebGL three.js学习笔记 加载外部模型以及Tween.js动画 本文的程序实现了加载外部stl格式的模型,以及学习了如何把加载的模型变为一个粒子系统,并使用Tween.js对该粒子系统进行 ...

  10. Springboot学习01- 配置文件加载优先顺序和本地配置加载

    Springboot学习01-配置文件加载优先顺序和本地配置加载 1-项目内部配置文件加载优先顺序 spring boot 启动会扫描以下位置的application.properties或者appl ...

随机推荐

  1. allure环境搭建

    allure环境搭建 在搭建之前你应该有python.pycharm jdk也需要(文中忽略,可以参考网上文档安装,可以用jdk1.8) 以windows为例,mac.linux你用到这些操作系统,这 ...

  2. GitHub实用开源项目

    第一款 JSON Crack JSON Crack 是一个很方便的 JSON 数据可视化工具. 该项目不是简单的展示 JSON 数据,而是将其转化为类似思维导图的形式,支持放大/缩小.展开/收缩.搜索 ...

  3. 加密,各种加密,耙梳加密算法(Encryption)种类以及开发场景中的运用(Python3.10)

    不用说火爆一时,全网热议的Web3.0区块链技术,也不必说诸如微信支付.支付宝支付等人们几乎每天都要使用的线上支付业务,单是一个简简单单的注册/登录功能,也和加密技术脱不了干系,本次我们耙梳各种经典的 ...

  4. 分布式共识算法随笔 —— 从 Quorum 到 Paxos

    分布式共识算法随笔 -- 从 Quorum 到 Paxos 本文主要参考各类英文文献,部分专业术语翻译较为生硬,望谅解. 概览: 为什么需要共识算法? 昨夜西风凋碧树,独上高楼,望尽天涯路 复制(Re ...

  5. K8s 网络新手教程(Kubernetes Networking Guide for Beginners)

    K8s 网络新手教程(Kubernetes Networking Guide for Beginners) 原文链接: Kubernetes Networking Guide for Beginner ...

  6. 插头dp 模板

    [JLOI2009]神秘的生物 只需要维护连通情况,采用最小表示法,表示此格是否存在,也即插头是否存在 分情况讨论当前格子的轮廓线上方格子和左方格子状态,转移考虑当前格子选不选,决策后状态最后要能合法 ...

  7. 微信小程序的全局弹窗以及全局实例

    全局组件 微信小程序组件关系中,父组件使用子组件需要在父组件index.json中引入子组件,然后在父组件页面中使用,这种组件的对应状态是一对一的,一个组件对应一个页面.如果有一个全局弹窗(登录),那 ...

  8. 使用Shapefile-js读取shp文件并使用WebGL绘制

    1. 引言 坐标数据是空间数据文件的核心,空间数据的数据量往往是很大的.数据可视化是GIS的一个核心应用,绘制海量的坐标数据始终是一个考验设备性能的难题,使用GPU进行绘制可有效减少CPU的负载,提升 ...

  9. 在Django中显示MySQL语句

    在setting中添加以下内容 LOGGING = { 'version': 1, 'disable_existing_loggers': False, 'handlers': { 'console' ...

  10. JDK8在xp安装办法

    jdk默认不支持xp,用到了后来的api,直接没法安装. 具体安装步骤: 1.用7-zip打开jdk8的版本8u231之前的版本. 2.导航到 .rsrc\1033\JAVA_CAB10\111把里面 ...