tensorflow学习笔记七----------RNN
和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。
1.将数据处理成序列化;
2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;
3.所有数据都处理完。
导入数据
import tensorflow as tf
import from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported") mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;
diminput = 28
dimhidden = 128
dimoutput = nclasses
nsteps = 28
weights = {
'hidden': tf.Variable(tf.random_normal([diminput, dimhidden])),
'out': tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([dimhidden])),
'out': tf.Variable(tf.random_normal([dimoutput]))
}
定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O
def _RNN(_X, _W, _b, _nsteps, _name):
# 1. Permute input from [batchsize, nsteps, diminput]
# => [nsteps, batchsize, diminput]
_X = tf.transpose(_X, [1, 0, 2])
# 2. Reshape input to [nsteps*batchsize, diminput]
_X = tf.reshape(_X, [-1, diminput])
# 3. Input layer => Hidden layer
_H = tf.matmul(_X, _W['hidden']) + _b['hidden']
# 4. Splite data to 'nsteps' chunks. An i-th chunck indicates i-th batch data
_Hsplit = tf.split(0, _nsteps, _H)
# 5. Get LSTM's final output (_LSTM_O) and state (_LSTM_S)
# Both _LSTM_O and _LSTM_S consist of 'batchsize' elements
# Only _LSTM_O will be used to predict the output.
with tf.variable_scope(_name) as scope: scope.reuse_variables()
lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
_LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
# 6. Output
_O = tf.matmul(_LSTM_O[-1], _W['out']) + _b['out']
# Return!
return {
'X': _X, 'H': _H, 'Hsplit': _Hsplit,
'LSTM_O': _LSTM_O, 'LSTM_S': _LSTM_S, 'O': _O
}
print ("Network ready")
定义好RNN后,定义损失函数等
learning_rate = 0.001
x = tf.placeholder("float", [None, nsteps, diminput])
y = tf.placeholder("float", [None, dimoutput])
myrnn = _RNN(x, weights, biases, nsteps, 'basic')
pred = myrnn['O']
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
init = tf.global_variables_initializer()
print ("Network Ready!")
进行训练
training_epochs = 5
batch_size = 16
display_step = 1
sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size) # Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
# Fit training using batch data
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print (" Training accuracy: %.3f" % (train_acc))
testimgs = testimgs.reshape((ntest, nsteps, diminput))
feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
test_acc = sess.run(accr, feed_dict=feeds)
print (" Test accuracy: %.3f" % (test_acc))
print ("Optimization Finished.")
tensorflow学习笔记七----------RNN的更多相关文章
- tensorflow学习笔记七----------卷积神经网络
卷积神经网络比神经网络稍微复杂一些,因为其多了一个卷积层(convolutional layer)和池化层(pooling layer). 使用mnist数据集,n个数据,每个数据的像素为28*28* ...
- tensorflow学习笔记——使用TensorFlow操作MNIST数据(2)
tensorflow学习笔记——使用TensorFlow操作MNIST数据(1) 一:神经网络知识点整理 1.1,多层:使用多层权重,例如多层全连接方式 以下定义了三个隐藏层的全连接方式的神经网络样例 ...
- tensorflow学习笔记——自编码器及多层感知器
1,自编码器简介 传统机器学习任务很大程度上依赖于好的特征工程,比如对数值型,日期时间型,种类型等特征的提取.特征工程往往是非常耗时耗力的,在图像,语音和视频中提取到有效的特征就更难了,工程师必须在这 ...
- TensorFlow学习笔记——LeNet-5(训练自己的数据集)
在之前的TensorFlow学习笔记——图像识别与卷积神经网络(链接:请点击我)中了解了一下经典的卷积神经网络模型LeNet模型.那其实之前学习了别人的代码实现了LeNet网络对MNIST数据集的训练 ...
- Tensorflow学习笔记No.10
多输出模型 使用函数式API构建多输出模型完成多标签分类任务. 数据集下载链接:https://pan.baidu.com/s/1JtKt7KCR2lEqAirjIXzvgg 提取码:2kbc 1.读 ...
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- (转)Qt Model/View 学习笔记 (七)——Delegate类
Qt Model/View 学习笔记 (七) Delegate 类 概念 与MVC模式不同,model/view结构没有用于与用户交互的完全独立的组件.一般来讲, view负责把数据展示 给用户,也 ...
- Learning ROS for Robotics Programming Second Edition学习笔记(七) indigo PCL xtion pro live
中文译著已经出版,详情请参考:http://blog.csdn.net/ZhangRelay/article/category/6506865 Learning ROS forRobotics Pro ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
随机推荐
- postman批量调用接口并发测试
本文出自:https://www.cnblogs.com/2186009311CFF/p/11425913.html 接口测试在开发中很容易遇到,下面是请教别人学会的并发测试,希望能帮到需要用到的你, ...
- jquery 3.1 tets
r.extend = r.fn.extend = function () { var a, b, c, d, e, f, g = arguments[0] || {}, h = 1, i = argu ...
- window.location.hash(hash应用)---跳转到hash值制定的具体页面
location是javascript里边管理地址栏的内置对象,比如location.href就管理页面的url,用location.href=url就可以直接将页面重定向url.而location. ...
- Luogu P4550 收集邮票
题目链接:Click here Solution: 本题直接推价格似乎很难,考虑先从购买次数入手 设购买次数\(g(i)\)为当前有\(i\)种不同的邮票,要买到\(n\)种的期望购买次数 可以由期望 ...
- 分布式架构基石-TCP通信协议
为什么会有TCP/IP协议 在世界上各地,各种各样的电脑运行着各自不同的操作系统为大家服务,这些电脑在表达同一种信息的时候所使用的方法是千差万别.就好像圣经中上帝打乱了各地人的口音,让他们无法合作一样 ...
- AtCoder AGC001F Wide Swap (线段树、拓扑排序)
题目链接: https://atcoder.jp/contests/agc001/tasks/agc001_f 题解: 先变成排列的逆,要求\(1\)的位置最小,其次\(2\)的位置最小,依次排下去( ...
- Katalon Studio用迅雷快速下载历史版本方法
一.下载说明 官网正版--历史版本下载地址: https://github.com/katalon-studio/katalon-studio/releases 说明1:这里需要注册账户才可以下载,但 ...
- 第三周syh
第三周作业 7-1 判断上三角矩阵 (15 分) 上三角矩阵指主对角线以下的元素都为0的矩阵:主对角线为从矩阵的左上角至右下角的连线. 本题要求编写程序,判断一个给定的方阵是否上三角矩阵. 输入格 ...
- Qt数据库之数据库连接池
前面的章节里,我们使用了下面的函数创建和取得数据库连接: void createConnectionByName(const QString &connectionName) { QSql ...
- linux下插入U盘自动挂载后,用C获取其挂载点(cat /proc/mounts)
现在已经能够通过libudev获取U盘插入时它的节点名(通过函数udev_device_get_devnode()),是/dev/sdb1 我现在的做法是读取/proc/mounts文件,找到有/de ...