使用TensorDataset和DataLoader来简化

 
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
 
def get_data(train_ds, valid_ds, bs):
    return (
        DataLoader(train_ds, batch_size=bs, shuffle=True),
        DataLoader(valid_ds, batch_size=bs * 2),
    )
 
 
 
  • 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
  • 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
    for step in range(steps):
        model.train()
        for xb, yb in train_dl:
            loss_batch(model, loss_func, xb, yb, opt)
        model.eval()
        with torch.no_grad():
            losses, nums = zip(
                *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
            )
        val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
        print('当前step:'+str(step), '验证集损失:'+str(val_loss))
 
 
from torch import optim
def get_model():
    model = Mnist_NN()
    return model, optim.SGD(model.parameters(), lr=0.001)
 
 
def loss_batch(model, loss_func, xb, yb, opt=None):
    loss = loss_func(model(xb), yb)
    if opt is not None:
        loss.backward()
        opt.step()
        opt.zero_grad()
    return loss.item(), len(xb)
 
 
 

三行搞定!

train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
 
 
 
 
 

Pytorch之数据处理的更多相关文章

  1. 【深度学习框架】使用PyTorch进行数据处理

      在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像.文本.语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果.本文以处理图像数据为例,记录一些使用PyTorc ...

  2. [源码解析] 机器学习参数服务器Paracel (3)------数据处理

    [源码解析] 机器学习参数服务器Paracel (3)------数据处理 目录 [源码解析] 机器学习参数服务器Paracel (3)------数据处理 0x00 摘要 0x01 切分需要 1.1 ...

  3. JuJu Beta Postmortem

    JuJu demo demo 项目github地址 JuJu   设想和目标 我们的软件要解决什么问题?是否定义得很清楚?是否对典型用户和典型场景有清晰的描述? 完成基于Julia语言的NER mod ...

  4. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  5. Neural Network Programming - Deep Learning with PyTorch with deeplizard.

    PyTorch Prerequisites - Syllabus for Neural Network Programming Series PyTorch先决条件 - 神经网络编程系列教学大纲 每个 ...

  6. 深度学习框架PyTorch一书的学习-第六章-实战指南

    参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...

  7. 深度学习框架PyTorch一书的学习-第五章-常用工具模块

    https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...

  8. 深度学习之PyTorch实战(3)——实战手写数字识别

    上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...

  9. 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化

    上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...

  10. Pytorch 入门之Siamese网络

    首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法:   ...

随机推荐

  1. JSP第八次作业

    数据库test 中建个表 stu(stuid 主键 自动增长 ,用户名,密码,年龄) 1.设计一个注册页面,实现用户注册功能2.设计一个登陆页面,实现用户名密码登陆3.两个页面可以互相超链接 1 pa ...

  2. python学习第三周总结

    文件操作 文件的读写模式 文件的操作模式 文件相关操作 文件内光标移动 文件内容修改 函数前戏 函数的语法结构 函数的定义和调用 函数的分类 函数的返回值 函数的参数 函数参数之位置参数 默认参数 可 ...

  3. 【学习日志】MongoDB为什么选择B树,而MySQL选择B+树实现索引

    先说B树和B+树的区别 B树:非叶子节点也存储数据 B+树:只有叶子节点存储数据,且所有叶子节点通过指针相连接. 为什么MongoDB选择B树而,MySQL选择B+树呢?两种数据结构的区别摆在上面了, ...

  4. Sentinel熔断与限流

    1.简介 在线文档: https://sentinelguard.io/zh-cn/docs/system-adaptive-protection.html 功能: 流量控制 速率控制 熔断和限流 和 ...

  5. 【vite】踩坑,首次点击路由跳转页面,发生回退,页面闪回,二次点击才能进入目标页面

    [vite]踩坑,首次点击路由跳转页面,发生回退,页面闪回,二次点击才能进入目标页面 最近在做移动端前端项目,使用的vite3+vue3+vant,组件和api挂载,使用的自动导入,unplugin- ...

  6. 记一次 .NET 某医保平台 CPU 爆高分析

    一:背景 1. 讲故事 一直在追这个系列的朋友应该能感受到,我给这个行业中无数的陌生人分析过各种dump,终于在上周有位老同学找到我,还是个大妹子,必须有求必应 . 妹子公司的系统最近在某次升级之后, ...

  7. [学习笔记]SQL server完全备份指南

    方式一,使用SQL Server Management Studio 准备工作 连接目标数据库服务器 在目标数据库上右键->属性,将数据库的恢复模式设置为"简单",兼容级别设 ...

  8. Java ”框架 = 注解 + 反射 + 设计模式“ 之 注解详解

    Java "框架 = 注解 + 反射 + 设计模式" 之 注解详解 每博一文案 刹那间我真想令时光停住,好让我回顾自己,回顾失去的年华,缅怀哪个穿一身短小的连衣裙 和瘦窄的短衫的小 ...

  9. Autoit 制作上传工具完美版

    一. 制作上传器 在ui自动化过程中经常遇到需要上传的动作,我们可以使用input标签来送值,但这样不太稳定,所以建议使用autoit制作出来的exe工具. 下面就教大家如何制作上传器,如何使用吧! ...

  10. HDMI转USB视频采集卡(ACASIS 1080P高清视频采集卡)--九五小庞

    ACASIS阿卡西斯是深圳市菲德越科技有限公司旗下数码科技品牌.菲德越是2008年成立的一家专注于采集卡.硬盘盒.集线器等专业3C配件产品,集研发.设计.生产.销售于一体的高新科技公司,我们公司以向客 ...