RNN的PyTorch实现
官方实现
PyTorch已经实现了一个RNN类,就在torch.nn工具包中,通过torch.nn.RNN调用。
使用步骤:
- 实例化类;
- 将输入层向量和隐藏层向量初始状态值传给实例化后的对象,获得RNN的输出。
在实例化该类时,需要传入如下属性:
- input_size:输入层神经元个数;
- hidden_size:每层隐藏层的神经元个数;
- num_layers:隐藏层层数,默认设置为1层;
- nonlinearity:激活函数的选择,可选是'tanh'或者'relu',默认设置为'tanh';
- bias:偏置系数,可选是'True'或者'False',默认设置为'True';
- batch_first:可选是'True'或者'False',默认设置为'False';
- dropout:默认设置为0。若为非0,将在除最后一层的每层RNN输出上引入Dropout层,dropout概率就是该非零值;
- bidirectional:默认设置为False。若为True,即为双向RNN。
RNN的输入有两个,一个是input,一个是h0。input就是输入层向量,h0就是隐藏层初始状态值。
若没有采用批量输入,则输入层向量的形状为(L, Hin);
若采用批量输入,且batch_first为False,则输入层向量的形状为(L, N, Hin);
若采用批量输入,且batch_first为True,则输入层向量的形状为(N, L, Hin);
对于(N, L, Hin),在文本输入时,可以按顺序理解为(每次输入几句话,每句话有几个字,每个字由多少维的向量表示)。
若没有采用批量输入,则隐藏层向量的形状为(D * num_layers, Hout);
若采用批量输入,则隐藏层向量的形状为(D * num_layers, N, Hout);
注意,batch_first的设置对隐藏层向量的形状不起作用。
RNN的输出有两个,一个是output,一个是hn。output包含了每个时间步最后一层的隐藏层状态,hn包含了最后一个时间步每层的隐藏层状态。
若没有采用批量输入,则输出层向量的形状为(L, D * Hout);
若采用批量输入,且batch_first为False,则输出层向量的形状为(L, N, D * Hout);
若采用批量输入,且batch_first为True,则输出层向量的形状为(N, L, D * Hout)。
参数解释:
- N代表的是批量大小;
- L代表的是输入的序列长度;
- 若是双向RNN,则D的值为2;若是单向RNN,则D的值为1;
- Hin在数值上是输入层神经元个数;
- Hout在数值上是隐藏层神经元个数。
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 1, batch_first=True) # 实例化一个单向单层RNN
input = torch.randn(5, 3, 10)
h0 = torch.randn(1, 5, 20)
output, hn = rnn(input, h0)
手写复现
复现代码
import torch
import torch.nn as nn
class MyRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = torch.randn(self.hidden_size, self.input_size) * 0.01
self.weight_hh = torch.randn(self.hidden_size, self.hidden_size) * 0.01
self.bias_ih = torch.randn(self.hidden_size)
self.bias_hh = torch.randn(self.hidden_size)
def forward(self, input, h0):
N, L, input_size = input.shape
output = torch.zeros(N, L, self.hidden_size)
for t in range(L):
x = input[:, t, :].unsqueeze(2) # 获得当前时刻的输入特征,[N, input_size, 1]。unsqueeze(n),在第n维上增加一维
w_ih_batch = self.weight_ih.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, input_size]
w_hh_batch = self.weight_hh.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, hidden_size]
w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1) # [N, hidden_size]。squeeze(n),在第n维上减小一维
w_times_h = torch.bmm(w_hh_batch, h0.unsqueeze(2)).squeeze(-1) # [N, hidden_size]
h0 = torch.tanh(w_times_x + self.bias_ih + w_times_h + self.bias_hh)
output[:, t, :] = h0
return output, h0.unsqueeze(0)
验证正确性
my_rnn = MyRNN(10, 20)
input = torch.randn(5, 3, 10)
h0 = torch.randn(5, 20)
my_output, my_hn = my_rnn(input, h0)
print(output.shape == my_output.shape, hn.shape == my_hn.shape)
True True
主要参考
RNN的PyTorch实现的更多相关文章
- Pytorch基础——使用 RNN 生成简单序列
一.介绍 内容 使用 RNN 进行序列预测 今天我们就从一个基本的使用 RNN 生成简单序列的例子中,来窥探神经网络生成符号序列的秘密. 我们首先让神经网络模型学习形如 0^n 1^n 形式的上下文无 ...
- pytorch_08_RNN
1.循环神经网络的提出是基于记忆模型的想法,期望网络能够记住前面出现的特征,并依据特征推断后面的结果,而且整体的网络结构不断循环,因而得名循环神经网络. 2.循环神经网络的基本结构特别简单,就是将网络 ...
- “你什么意思”之基于RNN的语义槽填充(Pytorch实现)
1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ...
- Pytorch系列教程-使用字符级RNN生成姓名
前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_generation_tutor ...
- Pytorch系列教程-使用字符级RNN对姓名进行分类
前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ...
- pytorch实现rnn并且对mnist进行分类
1.RNN简介 rnn,相比很多人都已经听腻,但是真正用代码操练起来,其中还是有很多细节值得琢磨. 虽然大家都在说,我还是要强调一次,rnn实际上是处理的是序列问题,与之形成对比的是cnn,cnn不能 ...
- pytorch中如何处理RNN输入变长序列padding
一.为什么RNN需要处理变长输入 假设我们有情感分析的例子,对每句话进行一个感情级别的分类,主体流程大概是下图所示: 思路比较简单,但是当我们进行batch个训练数据一起计算的时候,我们会遇到多个训练 ...
- pytorch之 RNN 参数解释
上次通过pytorch实现了RNN模型,简易的完成了使用RNN完成mnist的手写数字识别,但是里面的参数有点不了解,所以对问题进行总结归纳来解决. 总述:第一次看到这个函数时,脑袋有点懵,总结了下总 ...
- PyTorch快速入门教程七(RNN做自然语言处理)
以下内容均来自: https://ptorch.com/news/11.html word embedding也叫做word2vec简单来说就是语料中每一个单词对应的其相应的词向量,目前训练词向量的方 ...
- pytorch rnn 2
import torch import torch.nn as nn import numpy as np import torch.optim as optim class RNN(nn.Modul ...
随机推荐
- 理解 Spring IoC 容器
控制反转与大家熟知的依赖注入同理, 这是通过依赖注入对象的过程. 创建 Bean 后, 依赖的对象由控制反转容器通过构造参数 工厂方法参数或者属性注入. 创建过程相对于普通创建对象的过程是反向, 称之 ...
- ElasticSearch介绍和基本用法(一)
ElasticSearch 引言 1.在海量数据中执行搜索功能时,如果使用MySQL, 效率太低. 2.如果关键字输入的不准确,一样可以搜索到想要的数据. 3.将搜索关键字,以红色的字体展示. 介绍: ...
- 快Key:按一下鼠标【滚轮】,帮你自动填写用户名密码,快速登录,可制作U盘随身(开源免费-附安装文件和源代码)
* 代码以本文所附下载文件包为准,安装文件和源文件包均在本文尾部可下载. * 快Key及本文所有内容仅供交流使用,使用者责任自负,由快Key对使用者及其相关人员或组织造成的任何损失均由使用者自负,与本 ...
- HashMap的哈希函数为何用(n - 1) & hash
前言 在上一篇 Java 中HashMap详解(含HashTable, ConcurrentHashMap) 中提到在map.put(key, value)的过程中,计算完key的hash值, 是通过 ...
- 前端必读2.0:如何在React 中使用SpreadJS导入和导出 Excel 文件
最近我们公司接到一个客户的需求,要求为正在开发的项目加个功能.项目的前端使用的是React,客户想添加具备Excel 导入/导出功能的电子表格模块. 经过几个小时的原型构建后,技术团队确认所有客户需求 ...
- Elasitcsearch7.X集群/索引备份与恢复实战
文章转载自:https://mp.weixin.qq.com/s/_0RlojDsE30CeDSyLNP44w 1.问题引出 ES中文社区中,有如下问题: 问题1:存储数据,data目录从一个机器直接 ...
- 使用kubeoperator自带的nginx-ingress-controller设置服务的ingress规则进行访问
情况说明 当使用kubeoperator安装k8s集群的时候,在组件设置部分选择的ingress 类型是nginx-ingress yaml文件 k8s集群安装后,可以在节点的master主机的这个目 ...
- CentOS7.x安装VNC
VNC需要系统安装的有桌面,如果是生产环境服务器,安装时使用的最小化安装,那么进行下面操作安装GNOME 桌面. # 列出的组列表里有GNOME Desktop. yum grouplist #安装 ...
- Windows 下JDK绿色免安装制作教程
java自从被oracle收购后,windows下新的版本只有安装版.没有zip免安装. windows安装版有一下坏处 会写注册表 会将java.exe,javaw.exe 等解压到C:\Windo ...
- Gitlab基础知识介绍
GitLab架构图 Gitlab各组件作用 -Nginx:静态web服务器. -gitlab-shell:用于处理Git命令和修改authorized keys列表. -gitlab-workhors ...