关于概念:

  BRNN连接两个相反的隐藏层到同一个输出.基于生成性深度学习,输出层能够同时的从前向和后向接收信息.该架构是1997年被Schuster和Paliwal提出的.引入BRNNS是为了增加网络所用的输入信息量.例如,多层感知机(MLPS)和延时神经网络(TDNNS)在输入数据的灵活性方面是非常有局限性的.因为他们需要输入的数据是固定的.标准的递归神静

网络也有局限,就是将来的数据数据不能用现在状态来表达.BRNN恰好能够弥补他们的劣势.它不需要输入的数据固定,与此同时,将来的输入数据也能从现在的状态到达.

  BRNN的原理是将正则RNN的神经元分成两个方向。一个用于正时方向(正向状态),另一个用于负时间方向(反向状态).这两个状态的输出没有连接到相反状态的输入。通过这两个时间方向,可以使用来自当前时间帧的过去和将来作为输入信息。

  双向RNN的思想和原始版RNN有一些许不同,只要是它考虑到当前的输出不止和之前的序列元素有关系,还和之后的序列元素也是有关系的。举个例子来说,如果我们现在要去填一句话中空缺的词,那我们直观就会觉得这个空缺的位置填什么词其实和前后的内容都有关系,对吧。双向RNN其实也非常简单,我们直观理解一下,其实就可以把它们看做2个RNN的叠加。输出的结果这个 时候就是基于2个RNN的隐状态计算得到的。

  

  关于训练:

  BRNNS可以使用RNNS类似的算法来做训练.因为两个方向的神经元没有任何相互作用。然而,当应用反向传播时,由于不能同时更新输入和输出层,因此需要额外的过程。训练的一般流程如下:对于前向传递,先传递正向状态和后向状态,然后输出神经元通过.对于后向传递,首先输出神经元,然后传递正向状态和后退状态。在进行前向和后向传递之后,更新权重。

  关于源码:

  首先看一下BRNN的定义,定义中使用了两层的网络,使用的模型是nn.LSTM.这里的LSTM是一类可以处理长期依赖问题的特殊的RNN,由Hochreiter 和 Schmidhuber于1977年提出,目前已有多种改进,且广泛用于各种各样的问题中。LSTM主要用来处理长期依赖问题,与传统RNN相比,长时间的信息记忆能力是与生俱来的。参数bidirectional=True是表示

该网路是一个双向的网络.这里的参数batch_first=True,因为nn.lstm()接受的数据输入是(序列长度,batch,输入维数),这和我们cnn输入的方式不太一致,所以使用batch_first,我们可以将输入变成(batch,序列长度,输入维数)

# Bidirectional recurrent neural network (many-to-one)
class BiRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, num_classes):
super(BiRNN, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_size*2, num_classes) # 2 for bidirection def forward(self, x):
# Set initial states
h0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # 2 for bidirection
c0 = torch.zeros(self.num_layers*2, x.size(0), self.hidden_size).to(device) # Forward propagate LSTM
out, _ = self.lstm(x, (h0, c0)) # out: tensor of shape (batch_size, seq_length, hidden_size*2) # Decode the hidden state of the last time step
out = self.fc(out[:, -1, :])
return out

    在实现函数中,首先设置初始化的状态:h0,c0,然后根据初始化的状态来输出决策后的内容,把结果线性插值法过滤后输出.

 这个神经网络的其他部分使用和别的网络是一样的,训练部分和测试就不再一一介绍了,想知道的朋友可以参考我前面的文章的介绍.下面给出整体的源码:

  最终的可运行源码:

 import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms # Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Hyper-parameters
input_size = 784
hidden_size = 500
num_classes = 10
#input_size = 84
#hidden_size = 50
#num_classes = 2
num_epochs = 5
batch_size = 100
learning_rate = 0.001 # MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='../../data',
train=True,
transform=transforms.ToTensor(),
download=True) test_dataset = torchvision.datasets.MNIST(root='../../data',
train=False,
transform=transforms.ToTensor()) # Data loader
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True) test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False) # Fully connected neural network with one hidden layer
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(hidden_size, num_classes) def forward(self, x):
out = self.fc1(x)
out = self.relu(out)
out = self.fc2(out)
return out model = NeuralNet(input_size, hidden_size, num_classes).to(device) # Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Train the model
total_step = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# Move tensors to the configured device
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device) # Forward pass
outputs = model(images)
loss = criterion(outputs, labels) # Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step() if (i+1) % 100 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
# In test phase, we don't need to compute gradients (for memory efficiency)
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
#print(predicted)
correct += (predicted == labels).sum().item() print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total)) # Save the model checkpoint
torch.save(model.state_dict(), 'model.ckpt')

  结果这里就不再贴出来了,想知道的朋友可以自己运行一下.

参考文档:

1 https://cloud.tencent.com/developer/article/1134467

