2.2RNN
RNN
RNN无法回忆起长久的记忆

LSTM
(long short Term memory长短期记忆)解决梯度消失或弥散vanishing 和梯度爆炸explosion 0.9*n-->0 1.1*n--->无穷大
在RNN中增加了Gate



案例




所以RNN无法回忆起长久的记忆。LSTM为了解决该问题多了三个控制器,做到了延缓记忆的功能

可以从主线和分线两个方面理解。LSTM可以解决延缓记忆问题

截断的反向传播BBPT
https://r2rt.com/styles-of-truncated-backpropagation.html
Tensorflow的截断反向传播(截断长度为n的子序列)的方法在定性上不同于“反向传播错误最多n步”。

LSTM模拟sin图像解决回归问题代码
"""
Please note, this code is only for python 3+. If you are using python 2+, please modify the code accordingly.
Run this script on tensorflow r0.10. Errors appear when using lower versions.
"""
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt BATCH_START = 0
TIME_STEPS = 20
BATCH_SIZE = 50
INPUT_SIZE = 1
OUTPUT_SIZE = 1
CELL_SIZE = 10
LR = 0.006 def get_batch():
global BATCH_START, TIME_STEPS
# xs shape (50batch, 20steps)
xs = np.arange(BATCH_START, BATCH_START+TIME_STEPS*BATCH_SIZE).reshape((BATCH_SIZE, TIME_STEPS)) / (10*np.pi)
seq = np.sin(xs)
res = np.cos(xs)
BATCH_START += TIME_STEPS
# plt.plot(xs[0, :], res[0, :], 'r', xs[0, :], seq[0, :], 'b--')
# plt.show()
# returned seq, res and xs: shape (batch, step, input)
return [seq[:, :, np.newaxis], res[:, :, np.newaxis], xs] class LSTMRNN(object):
def __init__(self, n_steps, input_size, output_size, cell_size, batch_size):
self.n_steps = n_steps
self.input_size = input_size
self.output_size = output_size
self.cell_size = cell_size
self.batch_size = batch_size
with tf.name_scope('inputs'):
self.xs = tf.placeholder(tf.float32, [None, n_steps, input_size], name='xs')
self.ys = tf.placeholder(tf.float32, [None, n_steps, output_size], name='ys')
with tf.variable_scope('in_hidden'):
self.add_input_layer()
with tf.variable_scope('LSTM_cell'):
self.add_cell()
with tf.variable_scope('out_hidden'):
self.add_output_layer()
with tf.name_scope('cost'):
self.compute_cost()
with tf.name_scope('train'):
self.train_op = tf.train.AdamOptimizer(LR).minimize(self.cost) def add_input_layer(self,):
l_in_x = tf.reshape(self.xs, [-1, self.input_size], name='2_2D') # (batch*n_step, in_size)
# Ws (in_size, cell_size)
Ws_in = self._weight_variable([self.input_size, self.cell_size])
# bs (cell_size, )
bs_in = self._bias_variable([self.cell_size,])
# l_in_y = (batch * n_steps, cell_size)
with tf.name_scope('Wx_plus_b'):
l_in_y = tf.matmul(l_in_x, Ws_in) + bs_in
# reshape l_in_y ==> (batch, n_steps, cell_size)
self.l_in_y = tf.reshape(l_in_y, [-1, self.n_steps, self.cell_size], name='2_3D') def add_cell(self):
lstm_cell = tf.contrib.rnn.BasicLSTMCell(self.cell_size, forget_bias=1.0, state_is_tuple=True)
with tf.name_scope('initial_state'):
self.cell_init_state = lstm_cell.zero_state(self.batch_size, dtype=tf.float32)
self.cell_outputs, self.cell_final_state = tf.nn.dynamic_rnn(
lstm_cell, self.l_in_y, initial_state=self.cell_init_state, time_major=False) def add_output_layer(self):
# shape = (batch * steps, cell_size)
l_out_x = tf.reshape(self.cell_outputs, [-1, self.cell_size], name='2_2D')
Ws_out = self._weight_variable([self.cell_size, self.output_size])
bs_out = self._bias_variable([self.output_size, ])
# shape = (batch * steps, output_size)
with tf.name_scope('Wx_plus_b'):
self.pred = tf.matmul(l_out_x, Ws_out) + bs_out def compute_cost(self):
losses = tf.contrib.legacy_seq2seq.sequence_loss_by_example(
[tf.reshape(self.pred, [-1], name='reshape_pred')],
[tf.reshape(self.ys, [-1], name='reshape_target')],
[tf.ones([self.batch_size * self.n_steps], dtype=tf.float32)],
average_across_timesteps=True,
softmax_loss_function=self.ms_error,
name='losses'
)
with tf.name_scope('average_cost'):
self.cost = tf.div(
tf.reduce_sum(losses, name='losses_sum'),
self.batch_size,
name='average_cost')
tf.summary.scalar('cost', self.cost) @staticmethod
def ms_error(labels, logits):
return tf.square(tf.subtract(labels, logits)) def _weight_variable(self, shape, name='weights'):
initializer = tf.random_normal_initializer(mean=0., stddev=1.,)
return tf.get_variable(shape=shape, initializer=initializer, name=name) def _bias_variable(self, shape, name='biases'):
initializer = tf.constant_initializer(0.1)
return tf.get_variable(name=name, shape=shape, initializer=initializer) if __name__ == '__main__':
model = LSTMRNN(TIME_STEPS, INPUT_SIZE, OUTPUT_SIZE, CELL_SIZE, BATCH_SIZE)
sess = tf.Session()
merged = tf.summary.merge_all()
writer = tf.summary.FileWriter("logs", sess.graph)
# tf.initialize_all_variables() no long valid from
# 2017-03-02 if using tensorflow >= 0.12
if int((tf.__version__).split('.')[1]) < 12 and int((tf.__version__).split('.')[0]) < 1:
init = tf.initialize_all_variables()
else:
init = tf.global_variables_initializer()
sess.run(init)
# relocate to the local dir and run this line to view it on Chrome (http://0.0.0.0:6006/):
# $ tensorboard --logdir='logs' plt.ion()
plt.show()
for i in range(200):
seq, res, xs = get_batch()
if i == 0:
feed_dict = {
model.xs: seq,
model.ys: res,
# create initial state
}
else:
feed_dict = {
model.xs: seq,
model.ys: res,
model.cell_init_state: state # use last state as the initial state for this run
} _, cost, state, pred = sess.run(
[model.train_op, model.cost, model.cell_final_state, model.pred],
feed_dict=feed_dict) # plotting 绘制训练sin图像的过程
plt.plot(xs[0, :], res[0].flatten(), 'r', xs[0, :], pred.flatten()[:TIME_STEPS], 'b--')
plt.ylim((-1.2, 1.2))
plt.draw()
plt.pause(0.3) #每隔3秒运行一次 if i % 20 == 0:
print('cost: ', round(cost, 4))
result = sess.run(merged, feed_dict)
writer.add_summary(result, i)
程序运行结果
下面的图像是拟合正弦曲线的过程

