一、RNN实现

结构原理

代码实现

import torch
import torch.nn as nn class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined) #全连接层
output: object = self.i2o(combined)
output = self.softmax(output)
return output, hidden def initHidden(self):
return torch.zeros(1, self.hidden_size)

二、LSTM实现

结构原理

封装好的LSTM

import torch
import torch.nn as nn class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) # LSTM以word_embeddings作为输入, 输出维度为 hidden_dim 的隐藏状态值
self.lstm = nn.LSTM(embedding_dim, hidden_dim) # 线性层将隐藏状态空间映射到标注空间
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
self.hidden = self.init_hidden() def init_hidden(self):
# 一开始并没有隐藏状态所以要先初始化一个
# 各个维度的含义是 (num_layers, minibatch_size, hidden_dim)
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim)) def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
tag_scores = F.log_softmax(tag_space, dim=1)
return tag_scores

未封装的LSTM

import torch
import torch.nn as nn class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, output_size):
super(LSTMCell, self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
self.gate = nn.Linear(input_size + hidden_size, cell_size) # 门:线性全连接层
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden, cell):
combined = torch.cat((input, hidden), 1) #维度上连接
f_gate = self.sigmoid(self.gate(combined)) #遗忘门
i_gate = self.sigmoid(self.gate(combined)) #输入门
o_gate = self.sigmoid(self.gate(combined)) #输出门
z_state = self.tanh(self.gate(combined))
cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))
"""
cell长期记忆细胞:(cell·f_gate)+(z_state·i_gate)
遗忘门经过sigmoid后,值在[0,1]之间:
当f_gate趋于0时,和cell矩阵相乘后,记忆细胞为0,忘记长期记忆;
当f_gate区域1时,cell全部输入,作为长期记忆。
"""
hidden = torch.mul(self.tanh(cell), o_gate) #隐藏层:长期记忆细胞cell先过一层tanh激活函数,然后和输出门o_gate矩阵相乘
output = self.output(hidden) #隐藏层作为输出层的输出
output = self.softmax(output)
return output, hidden, cell def initHidden(self):
return torch.zeros(1, self.hidden_size) def initCell(self):
return torch.zeros(1, self.cell_size)

三、GRU实现

结构原理

代码实现

import torch
import torch.nn as nn class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.gate = nn.Linear(input_size + hidden_size, hidden_size)
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
z_gate = self.sigmoid(self.gate(combined)) #重置门
r_gate = self.sigmoid(self.gate(combined)) #更新门
combined01 = torch.cat((input, torch.mul(hidden, r_gate)), 1)
h1_state = self.tanh(self.gate(combined01)) h_state = torch.add(torch.mul((1 - z_gate), hidden), torch.mul(h1_state, z_gate))
output = self.output(h_state)
output = self.softmax(output)
return output, h_state def initHidden(self):
return torch.zeros(1, self.hidden_size)

四、程序分析

1、RNN(Recurrent Natural Network,循环神经网络)

PyTorch提供了两个版本的循环神经网络接口,单元版的输入是每个时间步,或循环神经网络的一个循环,而封装版的是一个序列。

2、LSTM(Long Short-TermMemory,长短时记忆网络)

LSTM是在RNN基础上增加了长时间记忆功能,具体通过增加一个状态C及利用3个门(Gate)实现对信息的更精准控制。
        LSTM比标准的RNN多了3个线性变换,多出的3个线性变换的权重合在一起是RNN的4倍,偏移量也是RNN的4倍。所以,LSTM的参数个数是RNN的4倍。
        除了参数的区别外,隐含状态除h0外,多了一个c0,两者形状相同,都是(num_layers*num_directions,batch,hidden_size),它们合在一起构成了LSTM的隐含状态。所以,LSTM的输入隐含状态为(h0,c0),输出的隐含状态为(hn,cn),其他输入与输出与RNN相同。

3、GRU(Gated Recurrent Unit,门控循环单元)

GRU网络结构与LSTM基本相同,主要区别是LSTM共有3个门,两个隐含状态;而GRU只有两个门,一个隐含状态。其参数是标准RNN的3倍。

