官方实现

PyTorch已经实现了一个RNN类,就在torch.nn工具包中,通过torch.nn.RNN调用。

使用步骤:

  1. 实例化类;
  2. 将输入层向量和隐藏层向量初始状态值传给实例化后的对象,获得RNN的输出。

在实例化该类时,需要传入如下属性:

  • input_size:输入层神经元个数;
  • hidden_size:每层隐藏层的神经元个数;
  • num_layers:隐藏层层数,默认设置为1层;
  • nonlinearity:激活函数的选择,可选是'tanh'或者'relu',默认设置为'tanh';
  • bias:偏置系数,可选是'True'或者'False',默认设置为'True';
  • batch_first:可选是'True'或者'False',默认设置为'False';
  • dropout:默认设置为0。若为非0,将在除最后一层的每层RNN输出上引入Dropout层,dropout概率就是该非零值;
  • bidirectional:默认设置为False。若为True,即为双向RNN。

RNN的输入有两个,一个是input,一个是h0。input就是输入层向量,h0就是隐藏层初始状态值。

若没有采用批量输入,则输入层向量的形状为(L, Hin);

若采用批量输入,且batch_first为False,则输入层向量的形状为(L, N, Hin);

若采用批量输入,且batch_first为True,则输入层向量的形状为(N, L, Hin);

对于(N, L, Hin),在文本输入时,可以按顺序理解为(每次输入几句话,每句话有几个字,每个字由多少维的向量表示)。

若没有采用批量输入,则隐藏层向量的形状为(D * num_layers, Hout);

若采用批量输入,则隐藏层向量的形状为(D * num_layers, N, Hout);

注意,batch_first的设置对隐藏层向量的形状不起作用。

RNN的输出有两个,一个是output,一个是hn。output包含了每个时间步最后一层的隐藏层状态,hn包含了最后一个时间步每层的隐藏层状态。

若没有采用批量输入,则输出层向量的形状为(L, D * Hout);

若采用批量输入,且batch_first为False,则输出层向量的形状为(L, N, D * Hout);

若采用批量输入,且batch_first为True,则输出层向量的形状为(N, L, D * Hout)。

参数解释:

  • N代表的是批量大小;
  • L代表的是输入的序列长度;
  • 若是双向RNN,则D的值为2;若是单向RNN,则D的值为1;
  • Hin在数值上是输入层神经元个数;
  • Hout在数值上是隐藏层神经元个数。
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 1, batch_first=True) # 实例化一个单向单层RNN
input = torch.randn(5, 3, 10)
h0 = torch.randn(1, 5, 20)
output, hn = rnn(input, h0)

手写复现

复现代码

import torch
import torch.nn as nn class MyRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = torch.randn(self.hidden_size, self.input_size) * 0.01
self.weight_hh = torch.randn(self.hidden_size, self.hidden_size) * 0.01
self.bias_ih = torch.randn(self.hidden_size)
self.bias_hh = torch.randn(self.hidden_size) def forward(self, input, h0):
N, L, input_size = input.shape
output = torch.zeros(N, L, self.hidden_size)
for t in range(L):
x = input[:, t, :].unsqueeze(2) # 获得当前时刻的输入特征,[N, input_size, 1]。unsqueeze(n),在第n维上增加一维
w_ih_batch = self.weight_ih.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, input_size]
w_hh_batch = self.weight_hh.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, hidden_size]
w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1) # [N, hidden_size]。squeeze(n),在第n维上减小一维
w_times_h = torch.bmm(w_hh_batch, h0.unsqueeze(2)).squeeze(-1) # [N, hidden_size]
h0 = torch.tanh(w_times_x + self.bias_ih + w_times_h + self.bias_hh)
output[:, t, :] = h0
return output, h0.unsqueeze(0)

验证正确性

my_rnn = MyRNN(10, 20)
input = torch.randn(5, 3, 10)
h0 = torch.randn(5, 20)
my_output, my_hn = my_rnn(input, h0)
print(output.shape == my_output.shape, hn.shape == my_hn.shape)
True True

主要参考

官方说明文档