PyTorch--双向递归神经网络(B-RNN)概念,源码分析的更多相关文章

  1. 递归神经网络(RNN,Recurrent Neural Networks)和反向传播的指南 A guide to recurrent neural networks and backpropagation(转载)

    摘要 这篇文章提供了一个关于递归神经网络中某些概念的指南.与前馈网络不同,RNN可能非常敏感,并且适合于过去的输入(be adapted to past inputs).反向传播学习(backprop ...

  2. vue双向绑定的原理及实现双向绑定MVVM源码分析

    vue双向绑定的原理及实现双向绑定MVVM源码分析 双向数据绑定的原理是:可以将对象的属性绑定到UI,具体的说,我们有一个对象,该对象有一个name属性,当我们给这个对象name属性赋新值的时候,新值 ...

  3. 鸿蒙内核源码分析(并发并行篇) | 听过无数遍的两个概念 | 百篇博客分析OpenHarmony源码 | v25.01

    百篇博客系列篇.本篇为: v25.xx 鸿蒙内核源码分析(并发并行篇) | 听过无数遍的两个概念 | 51.c.h .o 任务管理相关篇为: v03.xx 鸿蒙内核源码分析(时钟任务篇) | 触发调度 ...

  4. jQuery 2.0.3 源码分析 Deferred概念

    JavaScript编程几乎总是伴随着异步操作,传统的异步操作会在操作完成之后,使用回调函数传回结果,而回调函数中则包含了后续的工作.这也是造成异步编程困难的主要原因:我们一直习惯于“线性”地编写代码 ...

  5. jQuery 2.0.3 源码分析 Deferrred概念

    转载http://www.cnblogs.com/aaronjs/p/3348569.html JavaScript编程几乎总是伴随着异步操作,传统的异步操作会在操作完成之后,使用回调函数传回结果,而 ...

  6. Redux源码分析之基本概念

    Redux源码分析之基本概念 Redux源码分析之createStore Redux源码分析之bindActionCreators Redux源码分析之combineReducers Redux源码分 ...

  7. 【Java】HashMap源码分析——基本概念

    在JDK1.8后,对HashMap源码进行了更改,引入了红黑树.在这之前,HashMap实际上就是就是数组+链表的结构,由于HashMap是一张哈希表,其会产生哈希冲突,为了解决哈希冲突,HashMa ...

  8. Linux内核2.6.14源码分析-双向循环链表代码分析(巨详细)

    Linux内核源码分析-链表代码分析 分析人:余旭 分析时间:2005年11月17日星期四 11:40:10 AM 雨 温度:10-11度 编号:1-4 类别:准备工作 Email:yuxu97101 ...

  9. 鸿蒙内核源码分析(索引节点篇) | 谁是文件系统最重要的概念 | 百篇博客分析OpenHarmony源码 | v64.01

    百篇博客系列篇.本篇为: v64.xx 鸿蒙内核源码分析(索引节点篇) | 谁是文件系统最重要的概念 | 51.c.h.o 文件系统相关篇为: v62.xx 鸿蒙内核源码分析(文件概念篇) | 为什么 ...

  10. 鸿蒙内核源码分析(文件概念篇) | 为什么说一切皆是文件 | 百篇博客分析OpenHarmony源码 | v62.01

    百篇博客系列篇.本篇为: v62.xx 鸿蒙内核源码分析(文件概念篇) | 为什么说一切皆是文件 | 51.c.h.o 本篇开始说文件系统,它是内核五大模块之一,甚至有Linux的设计哲学是" ...

随机推荐

  1. Python第十六天 类的实例化

    首先 , 先定义一个 简单的 Person 类 class Person: head = 1 ear = 2 def eat(self): print('吃饭') 关于什么是类, 定义类, 类对象,类 ...

  2. 服务管理之samba

    目录 samba 1.samba的简介 2. samba访问 1.搭建用户认证共享服务器 2.搭建匿名用户共享服务器 samba 1.samba的简介 Samba是在Linux和UNIX系统上实现SM ...

  3. [原创]networkx 画中文节点

    一直想分享一些自己遇到的坑,但确实很多时候走的太快 很多想做的事情会被快节奏的生活冲淡, 在公司做事反而比学校还自在, 因为是悠闲的实习期... 几点小建议写在前头--xdj: 遇到问题,大多数人首先 ...

  4. #error "OpenCV 4.x+ requires enabled C++11 support"解决方法

    报错的本质是需要c++11的支持,顾名思义,当前的编译环境是c++11以下的版本.我用的cmake编译,因此再cmakelists文件内添加设置c++标准为14就可以编译通过. )

  5. Python Day 12

    阅读目录: 内容回顾 函数默认值的细节 三元表达式 列表与字典推导式 函数对象 名称空间 函数嵌套的定义 作用域 ##内容回顾 # 字符串的比较 -- 按照从左往右比较每一个字符,通过字符对应的asc ...

  6. 关于根据模板生成pdf文档,差入图片和加密

    import com.alibaba.fastjson.JSONObject; import com.aliyun.oss.OSSClient; import com.itextpdf.text.pd ...

  7. CodeForces 935E Fafa and Ancient Mathematics (树形DP)

    题意:给定一个表达式,然后让你添加 n 个加号,m 个减号,使得表达式的值最大. 析:首先先要建立一个表达式树,这个应该很好建立,就不说了,dp[u][i][0] 表示 u 这个部分表达式,添加 i ...

  8. 如何从本地远程访问虚拟机内的Mysql服务器?

    假设重装了操作系统,则本地的很多软件可能都需要重新安装,比如数据库.但是,假设我们把一些重要的软件安装在虚拟机当中,则在重装操作系统之前,数据库服务器可以和虚拟机一起进行备份.重装操作系统之后,原先的 ...

  9. [转载]ECMall模板解析语法与机制

    ECMall模板解析语法与机制 2011-05-22 在ECMall模板中,用"{"开头,以"}"结尾就构成一个标签单元,"{"紧接着的单词 ...

  10. Touch365现已上架!

    欢迎体验由武宇亭.诸子轩.梁国伟.张裕浩.孔维喆.邱亚威同学开发的创意照片浏览软件Touch365,现已上架Microsoft官方商城! https://www.microsoft.com/zh-cn ...