PyTorch程序练习(二):循环神经网络的PyTorch实现的更多相关文章

  1. [Pytorch框架] 2.5 循环神经网络

    文章目录 2.5 循环神经网络 2.5.1 RNN简介 RNN的起因 为什么需要RNN RNN都能做什么 2.5.2 RNN的网络结构及原理 RNN LSTM GRU 2.5.3 循环网络的向后传播( ...

  2. Pytorch循环神经网络LSTM时间序列预测风速

    #时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺序的,同样大 ...

  3. pytorch循环神经网络实现回归预测 代码

    pytorch循环神经网络实现回归预测 学习视频:莫烦python # RNN for classification import torch import numpy as np import to ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 神经网络架构PYTORCH-初相识(3W)

    who? Python是基于Torch的一种使用Python作为开发语言的开源机器学习库.主要是应用领域是在自然语言的处理和图像的识别上.它主要的开发者是Facebook人工智能研究院(FAIR)团队 ...

  6. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  7. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  8. 神经网络架构PYTORCH-几个概念

    使用Pytorch之前,有几个概念需要弄清楚. 什么是Tensors(张量)? 这个概念刚出来的时候,物理科班出身的我都感觉有点愣住了,好久没有接触过物理学的概念了. 这个概念,在物理学中怎么解释呢? ...

  9. PyTorch专栏(二)

    专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60min入门 PyTorch 入门 PyTorch 自动微分 PyTorch 神经 ...

  10. NLP与深度学习(二)循环神经网络

    1. 循环神经网络 在介绍循环神经网络之前,我们先考虑一个大家阅读文章的场景.一般在阅读一个句子时,我们是一个字或是一个词的阅读,而在阅读的同时,我们能够记住前几个词或是前几句的内容.这样我们便能理解 ...

随机推荐

  1. spire.Doc -Index was out of the range

    一直以来用的好好的,突然有一天出现:Index was out of the range   ED04211_邵武市易逸行软件技术服务有限公司(万顺出行)_其他 升级后问题: 1.合并单元格出现问题 ...

  2. installshield 安装jdk并配置环境变量

    今天来通过installshield安装jdk以及配置环境变量,本质上是调用第三方安装程序. 首先将jdk的安装文件添加到我们的安装程序中 然后编写我们的脚本 选择BEHAVIOR AND LOGIC ...

  3. 已经调试成功的Protues工程用了一段时间后不能用的问题

    已经调试成功的Protues工程,经过一段时间后不能用的问题 主要现象:(1)可以打开,运行时没有效果:(2)可以打开,运行时闪退 解决办法:(1)删除原ARM芯片:(2)重新找到ARM芯片,重新加载 ...

  4. Django国际化与本地化指南

    title: Django国际化与本地化指南 date: 2024/5/12 16:51:04 updated: 2024/5/12 16:51:04 categories: 后端开发 tags: D ...

  5. windows 文件夹添加备注

    1,选中希望改动的文件夹,然后右键"单击",选择"属性"按钮. 2,打开"自定义"面板,选择"更改图标",将原来的默认文 ...

  6. kubernetes 之 Rolling Update 滚动升级

    滚动升级 1.错误的yml文件 [machangwei@mcwk8s-master ~]$ cat mcwHttpd.yml apiVersion: apps/v1 kind: Deployment ...

  7. HC32L110(六) AS06-VTB07H V5.0测试板AT指令固件

    目录 HC32L110(一) HC32L110芯片介绍和Win10下的烧录 HC32L110(二) HC32L110在Ubuntu下的烧录 HC32L110(三) HC32L110的GCC工具链和VS ...

  8. go 使用 consul api filter 过滤注意点

    当你的value里面有-特殊符号的时候你应该像这样使用Service == "foo-bar"

  9. mysql in不走索引可能的情况

    在MySQL 5.7.3以及之前的版本中,eq_range_index_dive_limit的默认值为10,之 后的版本默认值为200.所以如果大家采用的是5.7.3以及之前的版本的话,很容易采用索引 ...

  10. C#WPF的多屏显示问题

    如果想让窗口在第二个屏幕中显示 public MainWindow() { InitializeComponent(); Screen[] _screens = Screen.AllScreens; ...