一、RNN实现

结构原理

代码实现

import torch
import torch.nn as nn class RNN(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(RNN, self).__init__()
self.hidden_size = hidden_size
self.i2h = nn.Linear(input_size + hidden_size, hidden_size)
self.i2o = nn.Linear(input_size + hidden_size, output_size)
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
hidden = self.i2h(combined) #全连接层
output: object = self.i2o(combined)
output = self.softmax(output)
return output, hidden def initHidden(self):
return torch.zeros(1, self.hidden_size)

二、LSTM实现

结构原理

封装好的LSTM

import torch
import torch.nn as nn class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim) # LSTM以word_embeddings作为输入, 输出维度为 hidden_dim 的隐藏状态值
self.lstm = nn.LSTM(embedding_dim, hidden_dim) # 线性层将隐藏状态空间映射到标注空间
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
self.hidden = self.init_hidden() def init_hidden(self):
# 一开始并没有隐藏状态所以要先初始化一个
# 各个维度的含义是 (num_layers, minibatch_size, hidden_dim)
return (torch.zeros(1, 1, self.hidden_dim),
torch.zeros(1, 1, self.hidden_dim)) def forward(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, self.hidden = self.lstm(embeds.view(len(sentence), 1, -1), self.hidden)
tag_space = self.hidden2tag(lstm_out.view(len(sentence), -1))
tag_scores = F.log_softmax(tag_space, dim=1)
return tag_scores

未封装的LSTM

import torch
import torch.nn as nn class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size, cell_size, output_size):
super(LSTMCell, self).__init__()
self.hidden_size = hidden_size
self.cell_size = cell_size
self.gate = nn.Linear(input_size + hidden_size, cell_size) # 门:线性全连接层
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden, cell):
combined = torch.cat((input, hidden), 1) #维度上连接
f_gate = self.sigmoid(self.gate(combined)) #遗忘门
i_gate = self.sigmoid(self.gate(combined)) #输入门
o_gate = self.sigmoid(self.gate(combined)) #输出门
z_state = self.tanh(self.gate(combined))
cell = torch.add(torch.mul(cell, f_gate), torch.mul(z_state, i_gate))
"""
cell长期记忆细胞:(cell·f_gate)+(z_state·i_gate)
遗忘门经过sigmoid后,值在[0,1]之间:
当f_gate趋于0时,和cell矩阵相乘后,记忆细胞为0,忘记长期记忆;
当f_gate区域1时,cell全部输入,作为长期记忆。
"""
hidden = torch.mul(self.tanh(cell), o_gate) #隐藏层:长期记忆细胞cell先过一层tanh激活函数,然后和输出门o_gate矩阵相乘
output = self.output(hidden) #隐藏层作为输出层的输出
output = self.softmax(output)
return output, hidden, cell def initHidden(self):
return torch.zeros(1, self.hidden_size) def initCell(self):
return torch.zeros(1, self.cell_size)

三、GRU实现

结构原理

代码实现

