LSTM 神经网络输入输出层
今天终于弄明白,TensorFlow和Keras中LSTM神经网络的输入输出层到底应该怎么设置和连接了。写个备忘。
https://machinelearningmastery.com/how-to-develop-lstm-models-for-time-series-forecasting/
Stacked LSTM
Multiple hidden LSTM layers can be stacked one on top of another in what is referred to as a Stacked LSTM model.
An LSTM layer requires a three-dimensional input and LSTMs by default will produce a two-dimensional output as an interpretation from the end of the sequence.
We can address this by having the LSTM output a value for each time step in the input data by setting the return_sequences=True argument on the layer. This allows us to have 3D output from hidden LSTM layer as input to the next.
We can, therefore, define a Stacked LSTM as follows.
# define model
model = Sequential()
model.add(LSTM(50, activation='relu', return_sequences=True, input_shape=(n_steps, n_features)))
model.add(LSTM(50, activation='relu'))
model.add(Dense(1))
model.compile(optimizer='adam', loss='mse')
X_train.shape
(500, 40, 1)
y_train.shape
(500, 40, 1)
from keras.models import Sequential
from keras import layers
from keras.optimizers import RMSprop
model = Sequential()
model.add(layers.GRU(100, input_shape=(None, X_train.shape[-1]), return_sequences=True))
model.add(layers.Dense(1))
model.compile(optimizer=RMSprop(), loss='mae')
history = model.fit(X_train, y_train,steps_per_epoch=25,epochs=20)
reset_graph()
n_steps = 40
n_inputs = 1
n_neurons = 100
X = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_steps, n_outputs])
num_units = [500, 200, 100]
cells = [tf.nn.rnn_cell.GRUCell(num_units=n) for n in num_units]
stacked_rnn_cell = tf.nn.rnn_cell.MultiRNNCell(cells)
rnn_outputs, states = tf.nn.dynamic_rnn(stacked_rnn_cell, X, dtype=tf.float32)
# 先去掉一个维度,用一个Dense层连上,再把n_steps这个维度加回去
# [batch_size, n_steps, n_neurons]
# [batch_size * n_steps, n_neurons]
# [batch_size, n_steps, n_neurons]
stacked_rnn_outputs = tf.reshape(rnn_outputs, [-1, n_neurons])
stacked_outputs = tf.layers.dense(stacked_rnn_outputs, n_outputs)
outputs = tf.reshape(stacked_outputs, [-1, n_steps, n_outputs])
loss = tf.reduce_mean(tf.square(outputs - y))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)
training_op = optimizer.minimize(loss)
init = tf.global_variables_initializer()
saver = tf.train.Saver()
n_iterations = 5000
batch_size = 100
with tf.Session() as sess:
init.run()
for iteration in range(n_iterations):
X_batch, y_batch = next_batch(batch_size, n_steps)
sess.run(training_op, feed_dict={X: X_batch, y: y_batch})
if iteration % 100 == 0:
mse = loss.eval(feed_dict={X: X_batch, y: y_batch})
print(iteration, "\tMSE:", mse)
X_new = time_series(np.array(t_instance[:-1].reshape(-1, n_steps, n_inputs)))
y_pred = sess.run(outputs, feed_dict={X: X_new})
saver.save(sess, "./my_time_series_model")
- 与
TensorFlow不同, Keras 中LSTM层默认只输出最后一个时间步
LSTM 神经网络输入输出层的更多相关文章
- LSTM神经网络输入输出究竟是怎样的?
LSTM图和词向量输入分析
- LSTM神经网络
LSTM是什么 LSTM即Long Short Memory Network,长短时记忆网络.它其实是属于RNN的一种变种,可以说它是为了克服RNN无法很好处理远距离依赖而提出的. 我们说RNN不能处 ...
- (转)LSTM神经网络介绍
原文链接:http://www.atyun.com/16821.html 扩展阅读: https://machinelearningmastery.com/time-series-prediction ...
- (转) 干货 | 图解LSTM神经网络架构及其11种变体(附论文)
干货 | 图解LSTM神经网络架构及其11种变体(附论文) 2016-10-02 机器之心 选自FastML 作者:Zygmunt Z. 机器之心编译 参与:老红.李亚洲 就像雨季后非洲大草原许多野 ...
- 关于LeNet-5卷积神经网络 S2层与C3层连接的参数计算的思考???
https://blog.csdn.net/saw009/article/details/80590245 关于LeNet-5卷积神经网络 S2层与C3层连接的参数计算的思考??? 首先图1是LeNe ...
- MLP神经网络 隐含层节点数的设置】如何设置神经网络隐藏层 的神经元个数
神经网络 隐含层节点数的设置]如何设置神经网络隐藏层 的神经元个数 置顶 2017年10月24日 14:25:07 开心果汁 阅读数:12968 版权声明:本文为博主原创文章,未经博主允许不得转 ...
- tensorflow学习之(十一)RNN+LSTM神经网络的构造
#RNN 循环神经网络 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data tf.se ...
- 深入浅出LSTM神经网络
转自:https://www.csdn.net/article/2015-06-05/2824880 LSTM递归神经网络RNN长短期记忆 摘要:根据深度学习三大牛的介绍,LSTM网络已被证明比传 ...
- Tensorflow之基于LSTM神经网络写唐诗
最近看了不少关于写诗的博客,在前人的基础上做了一些小的改动,因比较喜欢一次输入很长的开头句,所以让机器人输出压缩为一个开头字生成两个诗句,写五言和七言诗,当然如果你想写更长的诗句是可以继续改动的. 在 ...
随机推荐
- ORACLE异机增量备份恢复
PROD异机增量备份恢复验证实施文档 准备工作:source 源库:PROD数据库备份策略:周日0级RMAN备份,周一至周六1级差异增量备份0 4 * * 0 /data/rmanlev0.sh &g ...
- python 卡方检验例子
python 求拒绝域和卡方值 import scipy.stats as ss obs=[107,198,192,125,132,248] exp=[167]*6 #拒绝域 1%的显著水平,自由度5 ...
- Delphi XE2 之 FireMonkey 入门(10) - 常用结构 TPoint、TPointF、TSmallPoint、TSize、TRect、TRectF 及相关方法
它们都是结构, TPointF.TRectF 属新增, 其它也都有升级; 现在都拥有丰富的方法和方便的运算符重载; 且有一组相关的公共函数. 这组内容重要的是它们都来自 System.Types 单元 ...
- [10期]浅谈SSRF安全漏洞
引子:SSRF 服务端请求伪造攻击 很多web应用都提供从其他服务器上获取数据的功能.使用用户指定的URL,web应用可以从其他服务器获取图片,下载文件,读取文件内容等. 这个功能被恶意使用的话,可以 ...
- 关于Vue的理解以及与React框架的对比
1.Vue的理解 概念: Vue是一套用于构建用户界面的渐进式框架: Vue的核心库只关注视图层: 是一个数据驱动的MVVM框架: 特性: 确实轻量:体积比较小: 数据绑定简单.方便: 有一些简单的内 ...
- MSF——客户端渗透之VBScript感染
弱点扫描 根据信息收集的结果搜索漏洞利用模块 结合外部漏洞扫描系统对大IP地址段进行批量扫描 误报率.漏报率 VNC密码破解 客户端渗透 VBScript感染方式 利用 宏 感染word.exce ...
- mysql 主从 设置
总结:1.如果是虚拟克隆mysql 请注意auto.cnf的uuid保证不一样,即删除auto.cnf 重新启动即可2.默认安装的mysql配置文件mysqld.cnf可能绑定了127.0.0.1 只 ...
- 项目被os x占用
xattr -d com.apple.FinderInfo 空格后拖入项目回车就行了
- PCB电路设计 altiumdesigner(项目软件总结)
1.Altium designer 10在PCB里面复制粘贴,比CAD里面多一个动作,就是点击ctrl+C后,要左键点一下复制基点,比如某根线端点或者焊盘,再粘贴,就是基于刚才点的那个为基点粘贴了.2 ...
- Spring MVC-学习笔记(1)认识spring mvc
1.基于XML Schema.Controller接口的spring mvc简单例子 1>创建一个动态Web项目,选择同时创建web.xml文件 2>在WEB-INF/lib中粘贴spri ...