tensorflow学习笔记七----------RNN
和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。
1.将数据处理成序列化;
2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;
3.所有数据都处理完。
导入数据
import tensorflow as tf
import from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported") mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;
diminput = 28
dimhidden = 128
dimoutput = nclasses
nsteps = 28
weights = {
'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([dimhidden])),
'out': tf.Variable(tf.random_normal([dimoutput]))
}
定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O
def _RNN(_X, _W, _b, _nsteps, _name):
# 1. Permute input from [batchsize, nsteps, diminput]
# => [nsteps, batchsize, diminput]
_X = tf.transpose(_X, [1, 0, 2])
# 2. Reshape input to [nsteps*batchsize, diminput]
_X = tf.reshape(_X, [-1, diminput])
# 3. Input layer => Hidden layer
_H = tf.matmul(_X, _W['hidden']) + _b['hidden']
# 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
_Hsplit = tf.split(0, _nsteps, _H)
# 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
# Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
# Only _LSTM_O will be used to predict the output.
with tf.variable_scope(_name) as scope: scope.reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
_LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
# 6. Output
_O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
# Return!
return {
'X': _X, 'H': _H, 'Hsplit': _Hsplit,
'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
}
print ("Network ready")
定义好RNN后,定义损失函数等
learning_rate = 0.001
x = tf.placeholder("float", [None, nsteps, diminput])
y = tf.placeholder("float", [None, dimoutput])
myrnn = _RNN(x, weights, biases, nsteps, 'basic')
pred = myrnn['O']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
init = tf.global_variables_initializer()
print ("Network Ready!")
进行训练
training_epochs = 5
batch_size = 16
display_step = 1
sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size) # Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
# Fit training using batch data
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
print ("Optimization Finished.")
tensorflow学习笔记七----------RNN的更多相关文章
- tensorflow学习笔记七----------卷积神经网络
卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- (转)Qt Model/View 学习笔记 (七)——Delegate类
Qt Model/View 学习笔记 (七) Delegate 类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
随机推荐
- [转] 数值优化(Numerical Optimization)学习系列-目录
from:https://blog.csdn.net/fangqingan_java/article/details/48951191 概述数值优化对于最优化问题提供了一种迭代算法思路,通过迭代逐渐接 ...
- Mysql基本原理和概念
一.引言 随着互联网应用的广泛普及,海量数据的存储和访问成为了系统设计的瓶颈问题.对于一个大型的互联网应用,每天几十亿的PV无疑对数据库造成了相当高的负载.对于系统的稳定性和扩展性造成了极大的问题.通 ...
- windows10 gcc编译C程序(分步编译)
下面演示gcc对C源程序的分步编译过程: 1. 编译(Compile) gcc hello.cpp -c # 生成hello.o,目标文件名字和源文件名字一样,VC编译会生成.ojb文件,gcc编译器 ...
- 【转载】mysqld_safe Directory ‘/var/run/mysqld’ for UNIX socket file don’t exists.
This is about resetting the MySQL 5.7 root password in Ubuntu 16.04 LTS You probably tried something ...
- Navicat使用与python操作数据库
一.Navicat使用 1.下载地址: <https://pan.baidu.com/s/1bpo5mqj> 2.测试+链接数据库,新建库 3.新建表,新增字段+类型+约束 4.设计表:外 ...
- python之设计模式的装饰器9步学习
在继承的基础上增加新功能,重载,重写区别 装饰器: 函数a说,我是装饰器啊,其他哪个函数顶着我,我就吃了谁,然后吐出来我的和你的返回结果 testng的UI自动化,@beforetest,@befor ...
- 国内4G频段划分
国内4G频段划分 2015年 4G网络建设如火如荼地进行,换手机大家几乎都买的4G手机,那么看到如下参数怎么知道手机所支持的网络呢? SIM 1:4G TDD-LTE:TD38/39/40/41: ...
- .net sqlite 内存溢出 问题的分析与解决。
一个小的工具网站,用了sqlite数据库,在并发小的情况一切正常,但是并发量上来之后,就报"out of memory"错误了. 分析了代码,就是很简单的根据一个条件取一段数据,并 ...
- centos下面配置key登录
centos下需要配置使用key登录,并且要禁止root登录 下面的操作都是用root来设置的 1.添加新用户 例如用户名leisiyuan useradd leisiyuan 2.设置密码 pass ...
- leetcode-mid-backtracking -22. Generate Parentheses-NO
mycode 没有思路,大早上就有些萎靡,其实和上一个电话号码的题如出一辙啦 参考: class Solution(object): def generateParenthesis(self, ...