import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # 配置GPU或CPU设置
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 超参数设置
sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 100
num_epochs = 2
learning_rate = 0.01 # MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data/',
train=True,
transform=transforms.ToTensor(),# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255
download=True) test_dataset = torchvision.datasets.MNIST(root='./data/',
train=False,
transform=transforms.ToTensor())# 将PIL Image或者 ndarray 转换为tensor,并且归一化至[0-1],归一化至[0-1]是直接除以255 # 训练数据加载,按照batch_size大小加载,并随机打乱
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 测试数据加载,按照batch_size大小加载
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Recurrent neural network (many-to-one) 多对一
class RNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(RNN, self).__init__() # 继承 __init__ 功能
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) # if use nn.RNN(), it hardly learns LSTM 效果要比 nn.RNN() 好多了
self.fc = nn.Linear(hidden_size, num_classes) def forward(self, x):
# Set initial hidden and cell states
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size) # Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out model = RNN(input_size, hidden_size, num_layers, num_classes).to(device)
print(model)
# RNN((lstm): LSTM(28, 128, num_layers=2, batch_first=True)
# (fc): Linear(in_features=128, out_features=10, bias=True)) # 损失函数与优化器设置
# 损失函数
criterion = nn.CrossEntropyLoss()
# 优化器设置 ,并传入RNN模型参数和相应的学习率
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # 训练模型
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device) # 前向传播
outputs = model(images)
# 计算损失 loss
loss = criterion(outputs, labels) # 反向传播与优化
# 清空上一步的残余更新参数值
optimizer.zero_grad()
# 反向传播
loss.backward()
# 将参数更新值施加到RNN model的parameters上
optimizer.step()
# 每迭代一定步骤,打印结果值
if (i + 1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item())) # 测试模型
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, sequence_length, input_size).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item() print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total)) # 保存已经训练好的模型
# Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

  

Recurrent neural network (RNN) - Pytorch版的更多相关文章

  1. Convolutional neural network (CNN) - Pytorch版

    import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms # ...

  2. Recurrent Neural Network(循环神经网络)

    Reference:   Alex Graves的[Supervised Sequence Labelling with RecurrentNeural Networks] Alex是RNN最著名变种 ...

  3. 机器学习: Python with Recurrent Neural Network

    之前我们介绍了Recurrent neural network (RNN) 的原理: http://blog.csdn.net/matrix_space/article/details/5337404 ...

  4. Recurrent Neural Network系列2--利用Python,Theano实现RNN

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  5. Recurrent Neural Network系列3--理解RNN的BPTT算法和梯度消失

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 这是RNN教程的第三部分. 在前面的教程中,我们从头实现了一个循环 ...

  6. 循环神经网络(Recurrent Neural Network,RNN)

    为什么使用序列模型(sequence model)?标准的全连接神经网络(fully connected neural network)处理序列会有两个问题:1)全连接神经网络输入层和输出层长度固定, ...

  7. 4.5 RNN循环神经网络(recurrent neural network)

     自己开发了一个股票智能分析软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  RNN循环神经网络 ...

  8. Recurrent Neural Network系列1--RNN(循环神经网络)概述

    作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORKS T ...

  9. Recurrent Neural Network系列4--利用Python,Theano实现GRU或LSTM

    yi作者:zhbzz2007 出处:http://www.cnblogs.com/zhbzz2007 欢迎转载,也请保留这段声明.谢谢! 本文翻译自 RECURRENT NEURAL NETWORK ...

随机推荐

  1. Deepin Create/Delete Folder refresh

    Did u have a problem whth the deepin file manager,Everthime I create/delete a Folder of File i have ...

  2. Ruby on Rails框架(1)-安装全攻略

    序 关于Rails的三句箴言 (1)DRY:Don't Repeat Yourself(不要重复你自己) rails的开发理念,不要用你的代码不停的重复,rails框架给开发者提供了一套非常完善的支持 ...

  3. datagrid其中某列需要动态隐藏或显示的mvvm绑定方式,也可以用在其他表格类型控件上

    版权归原作者所有. 引用地址 [WPF] HOW TO BIND TO DATA WHEN THE DATACONTEXT IS NOT INHERITED MARCH 21, 2011 THOMAS ...

  4. MYSQL | ERROR 1305(42000) SAVEPOINT *** DOES NOT EXIST

    autocommit模式:在开启情况下,对于每条statement来说,都会自动形成一个commit,也就是会即时对开始和结束一个事务.所以,当出现rollback to savepoint出现这个错 ...

  5. Firefox disable search in the address bar

    disable search in the address bar Hi oitconz, setting keyword.enabled to false prevents Firefox from ...

  6. Unity编辑器环境在Inspector面板中显示变量

    Serialize功能Unity3D 中提供了非常方便的功能可以帮助用户将 成员变量 在Inspector中显示,并且定义Serialize关系. 简单的说,在没有自定义Inspector的情况下所有 ...

  7. JS构造函数中有return

    function foo(name) { this.name = name; return name } console.log(new foo('光何')) function bar(name) { ...

  8. matplotlib显示黑白灰度图像颜色设置

    对于黑白灰度图像(矩阵) 1. 默认使用伪彩色拉升 2 cmap参数为 binary,可能导致颜色反转 3. cmap = gray,same color as origin, that is, wh ...

  9. 数据包分析中Drop和iDrop的区别

    数据包分析中Drop和iDrop的区别   在数据包分析中,Drop表示因为过滤丢弃的包.为了区分发送和接受环节的过滤丢弃,把Drop又分为iDrop和Drop.其中,iDrop表示接受环节丢弃的包, ...

  10. Vue-cli项目结构讲解

    |-- build // 项目构建(webpack)相关代码 | |-- build.js // 生产环境构建代码 | |-- check-version.js // 检查node.npm等版本 | ...