Pytorch之数据处理
使用TensorDataset和DataLoader来简化
- from torch.utils.data import TensorDataset
- from torch.utils.data import DataLoader
-
- train_ds = TensorDataset(x_train, y_train)
- train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
-
- valid_ds = TensorDataset(x_valid, y_valid)
- valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
- def get_data(train_ds, valid_ds, bs):
- return (
- DataLoader(train_ds, batch_size=bs, shuffle=True),
- DataLoader(valid_ds, batch_size=bs * 2),
- )
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
- import numpy as np
-
- def fit(steps, model, loss_func, opt, train_dl, valid_dl):
- for step in range(steps):
- model.train()
- for xb, yb in train_dl:
- loss_batch(model, loss_func, xb, yb, opt)
-
- model.eval()
- with torch.no_grad():
- losses, nums = zip(
- *[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
- )
- val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
- print('当前step:'+str(step), '验证集损失:'+str(val_loss))
- from torch import optim
- def get_model():
- model = Mnist_NN()
- return model, optim.SGD(model.parameters(), lr=0.001)
- def loss_batch(model, loss_func, xb, yb, opt=None):
- loss = loss_func(model(xb), yb)
-
- if opt is not None:
- loss.backward()
- opt.step()
- opt.zero_grad()
-
- return loss.item(), len(xb)
三行搞定!
- train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
- model, opt = get_model()
- fit(25, model, loss_func, opt, train_dl, valid_dl)
-
Pytorch之数据处理的更多相关文章
- 【深度学习框架】使用PyTorch进行数据处理
在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像.文本.语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果.本文以处理图像数据为例,记录一些使用PyTorc ...
- [源码解析] 机器学习参数服务器Paracel (3)------数据处理
[源码解析] 机器学习参数服务器Paracel (3)------数据处理 目录 [源码解析] 机器学习参数服务器Paracel (3)------数据处理 0x00 摘要 0x01 切分需要 1.1 ...
- JuJu Beta Postmortem
JuJu demo demo 项目github地址 JuJu 设想和目标 我们的软件要解决什么问题?是否定义得很清楚?是否对典型用户和典型场景有清晰的描述? 完成基于Julia语言的NER mod ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- Neural Network Programming - Deep Learning with PyTorch with deeplizard.
PyTorch Prerequisites - Syllabus for Neural Network Programming Series PyTorch先决条件 - 神经网络编程系列教学大纲 每个 ...
- 深度学习框架PyTorch一书的学习-第六章-实战指南
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...
- 深度学习框架PyTorch一书的学习-第五章-常用工具模块
https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- Pytorch 入门之Siamese网络
首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法: ...
随机推荐
- UICC Send USSD 加密信息解析
已以下指令为例 已知发送 USSD 的格式为:Send USSD aaaxxxxxxxxxxxx*yyyyyyy# ;Fetch:==:SEND USSD SEND 801200001F ASSERT ...
- 浅谈Python中的in,可能有你不知道的
Python中的in,没那么简单,虽然也不难 https://docs.python.org/zh-cn/3.9/reference/expressions.html#membership-test- ...
- 【读书笔记】JS函数式编程指南
第一章 海鸥群可以合并和繁育 conjoin breed var result = flock_a.conjoin(flock_c).breed(flock_b).conjoin(flo ck_a.b ...
- Cannot find module ‘xxx\node_modules\yorkie\bin\install.js‘
1.出现问题原因 安装一个新仓库代码的依赖包,如输入npm install或yarn install,出现如题错误 2.解决办法 1)升级node.js 下载地址:https://nodejs.org ...
- Windows 串口代码
#pragma once #include <Windows.h> #define DEFAULT_THREAD_TERMINATED_TIME 2000 class CAutoThrea ...
- 深度学习-RNN
目录 I.前言 介绍RNN的概念和应用 II. RNN基础 RNN的概念和结构 RNN的前向传播和反向传播算法 前向传播算法 反向传播 RNN的变种:LSTM和GRU LSTM GRU III. RN ...
- 使用一个文件集中管理你的 Nuget 依赖版本号
在 .net 7 以前,项目对于 nuget 依赖项的版本依赖散落与解决方案的各个角落.这导致升级维护和查看的时候都比较麻烦.在 .net 7 中,你可以使用一个文件来集中管理你的 Nuget 依赖版 ...
- 国内“谁”能实现chatgpt,短期穷出的类ChatGPT简评(算法侧角度为主),以及对MOSS、ChatYuan给出简评,一文带你深入了解宏观技术路线。
1.ChatGPT简介[核心技术.技术局限] ChatGPT(全名:Chat Generative Pre-trained Transformer),美国OpenAI 研发的聊天机器人程序 ,于202 ...
- ubuntu18.04 server版安装教程
转载博客园: Ubuntu18.04 Server版安装(详细版) - 运维密码 - 博客园 (cnblogs.com)
- IP转换
IP转换 目录 IP转换 1 127.1 ? 2 IPv4两段点分十进制表示 3 IPv4一段十进制表示 4 IPv4地址有效地变换形式 5 IP地址进制转换网站 6 参考博客 1 127.1 ? 首 ...