RNN的PyTorch实现的更多相关文章

  1. Pytorch基础——使用 RNN 生成简单序列

    一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...

  2. pytorch_08_RNN

    1.循环神经网络的提出是基于记忆模型的想法,期望网络能够记住前面出现的特征,并依据特征推断后面的结果,而且整体的网络结构不断循环,因而得名循环神经网络. 2.循环神经网络的基本结构特别简单,就是将网络 ...

  3. “你什么意思”之基于RNN的语义槽填充(Pytorch实现)

    1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ...

  4. Pytorch系列教程-使用字符级RNN生成姓名

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutor ...

  5. Pytorch系列教程-使用字符级RNN对姓名进行分类

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ...

  6. pytorch实现rnn并且对mnist进行分类

    1.RNN简介 rnn,相比很多人都已经听腻,但是真正用代码操练起来,其中还是有很多细节值得琢磨. 虽然大家都在说,我还是要强调一次,rnn实际上是处理的是序列问题,与之形成对比的是cnn,cnn不能 ...

  7. pytorch中如何处理RNN输入变长序列padding

    一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...

  8. pytorch之 RNN 参数解释

    上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决. 总述:第一次看到这个函数时,脑袋有点懵,总结了下总 ...

  9. PyTorch快速入门教程七(RNN做自然语言处理)

    以下内容均来自: https://ptorch.com/news/11.html word embedding也叫做word2vec简单来说就是语料中每一个单词对应的其相应的词向量,目前训练词向量的方 ...

  10. pytorch rnn 2

    import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...

随机推荐

  1. 理解 Spring IoC 容器

    控制反转与大家熟知的依赖注入同理, 这是通过依赖注入对象的过程. 创建 Bean 后, 依赖的对象由控制反转容器通过构造参数 工厂方法参数或者属性注入. 创建过程相对于普通创建对象的过程是反向, 称之 ...

  2. ElasticSearch介绍和基本用法(一)

    ElasticSearch 引言 1.在海量数据中执行搜索功能时,如果使用MySQL, 效率太低. 2.如果关键字输入的不准确,一样可以搜索到想要的数据. 3.将搜索关键字,以红色的字体展示. 介绍: ...

  3. 快Key:按一下鼠标【滚轮】,帮你自动填写用户名密码,快速登录,可制作U盘随身(开源免费-附安装文件和源代码)

    * 代码以本文所附下载文件包为准,安装文件和源文件包均在本文尾部可下载. * 快Key及本文所有内容仅供交流使用,使用者责任自负,由快Key对使用者及其相关人员或组织造成的任何损失均由使用者自负,与本 ...

  4. HashMap的哈希函数为何用(n - 1) & hash

    前言 在上一篇 Java 中HashMap详解(含HashTable, ConcurrentHashMap) 中提到在map.put(key, value)的过程中,计算完key的hash值, 是通过 ...

  5. 前端必读2.0:如何在React 中使用SpreadJS导入和导出 Excel 文件

    最近我们公司接到一个客户的需求,要求为正在开发的项目加个功能.项目的前端使用的是React,客户想添加具备Excel 导入/导出功能的电子表格模块. 经过几个小时的原型构建后,技术团队确认所有客户需求 ...

  6. Elasitcsearch7.X集群/索引备份与恢复实战

    文章转载自:https://mp.weixin.qq.com/s/_0RlojDsE30CeDSyLNP44w 1.问题引出 ES中文社区中,有如下问题: 问题1:存储数据,data目录从一个机器直接 ...

  7. 使用kubeoperator自带的nginx-ingress-controller设置服务的ingress规则进行访问

    情况说明 当使用kubeoperator安装k8s集群的时候,在组件设置部分选择的ingress 类型是nginx-ingress yaml文件 k8s集群安装后,可以在节点的master主机的这个目 ...

  8. CentOS7.x安装VNC

    VNC需要系统安装的有桌面,如果是生产环境服务器,安装时使用的最小化安装,那么进行下面操作安装GNOME 桌面. # 列出的组列表里有GNOME Desktop. yum grouplist #安装 ...

  9. Windows 下JDK绿色免安装制作教程

    java自从被oracle收购后,windows下新的版本只有安装版.没有zip免安装. windows安装版有一下坏处 会写注册表 会将java.exe,javaw.exe 等解压到C:\Windo ...

  10. Gitlab基础知识介绍

    GitLab架构图 Gitlab各组件作用 -Nginx:静态web服务器. -gitlab-shell:用于处理Git命令和修改authorized keys列表. -gitlab-workhors ...