pytorch自定义RNN结构(附代码)
pytorch自定义LSTM结构(附代码)
有时我们可能会需要修改LSTM的结构,比如用分段线性函数替代非线性函数,这篇博客主要写如何用pytorch自定义一个LSTM结构,并在IMDB数据集上搭建了一个单层反向的LSTM网络,验证了自定义LSTM结构的功能。
@
一、整体程序框架
如果要处理一个维度为【batch_size, length, input_dim】的输入,则需要的LSTM结构如图1所示:
layers表示LSTM的层数,batch_size表示批处理大小,length表示长度,input_dim表示每个输入的维度。
其中,每个LSTMcell执行的表达式如下所示:
二、LSTMcell
LSTMcell的计算函数如下所示;其中nn.Parameter表示该张量为模型可训练参数;
class LSTMCell(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_cx = nn.Parameter(torch.Tensor(hidden_size, input_size)) #初始化8个权重矩阵
self.weight_ch = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.weight_fx = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.weight_fh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.weight_ix = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.weight_ih = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.weight_ox = nn.Parameter(torch.Tensor(hidden_size, input_size))
self.weight_oh = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.bias_c = nn.Parameter(torch.Tensor(hidden_size)) #初始化4个偏置bias
self.bias_f = nn.Parameter(torch.Tensor(hidden_size))
self.bias_i = nn.Parameter(torch.Tensor(hidden_size))
self.bias_o = nn.Parameter(torch.Tensor(hidden_size))
self.reset_parameters() #初始化参数
def reset_parameters(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, input, hc):
h, c = hc
i = F.linear(input, self.weight_ix, self.bias_i) + F.linear(h, self.weight_ih) #执行矩阵乘法运算
f = F.linear(input, self.weight_fx, self.bias_f) + F.linear(h, self.weight_fh)
g = F.linear(input, self.weight_cx, self.bias_c) + F.linear(h, self.weight_ch)
o = F.linear(input, self.weight_ox, self.bias_o) + F.linear(h, self.weight_oh)
i = F.sigmoid(i) #激活函数
f = F.sigmoid(f)
g = F.tanh(g)
o = F.sigmoid(o)
c = f * c + i * g
h = o * F.tanh(c)
return h, c
三、LSTM整体程序
如图1所示,一个完整的LSTM是由很多LSTMcell操作组成的,LSTMcell的数量,取决于layers的大小;每个LSTMcell运行的次数取决于length的大小
1. 多层LSTMcell
需要的库函数:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import math
假如我们设计的LSTM层数layers大于1,第一层的LSTM输入维度是input_dim,输出维度是hidden_dim,那么其他各层的输入维度和输出维度都是hidden_dim(下层的输出会成为上层的输入),因此,定义layers个LSTMcell的函数如下所示:
self.lay0 = LSTMCell(input_size,hidden_size)
if layers > 1:
for i in range(1, layers):
lay = LSTMCell(hidden_size,hidden_size)
setattr(self, 'lay{}'.format(i), lay)
其中setattr()函数的作用是,把lay变成self.lay 'i' ,如果layers = 3,那么这段程序就和下面这段程序是一样的
self.lay0 = LSTMCell(input_size,hidden_size)
self.lay1 = LSTMCell(hidden_size,hidden_size)
self.lay2 = LSTMCell(hidden_size,hidden_size)
2. 多层LSTM处理不同长度的输入
每个LSTMcell都需要(h_t-1和c_t-1)作为状态信息输入,若没有指定初始状态,我们就自定义一个值为0的初始状态
if initial_states is None:
zeros = Variable(torch.zeros(input.size(0), self.hidden_size))
initial_states = [(zeros, zeros), ] * self.layers #初始状态
states = initial_states
outputs = []
length = input.size(1)
for t in range(length):
x = input[:, t, :]
for l in range(self.layers):
hc = getattr(self, 'lay{}'.format(l))(x, states[l])
states[l] = hc #如图1所示,左面的输出(h,c)做右面的状态信息输入
x = hc[0] #如图1所示,下面的LSTMcell的输出h做上面的LSTMcell的输入
outputs.append(hc) #将得到的最上层的输出存储起来
其中getattr()函数的作用是,获得括号内的字符串所代表的属性;若l = 3,则下面这两段代码等价:
hc = getattr(self, 'lay{}'.format(l))(x, states[l])
hc = self.lay3(x, states[3])
3. 整体程序
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size, layers=1, sequences=True):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.layers = layers
self.sequences = sequences
self.lay0 = LSTMCell(input_size,hidden_size)
if layers > 1:
for i in range(1, layers):
lay = LSTMCell(hidden_size,hidden_size)
setattr(self, 'lay{}'.format(i), lay)
def forward(self, input, initial_states=None):
if initial_states is None:
zeros = Variable(torch.zeros(input.size(0), self.hidden_size))
initial_states = [(zeros, zeros), ] * self.layers
states = initial_states
outputs = []
length = input.size(1)
for t in range(length):
x = input[:, t, :]
for l in range(self.layers):
hc = getattr(self, 'lay{}'.format(l))(x, states[l])
states[l] = hc
x = hc[0]
outputs.append(hc)
if self.sequences: #是否需要图1最上层里从左到右所有的LSTMcell的输出
hs, cs = zip(*outputs)
h = torch.stack(hs).transpose(0, 1)
c = torch.stack(cs).transpose(0, 1)
output = (h, c)
else:
output = outputs[-1] # #只输出图1最右上角的LSTMcell的输出
return output
三、反向LSTM
定义两个LSTM,然后将输入input1反向,作为input2,就可以了
代码如下所示:
import torch
input1 = torch.rand(2,3,4)
inp = input1.unbind(1)[::-1] #从batch_size所在维度拆开,并倒序排列
input2 = inp[0]
for i in range(1, len(inp)): #倒序后的tensor连接起来
input2 = torch.cat((input2, inp[i]), dim=1)
x, y, z = input1.size() #两个输入同维度
input2 = input2.resize(x, y, z)
OK,反向成功
四、实验
在IMDB上搭建一个单层,双向,LSTM结构,加一个FC层;
self.rnn1 = LSTM(embedding_dim, hidden_dim, layers = n_layers, sequences=False)
self.rnn2 = LSTM(embedding_dim, hidden_dim, layers = n_layers, sequences=False)
self.fc = nn.Linear(hidden_dim * 2, output_dim)
运行结果如图:
时间有限,只迭代了6次,实验证明,自定义的RNN程序,可以收敛。
pytorch自定义RNN结构(附代码)的更多相关文章
- Winform中实现向窗体中拖放照片并显示以及拖放文件夹显示树形结构(附代码下载)
场景 向窗体中拖拽照片并显示效果 向窗体中拖拽文件夹并显示树形结构效果 注: 博客主页: https://blog.csdn.net/badao_liumang_qizhi 关注公众号 霸道的程序猿 ...
- 使用achartengine实现自定义折线图 ----附代码 调试OK
achartengine作为android开发中最常用的实现图标的开源框架,使用比较方便,参考官方文档谢了如下Demo,实现了自定义折线图. package edu.ustb.chart; impor ...
- 分布式消息总线,基于.NET Socket Tcp的发布-订阅框架之离线支持,附代码下载
一.分布式消息总线以及基于Socket的实现 在前面的分享一个分布式消息总线,基于.NET Socket Tcp的发布-订阅框架,附代码下载一文之中给大家分享和介绍了一个极其简单也非常容易上的基于.N ...
- 刚开始学python——数据结构——“自定义队列结构“
自定义队列结构 (学习队列后,自己的码) 主要功能:用列表模拟队列结构,考虑了入队,出队,判断队列是否为空,是否已满以及改变队列大小等基本操作. 下面是封装的一个类,把代码保存在myQueue.py ...
- pytorch实现rnn并且对mnist进行分类
1.RNN简介 rnn,相比很多人都已经听腻,但是真正用代码操练起来,其中还是有很多细节值得琢磨. 虽然大家都在说,我还是要强调一次,rnn实际上是处理的是序列问题,与之形成对比的是cnn,cnn不能 ...
- Spring Security教程(三):自定义表结构
在上一篇博客中讲解了用Spring Security自带的默认数据库存储用户和权限的数据,但是Spring Security默认提供的表结构太过简单了,其实就算默认提供的表结构很复杂,也不一定能满足项 ...
- (转)Uri详解之——Uri结构与代码提取
前言:依然没有前言…… 相关博客:1.<Uri详解之——Uri结构与代码提取>2.<Uri详解之二——通过自定义Uri外部启动APP与Notification启动> 上几篇给大 ...
- Uri详解之——Uri结构与代码提取
目录(?)[+] 前言:依然没有前言…… 相关博客:1.<Uri详解之——Uri结构与代码提取>2.<Uri详解之二——通过自定义Uri外部启动APP与Notification启动& ...
- iOS开发 swift 3dTouch实现 附代码
iOS开发 swift 3dTouch实现 附代码 一.What? 从iphone6s开始,苹果手机加入了3d touch技术,最简单的理解就是可以读取用户的点击屏幕力度大小,根据力度大小给予不同的反 ...
- 从实例一步一步入门学习SpringCloud的Eureka、Ribbon、Feign、熔断器、Zuul的简单使用(附代码下载)
场景 SpringCloud -创建统一的依赖管理: https://blog.csdn.net/BADAO_LIUMANG_QIZHI/article/details/102530574 Sprin ...
随机推荐
- CentOS7 登录到控制台后无法联网
登录到控制台, ping 不通网络 解决方法 通过命令找到网卡的配置文件见 ll /etc/sysconfig/network-scripts/ | grep ifcfg-en 编辑配置文件 vi i ...
- Vue29 $nextTick
https://www.jianshu.com/p/f1906903b609 1 介绍 Vue 在修改数据之后,视图不会立即更新,而是等待同一事件循环中的所有数据变化完成之后,再统一进行视图更新.而 ...
- window系统增强优化工具
计算机系统优化的作用很多,它可以清理WINDOWS临时文件夹中的临时文件,释放硬盘空间:可以清理注册表里的垃圾文件,减少系统错误的产生:它还能加快开机速度,阻止一些程序开机自动执行:还可以加快上网和关 ...
- C++练习9 函数的重载
函数的重载是用一个函数名定义多个函数,但是这些同名函数的形参列表(参数个数,类型,顺序)必须不同. 函数重载的规则: 1.函数名称必须相同. 2.参数列表必须不同(个数不同.类型不同.参数排列顺序不同 ...
- 什么是整体设备效率(OEE)?
整体设备效率 (OEE) 用于监控制造效率.得到的OEE百分比是通用的,可以跨不同行业和流程进行比较. OEE可用性 OEE可用性=实际运行时间/生产时间 OEE可用性是实际运行时间和计划生产时间之间 ...
- CF1625D.Binary Spiders
\(\text{Problem}\) 大概就是给出 \(n\) 个数和 \(m\),要从中选最多的数使得两两异或值大于等于 \(m\) 输出方案 \(\text{Solution}\) 一开始的想法很 ...
- JZOJ 4279. 【NOIP2015模拟10.29B组】树上路径
题目 现在有一棵n个点的无向树,每个点的编号在1-n之间,求出每个点所在的最长路. 思路 换根 \(dp\),这里只是记下怎么打 \(Code\) #include<cstdio> #in ...
- JZOJ 4213. 【五校联考1day2】对你的爱深不见底
题目 思路 结论题,我不会证明: 找到第一个 \(|S_n| \leq m + 1\),那么答案就是 \(m - |S_{n-2}|\) 证明?我说了我不会,就当结论用吧 这已经很恶心了 然而这题还要 ...
- CF845F - Guards In The Storehouse
题意:在 \((x,y)\) 放一个哨兵,可以监视本行后面的所有格子直到障碍.本列后面所有的格子直到障碍.求使全盘最多一个位置不被监视的方案总数. 我们发现,因为 \(nm\le 250\),所以 \ ...
- Postgresql执行计划浅析与案例
一.前言 PostgreSQL为每个收到查询产生一个查询计划. 选择正确的计划来匹配查询结构和数据的属性对于好的性能来说绝对是最关键的,因此系统包含了一个复杂的规划器来尝试选择好的计划. 你可以使用E ...