cnn卷积神经网络在前面已经有所了解了,目前博主也使用它进行了一个图像分类问题,基于kaggle里面的food-101进行的图像识别,识别率有点感人,基于数据集的关系,大致来说还可行。
下面我就继续学习rnn神经网络。

rnn神经网络(递归/循环神经网络)模式如下:

我们在处理文字等问题的时候,我们的输入会把上一个时间输出的数据作为下一个时间的输入数据进行处理。
例如:我们有一段话,我们将其分词,得到t个数据,我们分别将每一个词传入到x0,x1....xt里面,当x0传入后,会得到一个结果h0,同时我们会将处理后的数据传入到下个时间,到下个时间的时候,我们会再传入一个数据x1,同时还有上一个时间处理后的数据,将这两个数据进行整合计算,然后再向下传输,一直到结束。
rnn本质来说还是一个bp回路,不过他只是比bp网络多一个环节,即它可以反馈上一时间点处理后的数据。

上图细化如下:


rnn实际上还是存在梯度消失的问题,因此如上图所示,当我们在第一个时间输入的数据,可能在很久之后他就已经梯度消失了(影响很小),因此我们使用lstm(long short trem memory)

上图有三个门:输入门    忘记门   输出门
1.输入门:通过input * g 来判断是否输入,如果不输入就为0,输入就是0,以此判断信号是否输入
2.忘记门:这个信号是否需要衰减多少,可能为50%,衰减是根据信号来判断。
3.输入门:通过判断是否输出,或者输出多少,例如输出50%。
因此上述图可化为:

可以看出,这三个门,所有得影响都是关于输入和上一个数据得输出来进行计算的。

可以看下图:

我们使用lstm得话,通过三个门决定信号是否向下传输,传输多少都可以控制,是否传入信号,输出信息都进行控制。

下面我们还是用tensorflow实现,数据集还是手写数字,虽然rnn主要是用在文字和语言上,但是它依旧可以用在图片上。
下面给出代码:

```python
import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data
mnist=input_data.read_data_sets("MNNIST_data",one_hot=True) #输入图片为 28*28
n_inputs=28#输入一行,一行有28个像素
max_time=28#一共28行,所以为28*28
lstm_size=100#100个隐藏单元
batch_size=50
n_classes=10
n_batch=mnist.train.num_examples//batch_size#计算一共多少批次 #这里none表示第一个维度可以是任意长度
x=tf.placeholder(tf.float32,[None,784]) y=tf.placeholder(tf.float32,[None,10]) #初始化权值
weights=tf.Variable(tf.truncated_normal([lstm_size,n_classes],stddev=0.1))
#初始化偏置值
biases=tf.Variable(tf.constant(0.1,shape=[n_classes])) ##定义Rnn 网络
def RNN(X,weights,biases):
inputs=tf.reshape(X,[-1,max_time,n_inputs])
#定义lstm基本cell
lstm_cell = rnn.BasicLSTMCell(lstm_size)
#lstm_cell=tf.contrib.rnn.core_rnn_cell.BasicLSTMCell(lstm_size)
outputs,final_state=tf.nn.dynamic_rnn(lstm_cell,inputs,dtype=tf.float32)
results=tf.nn.softmax(tf.matmul(final_state[1],weights)+biases)
return results
prediction=RNN(x,weights,biases)
#损失函数
cross_entropy=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction,labels=y))
#优化器
train_step=tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
#保存结果
correct_prediction=tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) init=tf.global_variables_initializer() with tf.Session() as sess:
sess.run(init)
for epoch in range(6):
for batch in range(n_batch):
batch_xs,batch_ys=mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys}) acc=sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("iter:"+str(epoch)+"testing accuracy"+str(acc))

```
运行结果如下:

