今天终于弄明白,TensorFlow和Keras中LSTM神经网络的输入输出层到底应该怎么设置和连接了。写个备忘。

https://machinelearningmastery.com/how-to-develop-lstm-models-for-time-series-forecasting/

Stacked LSTM

Multiple hidden LSTM layers can be stacked one on top of another in what is referred to as a Stacked LSTM model.

An LSTM layer requires a three-dimensional input and LSTMs by default will produce a two-dimensional output as an interpretation from the end of the sequence.

We can address this by having the LSTM output a value for each time step in the input data by setting the return_sequences=True argument on the layer. This allows us to have 3D output from hidden LSTM layer as input to the next.

We can, therefore, define a Stacked LSTM as follows.

# define model
model = Sequential()
model.add(LSTM(50, activation='relu', return_sequences=True, input_shape=(n_steps, n_features)))
model.add(LSTM(50, activation='relu'))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
X_train.shape
(500, 40, 1)
y_train.shape
(500, 40, 1)
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop model = Sequential()
model.add(layers.GRU(100, input_shape=(None, X_train.shape[-1]), return_sequences=True))
model.add(layers.Dense(1))
model.compile(optimizer=RMSprop(), loss='mae')
history = model.fit(X_train, y_train,steps_per_epoch=25,epochs=20)
reset_graph()

n_steps = 40
n_inputs = 1
n_neurons = 100 X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs]) num_units = [500, 200, 100]
cells = [tf.nn.rnn_cell.GRUCell(num_units=n) for n in num_units]
stacked_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
rnn_outputs, states = tf.nn.dynamic_rnn(stacked_rnn_cell, X, dtype=tf.float32) # 先去掉一个维度,用一个Dense层连上,再把n_steps这个维度加回去
# [batch_size, n_steps, n_neurons]
# [batch_size * n_steps, n_neurons]
# [batch_size, n_steps, n_neurons] stacked_rnn_outputs = tf.reshape(rnn_outputs, [-1, n_neurons])
stacked_outputs = tf.layers.dense(stacked_rnn_outputs, n_outputs)
outputs = tf.reshape(stacked_outputs, [-1, n_steps, n_outputs]) loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss) init = tf.global_variables_initializer()
saver = tf.train.Saver() n_iterations = 5000
batch_size = 100 with tf.Session() as sess:
init.run()
for iteration in range(n_iterations):
X_batch, y_batch = next_batch(batch_size, n_steps)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
if iteration % 100 == 0:
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE:", mse) X_new = time_series(np.array(t_instance[:-1].reshape(-1, n_steps, n_inputs)))
y_pred = sess.run(outputs, feed_dict={X: X_new}) saver.save(sess, "./my_time_series_model")
  • TensorFlow不同, Keras 中 LSTM 层默认只输出最后一个时间步

LSTM 神经网络输入输出层的更多相关文章

  1. LSTM神经网络输入输出究竟是怎样的?

    LSTM图和词向量输入分析

  2. LSTM神经网络

    LSTM是什么 LSTM即Long Short Memory Network,长短时记忆网络.它其实是属于RNN的一种变种,可以说它是为了克服RNN无法很好处理远距离依赖而提出的. 我们说RNN不能处 ...

  3. (转)LSTM神经网络介绍

    原文链接:http://www.atyun.com/16821.html 扩展阅读: https://machinelearningmastery.com/time-series-prediction ...

  4. (转) 干货 | 图解LSTM神经网络架构及其11种变体(附论文)

    干货 | 图解LSTM神经网络架构及其11种变体(附论文) 2016-10-02 机器之心 选自FastML 作者:Zygmunt Z. 机器之心编译  参与:老红.李亚洲 就像雨季后非洲大草原许多野 ...

  5. 关于LeNet-5卷积神经网络 S2层与C3层连接的参数计算的思考???

    https://blog.csdn.net/saw009/article/details/80590245 关于LeNet-5卷积神经网络 S2层与C3层连接的参数计算的思考??? 首先图1是LeNe ...

  6. MLP神经网络 隐含层节点数的设置】如何设置神经网络隐藏层 的神经元个数

    神经网络 隐含层节点数的设置]如何设置神经网络隐藏层 的神经元个数 置顶 2017年10月24日 14:25:07 开心果汁 阅读数:12968    版权声明:本文为博主原创文章,未经博主允许不得转 ...

  7. tensorflow学习之(十一)RNN+LSTM神经网络的构造

    #RNN 循环神经网络 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data tf.se ...

  8. 深入浅出LSTM神经网络

    转自:https://www.csdn.net/article/2015-06-05/2824880 LSTM递归神经网络RNN长短期记忆   摘要:根据深度学习三大牛的介绍,LSTM网络已被证明比传 ...

  9. Tensorflow之基于LSTM神经网络写唐诗

    最近看了不少关于写诗的博客,在前人的基础上做了一些小的改动,因比较喜欢一次输入很长的开头句,所以让机器人输出压缩为一个开头字生成两个诗句,写五言和七言诗,当然如果你想写更长的诗句是可以继续改动的. 在 ...

随机推荐

  1. jobs的后台进程程序如何终止?

    好像没有专门的jobs相关的命令来终止后台进程, 只有通过 jobs -l看 后台进程的pid, 然后用kill来终止. 摘录: (( 进程的终止 后台进程的终止: 方法一: 通过jobs命令查看jo ...

  2. 【原创】基于phpGrace+uniApp开发之:5.登录界面增加图片验证码

    1.目的: 采用phpGrace中的图片验证码,在用户名+密码登录时使用图片验证码进行验证. 2.文档地址: 图片验证码的文档地址:http://www.phpgrace.com/tools/info ...

  3. springmvc 读写分离

    推荐第四种:https://github.com/shawntime/shawn-rwdb 4种不方的读写分离实现方法 http://blog.csdn.net/lixiucheng005/artic ...

  4. Jmeter之Switch Controller

    在测试过程中,各种不同的情况需要执行不同的操作,这个时候用if控制器比较麻烦,此时就可以使用Switch Controller代替. 一.界面显示 二.配置说明 1.名称:标识 2.注释:备注 3.S ...

  5. Spring Boot 之 RabbitMQ 消息队列中间件的三种模式

    开门见山(文末附有消息队列的几个基本概念) 1.直接模式( Direct)模式 直白的说就是一对一,生产者对应唯一的消费者(当然同一个消费者可以开启多个服务). 虽然使用了自带的交换器(Exchang ...

  6. 第一个spring boot应用

    前提 首先要确保已经安装了java和maven: $ java -version java version "1.8.0_102" Java(TM) SE Runtime Envi ...

  7. node.js多版本管理 nvm

    安装nvm 1.nvm压缩包下载地址 https://github.com/coreybutler/nvm-windows/releases 注意:Windows用户下载nvm-setup.zip 2 ...

  8. Web API入门二(实例)

    学习编程的最好方法就是实例,本人用的是VS2015 1.创建ASP.NET Web空项目 点击确定后即创建了空"WebApi"项目 2.下面,我们需要使用NuGet包管理器添加最新 ...

  9. Java学习day8面向对象编程2-类的属性和方法

    一.类的属性 1.语法格式 修饰符 类型 属性名 = 初值 说明:修饰符private:该属性只能由该类的方法使用.在同一类内可见.使用对象:变量.方法. 注意:不能修饰类(外部类)    修饰符pu ...

  10. 重载与重写、多态——java

    方法的重写(Overriding)和重载(Overloading)是java多态性的不同表现,重写是父类与子类之间多态性的一种表现,重载可以理解成多态的具体表现形式. (1)方法重载是一个类中定义了多 ...