tensorflow学习笔记七----------RNN
和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。
1.将数据处理成序列化;
2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;
3.所有数据都处理完。
导入数据
import tensorflow as tf
import from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported") mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;
diminput = 28
dimhidden = 128
dimoutput = nclasses
nsteps = 28
weights = {
'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([dimhidden])),
'out': tf.Variable(tf.random_normal([dimoutput]))
}
定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O
def _RNN(_X, _W, _b, _nsteps, _name):
# 1. Permute input from [batchsize, nsteps, diminput]
# => [nsteps, batchsize, diminput]
_X = tf.transpose(_X, [1, 0, 2])
# 2. Reshape input to [nsteps*batchsize, diminput]
_X = tf.reshape(_X, [-1, diminput])
# 3. Input layer => Hidden layer
_H = tf.matmul(_X, _W['hidden']) + _b['hidden']
# 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
_Hsplit = tf.split(0, _nsteps, _H)
# 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
# Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
# Only _LSTM_O will be used to predict the output.
with tf.variable_scope(_name) as scope: scope.reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
_LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
# 6. Output
_O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
# Return!
return {
'X': _X, 'H': _H, 'Hsplit': _Hsplit,
'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
}
print ("Network ready")
定义好RNN后,定义损失函数等
learning_rate = 0.001
x = tf.placeholder("float", [None, nsteps, diminput])
y = tf.placeholder("float", [None, dimoutput])
myrnn = _RNN(x, weights, biases, nsteps, 'basic')
pred = myrnn['O']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
init = tf.global_variables_initializer()
print ("Network Ready!")
进行训练
training_epochs = 5
batch_size = 16
display_step = 1
sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size) # Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
# Fit training using batch data
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
print ("Optimization Finished.")
tensorflow学习笔记七----------RNN的更多相关文章
- tensorflow学习笔记七----------卷积神经网络
卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- (转)Qt Model/View 学习笔记 (七)——Delegate类
Qt Model/View 学习笔记 (七) Delegate 类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
随机推荐
- 安装SQL2012出现[HKLM\Software\Microsoft\Fusion!EnableLog] (DWORD)设置为 1
本人安装SQL2012出现这个错误,找了三天三夜,终于把问题找出来,共享给有需要的人们,不用重新换系统 错误如下: 1,此问题是系统.net Framework版本冲突,首先下载.net Framew ...
- hdu 1754 线段树 水题 单点更新 区间查询
I Hate It Time Limit: 9000/3000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others)Total S ...
- [BZOJ3990]:[SDOI2015]排序(搜索)
题目传送门 题目描述 小A有一个1-${2}^{N}$的排列A[1..${2}^{N}$],他希望将A数组从小到大排序,小A可以执行的操作有N种,每种操作最多可以执行一次,对于所有的i(1≤i≤N), ...
- (转载)FM 算法
(转载)FM算法 https://zhuanlan.zhihu.com/p/33184179
- legend3---PHP使用阿里云短信服务
legend3---PHP使用阿里云短信服务 一.总结 一句话总结: 使用步骤照官方文档,代码拷贝即可 1.php使用阿里云短信服务的步骤? 入驻阿里云->开通短信服务->获取Access ...
- 【后台管理系统】—— Ant Design Pro 页面相关(三)
一.卡片Card分类 与普通卡片使用区别:底部按钮及内容样式 <Card hoverable bodyStyle={{ paddingBottom: 20 }} actions={[ // 卡片 ...
- 使用NSIS脚本制作一个安装包
大部分人第一次看到NSIS脚本都是一脸懵逼的.因为它这个脚本的结构乍一看上去就非常奇怪,不作说明的话是看不懂的. 编写脚本命令的时候要非常注意,命令要按照规定写在脚本中不同的段落里,也就是说,命令的先 ...
- 什么是HOOK功能?
HOOK API是一个永恒的话题,如果没有HOOK,许多技术将很难实现,也许根本不能实现.这里所说的API,是广义上的API,它包括DOS下的中断,WINDOWS里的API.中断服务.IFS和NDIS ...
- 【C++ STL 优先队列priority_queue】
https://www.cnblogs.com/fzuljz/p/6171963.html
- lua源码学习篇二:语法分析
一步步调试,在lparser.c文件中luaY_parser函数是语法分析的重点函数,词法分析也是在这个过程中调用的.在这个过程中,用到一些数据结构,下面会详细说. Proto *luaY_parse ...