import torch
import torch.nn as nn class GRUCell(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(GRUCell, self).__init__()
self.hidden_size = hidden_size
self.gate = nn.Linear(input_size + hidden_size, hidden_size)
self.output = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
self.tanh = nn.Tanh()
self.softmax = nn.LogSoftmax(dim=1) def forward(self, input, hidden):
combined = torch.cat((input, hidden), 1)
z_gate = self.sigmoid(self.gate(combined)) #重置门
r_gate = self.sigmoid(self.gate(combined)) #更新门
combined01 = torch.cat((input, torch.mul(hidden, r_gate)), 1)
h1_state = self.tanh(self.gate(combined01)) h_state = torch.add(torch.mul((1 - z_gate), hidden), torch.mul(h1_state, z_gate))
output = self.output(h_state)
output = self.softmax(output)
return output, h_state def initHidden(self):
return torch.zeros(1, self.hidden_size)

四、程序分析

1、RNN(Recurrent Natural Network,循环神经网络)

PyTorch提供了两个版本的循环神经网络接口,单元版的输入是每个时间步,或循环神经网络的一个循环,而封装版的是一个序列。

2、LSTM(Long Short-TermMemory,长短时记忆网络)

LSTM是在RNN基础上增加了长时间记忆功能,具体通过增加一个状态C及利用3个门(Gate)实现对信息的更精准控制。
        LSTM比标准的RNN多了3个线性变换,多出的3个线性变换的权重合在一起是RNN的4倍,偏移量也是RNN的4倍。所以,LSTM的参数个数是RNN的4倍。
        除了参数的区别外,隐含状态除h0外,多了一个c0,两者形状相同,都是(num_layers*num_directions,batch,hidden_size),它们合在一起构成了LSTM的隐含状态。所以,LSTM的输入隐含状态为(h0,c0),输出的隐含状态为(hn,cn),其他输入与输出与RNN相同。

3、GRU(Gated Recurrent Unit,门控循环单元)

GRU网络结构与LSTM基本相同,主要区别是LSTM共有3个门,两个隐含状态;而GRU只有两个门,一个隐含状态。其参数是标准RNN的3倍。

PyTorch程序练习(二):循环神经网络的PyTorch实现的更多相关文章

  1. [Pytorch框架] 2.5 循环神经网络

    文章目录 2.5 循环神经网络 2.5.1 RNN简介 RNN的起因 为什么需要RNN RNN都能做什么 2.5.2 RNN的网络结构及原理 RNN LSTM GRU 2.5.3 循环网络的向后传播( ...

  2. Pytorch循环神经网络LSTM时间序列预测风速

    #时间序列预测分析就是利用过去一段时间内某事件时间的特征来预测未来一段时间内该事件的特征.这是一类相对比较复杂的预测建模问题,和回归分析模型的预测不同,时间序列模型是依赖于事件发生的先后顺序的,同样大 ...

  3. pytorch循环神经网络实现回归预测 代码

    pytorch循环神经网络实现回归预测 学习视频:莫烦python # RNN for classification import torch import numpy as np import to ...

  4. 『PyTorch』第十弹_循环神经网络

    RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...

  5. 神经网络架构PYTORCH-初相识(3W)

    who? Python是基于Torch的一种使用Python作为开发语言的开源机器学习库.主要是应用领域是在自然语言的处理和图像的识别上.它主要的开发者是Facebook人工智能研究院(FAIR)团队 ...

  6. MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网络训练实现及比较(三)

    版权声明:本文为博主原创文章,欢迎转载,并请注明出处.联系方式:460356155@qq.com 在前两篇文章MINIST深度学习识别:python全连接神经网络和pytorch LeNet CNN网 ...

  7. 解决运行pytorch程序多线程问题

    当我使用pycharm运行  (https://github.com/Joyce94/cnn-text-classification-pytorch )  pytorch程序的时候,在Linux服务器 ...

  8. 神经网络架构PYTORCH-几个概念

    使用Pytorch之前,有几个概念需要弄清楚. 什么是Tensors(张量)? 这个概念刚出来的时候,物理科班出身的我都感觉有点愣住了,好久没有接触过物理学的概念了. 这个概念,在物理学中怎么解释呢? ...

  9. PyTorch专栏(二)

    专栏目录: 第一章:PyTorch之简介与下载 PyTorch简介 PyTorch环境搭建 第二章:PyTorch之60min入门 PyTorch 入门 PyTorch 自动微分 PyTorch 神经 ...

  10. NLP与深度学习(二)循环神经网络

    1. 循环神经网络 在介绍循环神经网络之前,我们先考虑一个大家阅读文章的场景.一般在阅读一个句子时,我们是一个字或是一个词的阅读,而在阅读的同时,我们能够记住前几个词或是前几句的内容.这样我们便能理解 ...

随机推荐

  1. fastposter v2.8.4 发布 电商海报生成器

    fastposter v2.8.4 发布 电商海报生成器 fastposter海报生成器,电商海报编辑器,电商海报设计器,fast快速生成海报 海报制作 海报开发.贰维海报,图片海报,分享海报贰维码推 ...

  2. gRPC入门学习之旅(八)

    gRPC入门学习之旅(一) gRPC入门学习之旅(二) gRPC入门学习之旅(三) gRPC入门学习之旅(四) gRPC入门学习之旅(五) gRPC入门学习之旅(六) gRPC入门学习之旅(七) 3. ...

  3. 网络拓扑—NAT内外网映射

    使用Windows Server 2003 网络拓扑 Router 外网:NATIP 网段 = 192.168.17.0/24 内网:仅主机模式IP = 172.16.29.4 Client1:仅主机 ...

  4. mogodb replication set复制集

    replication set复制集 简要命令 replication set复制集 replicattion set 多台服务器维护相同的数据副本,提高服务器的可用性. Replication se ...

  5. 【漏洞复现】用友NC uapjs RCE漏洞(CNVD-C-2023-76801)

    产品介绍 用友NC是一款企业级ERP软件.作为一种信息化管理工具,用友NC提供了一系列业务管理模块,包括财务会计.采购管理.销售管理.物料管理.生产计划和人力资源管理等,帮助企业实现数字化转型和高效管 ...

  6. debug技巧之使用Arthes调试

    一.前言 大家好啊,我是summo,今天给大家分享一下我平时是怎么调试代码的,不是权威也不是教学,就是简单分享一下,如果大家还有更好的调试方式也可以多多交流哦. 前面我介绍了本地调试和远程调试,今天再 ...

  7. RBD与Cephfs

    目录 1. RBD 1. RBD特性 2. 创建rbd池并使用 2.1 创建rbd 2.2 创建用户 2.3 下发用户key与ceph.conf 2.4 客户端查看pool 2.5 创建rbd块 2. ...

  8. Maven项目中整合SSH(pom.xml文件的配置详解)

    Maven项目中整合SSH比较繁琐,需要解决版本冲突问题,博主在下面给出了pom.xml文件的配置信息,改配置文件整合的是:struts2-2.3.24.spring4.2.4.hibernate5. ...

  9. Swoole 源码分析之 TCP Server 模块

    首发原文链接:https://mp.weixin.qq.com/s/KxgxseLEz84wxUPjzSUd3w 大家好,我是码农先森. 今天我们来分析 TCP Server 模块 的实现原理,下面这 ...

  10. Android 12(S) MultiMedia(十四)ESQueue

    之前看到在ATSParser::Pogram::Stream中会创建一个ESQueue,用于存储解析出来的ES data,这个ESQueue到底是用来做什么的呢?这节就来研究研究. 1.构造函数 ES ...