一、学习单步的RNN:RNNCell

如果要学习TensorFlow中的RNN,第一站应该就是去了解“RNNCell”,它是TensorFlow中实现RNN的基本单元,每个RNNCell都有一个call方法,使用方式是:(output, next_state) = call(input, state)。
也就是说,每调用一次RNNCell的call方法,就相当于在时间上“推进了一步”,这就是RNNCell的基本功能。

在代码实现上,RNNCell只是一个抽象类,我们用的时候都是用的它的两个子类BasicRNNCell和BasicLSTMCell。顾名思义,前者是RNN的基础类,后者是LSTM的基础类。

找到源码中BasicRNNCell的调用函数实现:

def调用(self,inputs,state):
“”“最基本的RNN:output = new_state = act(W * input + U * state + B)。”“”
output = self._activation(_linear([inputs,state] ,self._num_units,True))
return 输出,输出

"return输出,输出”说明在BasicRNNCell中,输出其实和隐状态的值是一样的。因此还需要额外对输出定义新的变换才能得到真正的输出y。由于输出和隐状态是一回事,所以在BasicRNNCell中,state_size永远等于output_size

除了call方法外,对于RNNCell,还有两个类属性比较重要:

state_size
output_size
前者是隐层的大小,后者是输出的大小。比如我们通常是将一个batch送入模型计算,设输入数据的形状为(batch_size, input_size),那么计算时得到的隐层状态就是(batch_size, state_size),输出就是(batch_size, output_size)。

对于单层RNN:

import tensorflow as tf
import numpy as np cell = tf.nn.rnn_cell.BasicRNNCell(num_units=) # state_size =
print(cell.state_size) # inputs = tf.placeholder(np.float32, shape=(, )) # 是 batch_size,100是input_size shape = (batch_size, input_size)
h0 = cell.zero_state(, np.float32) # 通过zero_state得到一个全0的初始状态,形状为(batch_size, state_size)
output, h1 = cell.call(inputs, h0) #调用call函数 print(h1.shape) # (, )

对于多层RNN:

import tensorflow as tf
import numpy as np num_layers = 2 #层数
hidden_size = [128,256] #每一层的隐节点个数(可以不一样)
rnn_cells = [] #包含所有层的列表 for i in range(num_layers):
# 构建一个基本rnn单元(一层)
rnn_cell = tf.nn.rnn_cell.BasicRNNCell(lstm_size[i])
# 可以添加dropout
drop_cell = tf.nn.rnn_cell.DropoutWrapper(rnn_cell , output_keep_prob=keep_prob)
rnn_cells.append(drop_cell)
# 堆叠多个LSTM单元
cell = tf.nn.rnn_cell.MultiRNNCell(rnn_cells)
initial_state = cell.zero_state(batch_size, tf.float32)
return cell, initial_state '''
注:对于老版本的tensorflow,堆叠多层RNN(或LSTM):
cell = tf.nn.rnn_cell.MultiRNNCell([rnn_cell for _ in
range(num_layers)]) '''

对于BasicLSTMCell,情况有些许不同,因为LSTM可以看做有两个隐状态h和c,对应的隐层就是一个Tuple,每个都是(batch_size, state_size)的形状:

import tensorflow as tf
import numpy as np
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units=)
inputs = tf.placeholder(np.float32, shape=(, )) # 是 batch_size
h0 = lstm_cell.zero_state(, np.float32) # 通过zero_state得到一个全0的初始状态
output, h1 = lstm_cell.call(inputs, h0) #h1包含两个隐状态 print(h1.h) # shape=(, )
print(h1.c) # shape=(, )

二、学习如何一次执行多步:tf.nn.dynamic_rnn

基础的RNNCell有一个很明显的问题:对于单个的RNNCell,我们使用它的call函数进行运算时,只是在序列时间上前进了一步。比如使用x1、h0得到h1,通过x2、h1得到h2等。这样的h话,如果我们的序列长度为10,就要调用10次call函数,比较麻烦。对此,TensorFlow提供了一个tf.nn.dynamic_rnn函数,使用该函数就相当于调用了n次call函数。即通过{h0,x1, x2, …., xn}直接得{h1,h2…,hn}。

具体来说,设我们输入数据的格式为(batch_size, time_steps, input_size),其中time_steps表示序列本身的长度,如在Char RNN中,长度为10的句子对应的time_steps就等于10。最后的input_size就表示输入数据单个序列单个时间维度上固有的长度。另外我们已经定义好了一个RNNCell,调用该RNNCell的call函数time_steps次,对应的代码就是:

# inputs: shape = (batch_size, time_steps, input_size)
# cell: RNNCell
# initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵

# inputs: shape = (batch_size, time_steps, input_size)
# cell: RNNCell
# initial_state: shape = (batch_size, cell.state_size)。初始状态。一般可以取零矩阵


import tensorflow as tf


tf.reset_default_graph()
batch_size = 32 # batch大小
input_size = 100 # 输入向量xt维度
state_size = 128 # 隐藏状态ht维度
time_steps = 10 # 序列长度


inputs = tf.random_normal(shape=[batch_size, time_steps, input_size], dtype=tf.float32)
print("inputs.shape:",inputs.shape) #(32,10,100)


lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(num_units = state_size)
print(lstm_cell.state_size) #(c=128,h=128)


initial_state = lstm_cell.zero_state(batch_size, dtype = tf.float32)
print(initial_state.h, initial_state.c)  #(32,128),(32,128)


outputs, state = tf.nn.dynamic_rnn(lstm_cell, inputs, initial_state = initial_state)


