PyTorch学习笔记之DataLoaders
A DataLoader wraps a Dataset and provides minibatching, shuffling, multithreading, for you。
import torch
from torch.autograd import Variable
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader # define our whole model as a single Module
class TwoLayerNet(nn.Module):
# Initializer sets up two children (Modules can contain modules)
def _init_(self, D_in, H, D_out):
super(TwoLayerNet, self)._init_()
self.linear1 = torch.nn.Linear(D_in, H)
self.linear2 = torch.nn.Linear(H, D_out) # Define forward pass using child modules and autograd ops on Variables
# No need to define backward - autograd will handle it
def forward(self, x):
h_relu = self.linear1(x).clamp(min=0)
y_pred = self.linear2(h_relu)
return y_pred N, D_in, H, D_out = 64, 1000, 100, 10
x = Variable(torch.randn(N, D_in))
y = Variable(torch.randn(N, D_out)) # When you need to load custom data, just write your own Dataset class
loader = DataLoader(TensorDataset(x, y), batch_size=8) model = TwoLayerNet(D_in, H, D_out) criterion = torch.nn.MSELoss(size_average=False)
optimizer = torch.optim.SGD(model.parameters(), lr=1e-4)
for epoch in range(10):
# Iterate(遍历) over loader to form minibatches
for x_batch, y_batch in loader:
# Loader gives Tensors so you need to wrap in Variables
x_var, y_var = Variable(x), Variable(y)
y_pred = model(x_var)
loss = criterion(y_pred, y_var) optimizer.zero_grad()
loss.backward()
optimizer.step()
PyTorch学习笔记之DataLoaders的更多相关文章
- Pytorch学习笔记(二)---- 神经网络搭建
记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...
- Pytorch学习笔记(一)---- 基础语法
书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...
- 【pytorch】pytorch学习笔记(一)
原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
- Pytorch学习笔记(一)——简介
一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...
- [PyTorch 学习笔记] 1.3 张量操作与线性回归
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/linear_regression.py 张量的操作 拼 ...
- [PyTorch 学习笔记] 1.1 PyTorch 简介与安装
PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...
- [PyTorch 学习笔记] 1.4 计算图与动态图机制
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/computational_graph.py 计算图 深 ...
- [PyTorch 学习笔记] 2.2 图片预处理 transforms 模块机制
PyTorch 的数据增强 我们在安装PyTorch时,还安装了torchvision,这是一个计算机视觉工具包.有 3 个主要的模块: torchvision.transforms: 里面包括常用的 ...
随机推荐
- perl-basic-数据类型&引用
我觉得这一系列的标题应该是:PERL,从入门到放弃 USE IT OR U WILL LOSE IT 参考资料: https://qntm.org/files/perl/perl.html 在线per ...
- 【HIHOCODER 1403】后缀数组一·重复旋律(后缀数组)
描述 小Hi平时的一大兴趣爱好就是演奏钢琴.我们知道一个音乐旋律被表示为长度为 N 的数构成的数列. 小Hi在练习过很多曲子以后发现很多作品自身包含一样的旋律.旋律是一段连续的数列,相似的旋律在原数列 ...
- poj 3281 Dining(网络流+拆点)
Dining Time Limit: 2000MS Memory Limit: 65536K Total Submissions: 20052 Accepted: 8915 Descripti ...
- BZOJ 5215: [Lydsy2017省队十连测]商店购物
裸题 注意+特判 #include<cstdio> using namespace std; const int mod=1e9+7; int F[1000005],mi[10000005 ...
- Linux中TTY是什么意思
终端是一种字符型设备,它有多种类型,通常使用tty来简称各种类型的终端设备.tty是Teletype的缩写.Teletype是最早出现的一种终端 设备,很象电传打字机(或者说就是),是由Telety ...
- Baum-Welch算法(EM算法)对HMM模型的训练
Baum-Welch算法就是EM算法,所以首先给出EM算法的Q函数 \[\sum_zP(Z|Y,\theta')\log P(Y,Z|\theta)\] 换成HMM里面的记号便于理解 \[Q(\lam ...
- LiveScript 操作符
The LiveScript Book The LiveScript Book 操作符 数字 标准的数学操作符: 1.1 + 2 # => 32.3 - 4 # => -13.6 ...
- Multi-Dimensional Recurrent Neural Networks
Multi-Dimensional Recurrent Neural Networks The basic idea of MDRNNs is to replace the single recurr ...
- Codecraft-18 and Codeforces Round #458 (Div. 1 + Div. 2, combined)
我真的是太菜了 A. Perfect Squares time limit per test 1 second memory limit per test 256 megabytes input st ...
- Python日志记录(Logging)
日志记录跟程序的测试相关,并且在大幅度更改程序内核时很有用,它可以帮助我们找到问题和错误的所在.日志记录基本上就是收集与程序运行有关的数据,这样可以在随后进行检查或者累计数据. 1.简单示例 在Pyt ...