包括卷积神经网络(CNN)在内的各种前馈神经网络模型, 其一次前馈过程的输出只与当前输入有关与历史输入无关.

递归神经网络(Recurrent Neural Network, RNN)充分挖掘了序列数据中的信息, 在时间序列和自然语言处理方面有着重要的应用.

递归神经网络可以展开为普通的前馈神经网络:

长短期记忆模型(Long-Short Term Memory)是RNN的常用实现. 与一般神经网络的神经元相比, LSTM神经元多了一个遗忘门.

LSTM神经元的输出除了与当前输入有关外, 还与自身记忆有关. RNN的训练算法也是基于传统BP算法增加了时间考量, 称为BPTT(Back-propagation Through Time)算法.

使用tensorflow内置rnn

tensorflow内置了递归神经网络的实现:

from tensorflow.python.ops import rnn, rnn_cell

tensorflow目前正在快速迭代中, 上述路径可能会发生变化.在0.6.0版本中上述路径是有效的.

官方教程中已经加入了循环神经网络的部分, API可能不会发生太大变化.

Tensorflow有多种rnn神经元可供选择:

  • rnn_cell.BasicLSTMCell

  • rnn_cell.LSTMCell

  • rnn_cell.GRUCell

这里我们选用最简单的BasicLSTMCell, 需要设置神经元个数和forget_bias参数:

self.lstm_cell = rnn_cell.BasicLSTMCell(hidden_n, forget_bias=1.0)

可以直接调用cell对象获得输出和状态:

output, state = cell(inputs, state)

使用dropout避免过拟合问题:

from tensorflow.python.ops.rnn_cell import Dropoutwrapper

cells = DropoutWrapper(lstm_cell, input_keep_prob=0.5, output_keep_prob=0.5)

使用MultiRNNCell来创建多层神经网络:

from tensorflow.python.ops.rnn_cell import MultiRNNCell

cells = MultiRNNCell([lstm_cell_1, lstm_cell_2])

不过rnn.rnn可以替我们完成神经网络的构建工作:

outputs, states = rnn.rnn(self.lstm_cell, self.input_layer, dtype=tf.float32)

再加一个输出层进行输出:

self.prediction = tf.matmul(outputs[-1], self.weights) + self.biases

定义损失函数:

self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(self.prediction, self.label_layer))

使用Adam优化器进行训练:

self.trainer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(self.loss)

因为神经网络需要处理序列数据, 所以输入层略复杂:

self.input_layer = [tf.placeholder("float", [step_n, input_n]) for i in range(batch_size)]

tensorflow要求RNNCell的输入为一个列表, 列表中的每一项作为一个批次进行训练.

列表中的每一个元素代表一个序列, 每一行为序列中的一项. 这样每一项为一个形状为(序列长, 输入维数)的矩阵.

标签还是和原来一样为形如(序列长, 输出维度)的矩阵:

self.label_layer = tf.placeholder("float", [step_n, output_n])

执行训练:

self.session.run(initer)
for i in range(limit):
self.session.run(self.trainer, feed_dict={self.input_layer[0]: train_x[0], self.label_layer: train_y})

因为input_layer为列表, 而列表不能作为字典的键.所以我们只能采用{self.input_layer[0]: train_x[0]}这样的方式输入数据.

可以看到lable_layer也是二维的, 并没有输入多个批次的数据. 考虑到这两点, 目前这个实现并不具备多批次处理的能力.

序列的长度通常是不同的, 而目前的实现采用的是定长输入. 这是需要解决的另一个难题.

完整源代码可以在demo.py中查看.