RNN-LSTM讲解-基于tensorflow实现的更多相关文章

  1. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  2. TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人。

    简介 TensorFlow-Bitcoin-Robot:一个基于 TensorFlow LSTM 模型的 Bitcoin 价格预测机器人. 文章包括一下几个部分: 1.为什么要尝试做这个项目? 2.为 ...

  3. tensorflow学习之(十一)RNN+LSTM神经网络的构造

    #RNN 循环神经网络 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data tf.se ...

  4. Tensorflow - Tutorial (7) : 利用 RNN/LSTM 进行手写数字识别

    1. 经常使用类 class tf.contrib.rnn.BasicLSTMCell BasicLSTMCell 是最简单的一个LSTM类.没有实现clipping,projection layer ...

  5. 基于TensorFlow的循环神经网络(RNN)

    RNN适用场景 循环神经网络(Recurrent Neural Network)适合处理和预测时序数据 RNN的特点 RNN的隐藏层之间的节点是有连接的,他的输入是输入层的输出向量.extend(上一 ...

  6. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

  7. 基于Tensorflow + Opencv 实现CNN自定义图像分类

    摘要:本篇文章主要通过Tensorflow+Opencv实现CNN自定义图像分类案例,它能解决我们现实论文或实践中的图像分类问题,并与机器学习的图像分类算法进行对比实验. 本文分享自华为云社区< ...

  8. [NL系列] RNN & LSTM 网络结构及应用

    http://www.jianshu.com/p/f3bde26febed/ 这篇是 The Unreasonable Effectiveness of Recurrent Neural Networ ...

  9. LSTM和双向LSTM讲解及实践

    LSTM和双向LSTM讲解及实践 目录 RNN的长期依赖问题LSTM原理讲解双向LSTM原理讲解Keras实现LSTM和双向LSTM 一.RNN的长期依赖问题 在上篇文章中介绍的循环神经网络RNN在训 ...

随机推荐

  1. 百万年薪python之路 -- 并发编程之 多线程 一

    多线程 1.进程: 生产者消费者模型 一种编程思想,模型,设计模式,理论等等,都是交给你一种编程的方法,以后遇到类似的情况,套用即可 生产者与消费者模型的三要素: 生产者:产生数据的 消费者:接收数据 ...

  2. 微信小程序中的canvas基础应用

    学了东西还是要记录一下,刚入职的小萌新啊,运气好分到一个项目不是很急的组原以为时间多了可以多学一些东西,但是发现好像不知道从哪里开始下手,我太南了.... 看旁边的实习生同事一直在搞canvas,自己 ...

  3. Apache源码包在LINUX(CENTOS6.8)中的安装(出现问题及解决)

    任务:在CENT6.8系统中安装Apache(版本为:httpd-2.4.41) 前提:由于源码包必须先编译后安装,所以必须先安装编译器:gcc 理论步骤: 1.检测gcc软件包,如果不存在则进行安装 ...

  4. Java多线程编程(五)定时器Timer

    一.定时器Timer的使用 在JDK库中Timer类主要负责计划任务的功能,也就是在指定的时间开始执行某一个任务.Timer类的主要作用就是设置计划任务,但封装任务的类确实TimerTask类,执行计 ...

  5. 【xinsir】函数库,持续更新

    1.遍历文件-node // 递归遍历目录下的文件 function readDirSync (path) { var pa = fs.readdirSync(path); pa.forEach(fu ...

  6. Java HashSet对txt文本内容去重(统计小说用过的字或字数)

    Java HashSet对txt文本内容去重(统计小说用过的字或字数) 基本思路: 1.字节流读需要去重的txt文本.(展示demo为当前workspace下名为utf-8.txt的文本) 2.对读取 ...

  7. 作为一名程序员,你真正了解CDN技术吗?

    本文导读: 物流仓库配送如何加速 静态资源文件部署方式 静态资源加速之CDN技术 解析过程中的名词解释 最后的总结 1.物流仓库配送如何加速 我们还是从生活中购物的例子来展开. 将时光倒回到几年前,在 ...

  8. gym102302E_Chi's performance

    题意 给n个二元组(v,p),要求排序使得v从小到大,而且总价值最大,价值定义为相邻两个v值不同的p值之差绝对值之和. 分析 in a row原来是相邻的意思. 对于每个相同v值的块来说,有用的数只有 ...

  9. ABAP中时间戳的处理

    UTC(UTC, Universal Time Coordinated,通用协调时)时间戳,分为长时间戳和段时间戳,其中长时间戳餐开始的系统的数据元素TIMESTAMPL,类型为DEC(21,7):而 ...

  10. 第三十二章 System V信号量(三)

    n哲学家进餐问题描述有五个哲学家,他们的生活方式是交替地进行思考和进餐,n哲学家们共用一张圆桌,分别坐在周围的五张椅子上,在圆桌上有五个碗和五支筷子,n平时哲学家进行思考,饥饿时便试图取其左.右最靠近 ...