2.2RNN的更多相关文章
- PYTHON深度学习6.2RNN循环网络
#简单的循环网络 #-*-coding:utf-8 -*- from keras.datasets import imdbfrom keras.preprocessing import sequenc ...
- Auty自动化测试框架第五篇——框架内部的调用支持、自动化安装库与配置说明
[本文出自天外归云的博客园] 本次对Auty自动化测试框架做些收尾工作,由于在scripts文件夹中的脚本会需要调用其他包结构文件夹中的脚本,所以这里需要添加一下框架对于内部脚本间互相调用的支持,这里 ...
- python就业班-淘宝-目录.txt
卷 TOSHIBA EXT 的文件夹 PATH 列表卷序列号为 AE86-8E8DF:.│ python就业班-淘宝-目录.txt│ ├─01 网络编程│ ├─01-基本概念│ │ 01-网络通信概述 ...
- ssrf小记
SSRF(Server-Side Request Forgery, 服务端请求伪造),攻击者伪造服务端发起的请求并执行,从而获得一些数据或进行攻击 一.危害 1.对内网的端口和服务进行扫描,对主机本地 ...
随机推荐
- re.match re.search re.findall区别
re正则表达式里面,常用的三种方法的区别. re.macth和search匹配得到的是match对象,findall得到的是一个列表. match从字符串开头开始匹配,search返回与正则表达式匹配 ...
- Git初级使用教程
什么是 Git? Git 是一款免费的.开源的.分布式的版本控制系统.旨在快速高效地处理无论规模大小的任何软件工程. 每一个 Git克隆 都是一个完整的文件库,含有全部历史记录和修订追踪能力,不依赖于 ...
- Lab6: Paxos
Introduction In labs 6 and 7, you will replicate the lock service using the replicated state machine ...
- 九度 1494:Dota(完全背包)
题目描述: 大家都知道在dota游戏中,装备是对于英雄来说十分重要的要素.英雄们不仅可以购买单个的装备,甚至某些特定的装备组合能够合成更强的装备.为了简化问题,我们将每个装备对于英雄的功能抽象为一个整 ...
- ios开发之--[_NSInlineData objectForKeyedSubscript:]
reason: '-[_NSInlineData objectForKeyedSubscript:]: unrecognized selector sent to instance 0x7fa2049 ...
- PHP代码审计笔记--URL跳转漏洞
0x01 url任意跳转 未做任何限制,传入任何网址即可进行跳转. 漏洞示例代码: <?php $redirect_url = $_GET['url']; header("Locati ...
- Window关闭端口的方法(445/135/137/138/139/3389等)
为防止漏洞被利用,需要采取必要措施,关闭以上端口,以保证系统更加安全. window2003 关闭135端口的方法 要关闭此端口,只需停止DCOM接口服务即达到目的.下面是详细操作过程. 1.打开“组 ...
- [Linux] 修改用户名密码
1. 普通用户或root用户修改自身登录密码:在终端使用passwd命令. linaro@linaro-ubuntu-desktop:~$ passwd Changing password for l ...
- debug-stripped.ap_' specified for property 'resourceFile' does not exist
1.关闭 Instant Run 2. 关闭混淆(混淆的问题) buildTypes { release { minifyEnabled true shrinkResources true progu ...
- $ cd `dirname $0` 和PWD用法
在命令行状态下单纯执行 $ cd `dirname $0` 是毫无意义的.因为他返回当前路径的".".这个命令写在脚本文件里才有作用,他返回这个脚本文件放置的目录,并可以根据这 ...