tensorflow实现循环神经网络的更多相关文章

  1. 基于TensorFlow的循环神经网络(RNN)

    RNN适用场景 循环神经网络(Recurrent Neural Network)适合处理和预测时序数据 RNN的特点 RNN的隐藏层之间的节点是有连接的,他的输入是输入层的输出向量.extend(上一 ...

  2. tensorflow RNN循环神经网络 (分类例子)-【老鱼学tensorflow】

    之前我们学习过用CNN(卷积神经网络)来识别手写字,在CNN中是把图片看成了二维矩阵,然后在二维矩阵中堆叠高度值来进行识别. 而在RNN中增添了时间的维度,因为我们会发现有些图片或者语言或语音等会在时 ...

  3. Tensorflow中循环神经网络及其Wrappers

    tf.nn.rnn_cell.LSTMCell 又名:tf.nn.rnn_cell.BasicLSTMCell.tf.contrib.rnn.LSTMCell 参见: tf.nn.rnn_cell.L ...

  4. TensorFlow系列专题(七):一文综述RNN循环神经网络

    欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 前言 RNN知识结构 简单循环神经网络 RNN的基本结构 RNN的运算过程 ...

  5. 4.5 RNN循环神经网络(recurrent neural network)

     自己开发了一个股票智能分析软件,功能很强大,需要的点击下面的链接获取: https://www.cnblogs.com/bclshuai/p/11380657.html 1.1  RNN循环神经网络 ...

  6. 学习笔记TF057:TensorFlow MNIST,卷积神经网络、循环神经网络、无监督学习

    MNIST 卷积神经网络.https://github.com/nlintz/TensorFlow-Tutorials/blob/master/05_convolutional_net.py .Ten ...

  7. 学习笔记TF053:循环神经网络,TensorFlow Model Zoo,强化学习,深度森林,深度学习艺术

    循环神经网络.https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/3_NeuralNetworks/re ...

  8. TensorFlow——循环神经网络基本结构

    1.导入依赖包,初始化一些常量 import collections import numpy as np import tensorflow as tf TRAIN_DATA = "./d ...

  9. TensorFlow学习笔记(六)循环神经网络

    一.循环神经网络简介 循环神经网络的主要用途是处理和预测序列数据.循环神经网络刻画了一个序列当前的输出与之前信息的关系.从网络结构上,循环神经网络会记忆之前的信息,并利用之前的信息影响后面节点的输出. ...

随机推荐

  1. mark 三年工作总结

    在新公司加班,正在看<HBase 权威指南>,看Michael Stack为本书写的序,介绍HBase最初的发展,Lars在HBase 使用和推广做出的贡献. 突然想到,我还有一篇工作三年 ...

  2. 2.3.7synchronized代码块有volatile同步的功能

    关键字synchronized可以使多个线程访问同一个资源具有同步性,而且他还具有将线程工作内存中的私有变量与公共内存中的变量同步的功能. package com.cky.thread; /** * ...

  3. Integer Array Ladder questions

    1.这个题不难,关键在于把题目意思理解好了.这个题问的不清楚.要求return new length,很容易晕掉.其实就是return 有多少个单独的数. import java.util.Array ...

  4. 使用命令行管理maven项目

    创建maven java项目 自己创建一个文件夹,进入cmd,(shift+鼠标右键)这样创建的maven[java]项目就在该文件夹下了. 打开cmd第一种方式 打开cmd第二种方式 命令:mvn ...

  5. Android-Java-子类实例化过程(内存图)

    案例一: package android.java.oop15; // 描述Person对象 class Person { // 构造方法就算不写 默认有一个隐式的无参构造方法:public Pers ...

  6. ubuntu16.04下idea、webstorm等开发工具不能输入中文问题

    问题: ubuntu16.04下idea.webstorm开发工具不能输入中文,就算切换到中文输入法输入的也是英文字母. 解决方案: 1.vim打开开发工具的启动文件(idea下就是idea.sh) ...

  7. zabbix docker - 安装和初始化配置

    zabbix docker - 安装和初始化配置 安装zabbix server docker-mysql版本 zabbix server支持不同的数据库(详见zabbix的docker hub),这 ...

  8. 宽字符————_T、_TEXT、L、TEXT之间的区别

    _T._TEXT.L.TEXT之间的区别 在分析前先对三者做一个简单的分类 _T._TEXT.TEXT三者都是根据编译器的环境进行ANSI/UNICODE变换的,_T和_TEXT是根据_UNICODE ...

  9. openresty + lua 1、openresty 连接 mysql,实现 crud

    最近开发一个项目,公司使用的是 openresty + lua,所以就研究了 openresty + lua.介绍的话,我就不多说了,网上太多了. 写这个博客主要是记录一下,在学习的过程中遇到的一些坑 ...

  10. Spring Boot Debug调试

    在使用spring-boot:run进行启动的时候,如果设置的断点进不去,要进行以下的设置. 1.添加jvm参数配置 在spring-boot的maven插件加上jvmArguments配置. < ...