Pytorch之数据处理
使用TensorDataset和DataLoader来简化
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
train_ds = TensorDataset(x_train, y_train)
train_dl = DataLoader(train_ds, batch_size=bs, shuffle=True)
valid_ds = TensorDataset(x_valid, y_valid)
valid_dl = DataLoader(valid_ds, batch_size=bs * 2)
def get_data(train_ds, valid_ds, bs):
return (
DataLoader(train_ds, batch_size=bs, shuffle=True),
DataLoader(valid_ds, batch_size=bs * 2),
)
- 一般在训练模型时加上model.train(),这样会正常使用Batch Normalization和 Dropout
- 测试的时候一般选择model.eval(),这样就不会使用Batch Normalization和 Dropout
import numpy as np
def fit(steps, model, loss_func, opt, train_dl, valid_dl):
for step in range(steps):
model.train()
for xb, yb in train_dl:
loss_batch(model, loss_func, xb, yb, opt)
model.eval()
with torch.no_grad():
losses, nums = zip(
*[loss_batch(model, loss_func, xb, yb) for xb, yb in valid_dl]
)
val_loss = np.sum(np.multiply(losses, nums)) / np.sum(nums)
print('当前step:'+str(step), '验证集损失:'+str(val_loss))
from torch import optim
def get_model():
model = Mnist_NN()
return model, optim.SGD(model.parameters(), lr=0.001)
def loss_batch(model, loss_func, xb, yb, opt=None):
loss = loss_func(model(xb), yb)
if opt is not None:
loss.backward()
opt.step()
opt.zero_grad()
return loss.item(), len(xb)
三行搞定!
train_dl, valid_dl = get_data(train_ds, valid_ds, bs)
model, opt = get_model()
fit(25, model, loss_func, opt, train_dl, valid_dl)
Pytorch之数据处理的更多相关文章
- 【深度学习框架】使用PyTorch进行数据处理
在深度学习中,数据的处理对于神经网络的训练来说十分重要,良好的数据(包括图像.文本.语音等)处理不仅可以加速模型的训练,同时也直接关系到模型的效果.本文以处理图像数据为例,记录一些使用PyTorc ...
- [源码解析] 机器学习参数服务器Paracel (3)------数据处理
[源码解析] 机器学习参数服务器Paracel (3)------数据处理 目录 [源码解析] 机器学习参数服务器Paracel (3)------数据处理 0x00 摘要 0x01 切分需要 1.1 ...
- JuJu Beta Postmortem
JuJu demo demo 项目github地址 JuJu 设想和目标 我们的软件要解决什么问题?是否定义得很清楚?是否对典型用户和典型场景有清晰的描述? 完成基于Julia语言的NER mod ...
- 【转载】PyTorch系列 (二):pytorch数据读取
原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...
- Neural Network Programming - Deep Learning with PyTorch with deeplizard.
PyTorch Prerequisites - Syllabus for Neural Network Programming Series PyTorch先决条件 - 神经网络编程系列教学大纲 每个 ...
- 深度学习框架PyTorch一书的学习-第六章-实战指南
参考:https://github.com/chenyuntc/pytorch-book/tree/v1.0/chapter6-实战指南 希望大家直接到上面的网址去查看代码,下面是本人的笔记 将上面地 ...
- 深度学习框架PyTorch一书的学习-第五章-常用工具模块
https://github.com/chenyuntc/pytorch-book/blob/v1.0/chapter5-常用工具/chapter5.ipynb 希望大家直接到上面的网址去查看代码,下 ...
- 深度学习之PyTorch实战(3)——实战手写数字识别
上一节,我们已经学会了基于PyTorch深度学习框架高效,快捷的搭建一个神经网络,并对模型进行训练和对参数进行优化的方法,接下来让我们牛刀小试,基于PyTorch框架使用神经网络来解决一个关于手写数字 ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- Pytorch 入门之Siamese网络
首次体验Pytorch,本文参考于:github and PyTorch 中文网人脸相似度对比 本文主要熟悉Pytorch大致流程,修改了读取数据部分.没有采用原作者的ImageFolder方法: ...
随机推荐
- MySQL优化六,锁
一,MySQL中的锁 InnoDB中锁非常多,总的来说,可以如下分类: 这些锁都是做什么的?具体含义是什么?我们现在来一一学习. 1.2,解决并发事务问题 我们已经知道事务并发执行时可能带来的各种问题 ...
- vue3 | shallowReactive 、shallowRef、triggerRef
shallowReactive 使用 reactive 声明的变量为递归监听,使用 shallowReactive 声明的变量为非递归监听(通俗的讲就是 reactive 创建的对象将会被 vue 内 ...
- 11月28日内容总结——多表查询的两种方法及部分小知识点、可视化软件Navicat安装及简单使用讲解及多表查询练习题、python代码操作MySQL(pymysql模块)
目录 一.多表查询的两种方法 方式1:连表操作 inner join(内连接) left join(左连接) right join(右连接) union(全连接) 方式2:子查询 二.小知识点补充说明 ...
- MySQL的简单安装配置
一.简单了解MySQL 1.在了解MySQL之前因该了解的东西 数据库(Database)指长期存储在计算机内的.有组织的.可共享的数据集合.数据库实际上就是一个文件集合,是一个存储数据的仓库,本质就 ...
- LM算法详解
1. 高斯牛顿法 残差函数f(x)为非线性函数,对其一阶泰勒近似有: 这里的J是残差函数f的雅可比矩阵,带入损失函数的: 令其一阶导等于0,得: 这就是论文里常看到的normal equation. ...
- 字符串、函数、bug
字符串 字符串驻留机制 仅保存一份相同且不可变字符串的方法,不同的值别存放在字符串的驻留池中,Python的驻留机制对相同的字符串只保留一份拷贝,后续穿件相同的字符串时,不会开辟新的空间,而是把字符串 ...
- 在react项目如何捕获错误
在React项目是如何捕获错误的? 一.是什么 错误在我们日常编写代码是非常常见的 举个例子,在react项目中去编写组件内JavaScript代码错误会导致 React 的内部状态被破坏,导致整个应 ...
- Vue组件之间的通信方式都有哪些?
一.组件间通信的概念 我们通常把组件间通信这个词进行拆分 组件 通信 都知道组件是vue最强大的功能之一,vue中每一个.vue我们都可以视之为一个组件通信指的是发送者通过某种媒体以某种格式来传递信息 ...
- Redis入门级简单安装使用
最近突然就想学一下Redis,于是就各种找教程,前两天实际操作了一下,也不是想象中的很难 但是今天想写一个使用Redis的demo,突然就不会使用Redis了,在网上也是查找了半天,还是想起来了点 ...
- Unix时间戳转化成普通日期
var time = 1630634462000; //13位数 var unixTimestamp = new Date(time); var commonTime = unixTimestamp. ...