print(outputs)  #(32,10,128)
print(state)    #(32,128) state是最终(最后一个time_step)的状态
print(state.h, state.c) #(32,128),(32,128)

转自知乎:https://zhuanlan.zhihu.com/p/28196873

超详细的RNN代码实现(tensorflow)的更多相关文章

  1. 超详细的Xcode代码格式化教程,可自定义样式。

    超详细的Xcode代码格式化教程,可自定义样式. 为什么要格式化代码 当团队内有多人开发的时候,每个人写的代码格式都有自己的喜好,也可能会忙着写代码而忽略了格式的问题.在之前,我们可能会写完代码后,再 ...

  2. 超详细的Xcode代码格式化教程,可自定义样式

    为什么要格式化代码 当团队内有多人开发的时候,每个人写的代码格式都有自己的喜好,也可能会忙着写代码而忽略了格式的问题. 在之前,我们可能会写完代码后,再一点一点去调格式,很浪费时间. 有了ClangF ...

  3. 【智能算法】粒子群算法(Particle Swarm Optimization)超详细解析+入门代码实例讲解

    喜欢的话可以扫码关注我们的公众号哦,更多精彩尽在微信公众号[程序猿声] 01 算法起源 粒子群优化算法(PSO)是一种进化计算技术(evolutionary computation),1995 年由E ...

  4. 微信支付接口--超详细带注释代码--Demo

    如果本文对你有用,请爱心点个赞,提高排名,帮助更多的人.谢谢大家!❤ 如果解决不了,可以在文末进群交流. 如果对你有帮助的话麻烦点个[推荐]~最好还可以follow一下我的GitHub~感谢观看! 微 ...

  5. Anaconda安装tensorflow和keras(gpu版,超详细)

    本人配置:window10+GTX 1650+tensorflow-gpu 1.14+keras-gpu 2.2.5+python 3.6,亲测可行 一.Anaconda安装 直接到清华镜像网站下载( ...

  6. Github上传代码菜鸟超详细教程【转】

    最近需要将课设代码上传到Github上,之前只是用来fork别人的代码. 这篇文章写得是windows下的使用方法. 第一步:创建Github新账户 第二步:新建仓库 第三部:填写名称,简介(可选), ...

  7. 数据挖掘领域十大经典算法之—C4.5算法(超详细附代码)

    https://blog.csdn.net/fuqiuai/article/details/79456971 相关文章: 数据挖掘领域十大经典算法之—K-Means算法(超详细附代码)        ...

  8. (超详细)使用git命令行将本地仓库代码上传到github或gitlab远程仓库

    (超详细)使用git命令行将本地仓库代码上传到github或gitlab远程仓库 本地创建了一个 xcode 工程项目,现通过 命令行 将该项目上传到 github 或者 gitlab 远程仓库,具体 ...

  9. Keras代码超详细讲解LSTM实现细节

    1.首先我们了解一下keras中的Embedding层:from keras.layers.embeddings import Embedding: Embedding参数如下: 输入尺寸:(batc ...

随机推荐

  1. oracle数据库中 impdb/expdb 详解

    创建逻辑目录,该命令不会在操作系统创建真正的目录,最好以system等管理员创建.create directory dpdata as 'd:\test\dump'; 二.查看管理理员目录(同时查看操 ...

  2. TinyMCE 工具栏配置

    plugins: { type: [String, Array], default: 'lists image media wordcount advlist bbcode code charmap ...

  3. [Functional Programming] Rewrite a reducer with functional state ADT

    For example we have a feature reducer like this: // selectCard :: String -> Action String export ...

  4. php类的定义与实例化方法

    php类的定义 类是对某个对象的定义.它包含有关对象动作方式的信息,包括它的名称.方法.属性和事件.实际上它本身并不是对象,因为它不存在于内存中.当引用类的代码运行时,类的一个新的实例,即对象,就在内 ...

  5. yum -y install 问题解决

    1.错误如下: Last login: Thu Jul 26 09:04:14 2018 from 192.168.3.250[root@diagbot01 ~]# yum -y install do ...

  6. 最短路--Bellman-Ford

    Bellman-Ford 贝尔曼-福特 算法思想 贝尔曼-福特算法(英语:Bellman–Ford algorithm),求解单源最短路径问题的一种算法,由理查德·贝尔曼 和 莱斯特·福特 创立的.它 ...

  7. mongodb mongod.lock文件及oplog文件

    在mongodb的启动时,在数据目录下,会生成一个mongod.lock文件.如果在正常退出时,会清除这个mongod.lock文件,若要是异常退出,在下次启动的时候,会禁止启动,从而保留一份干净的一 ...

  8. SQLServer常见查询问题

     http://bbs.csdn.net/topics/340078327 1.生成若干行记录 --自然数表1-1M CREATE TABLE Nums(n int NOT NULL PRIMAR ...

  9. LOJ3102. 「JSOI2019」神经网络 [DP,容斥,生成函数]

    传送门 思路 大部分是感性理解,不保证完全正确. 不能算是神仙题,但我还是不会qwq 这题显然就是求:把每一棵树分成若干条链,然后把链拼成一个环,使得相邻的链不来自同一棵树,的方案数.(我才不告诉你们 ...

  10. [报错解决] k8s 删除pv一直处于terminating 两种解决方法

    第一种 直接到etcd中删除 1.将所有的etcd中的key值取到一个keys.yam里面,便于查询 ETCDCTL_API=3 etcdctl get "" --from-key ...