rnn-手写数字识别-网络结构-shape
手写数字识别经典案例,目标是:
1. 掌握tf编写RNN的方法
2. 剖析RNN网络结构
tensorflow编程
#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data ### 注意
# init_state = tf.zeros(shape=[batch_size,rnn_cell.state_size])
# init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) mnist=input_data.read_data_sets("./data",one_hot=True) # 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100 # rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数, 时序
hidden_num=100 # 隐层神经元个数
n_classes=10 # 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) ### 注意
# init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32) output,states=tf.nn.dynamic_rnn(rnn_cell,x,initial_state=init_state,dtype=tf.float32)
return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
print()
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss) step+=1
如果你非常熟悉rnn,代码整体上还是比较好理解的,但是里面涉及许多次的shape设置,比较让人头大,特别是后期写各种rnn时,很容易迷糊,所以每个模型都要理解透彻。
以上代码涉及到shape的变量有
x y w b x变形 init_state
其中比较难理解的是 x x变形 init_state
网络结构
首先回顾一下RNN网络,以便对上个问题进行深入分析。

公式简写如下:
h1 = f(x1w1 + h0w2)
o1 = h1w3 输出层就是简单的全连接,这里不做讨论
shape分析
我们把每个时刻的输入看做向量或者矩阵,因为如果只是一个数,没有shape可言,而且也很简单,没有讨论的必要。
首先有如下思考:
1. h是隐层的输出,也就是x传进去得到的输出,因此传一个x就有一个h(但这并不足以说明什么)
其次从公式层面考虑

从公式可以看出,x和h的行必须相同,列不必相同
图形表示

这是单节点隐层,那么多节点呢?
首先一个神经元节点对应一组weight,多个神经元就是多组weight
其次从公式层面考虑

从公式看出,h和x行相同,h列和神经元个数相同。
图形表示

综上所述,h0的shape是行为 x的行,即batch,列为神经元个数
也就是说一个神经元对应一个h0
对应到上述代码
init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100,100为节点数
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32)
对于输入x的shape,把代码转化成图

根据图来理解:
每次输入n张图片,也就是一次性输入所有时序的x,所有x的shape 为 [None,sequence_length*frame_size]
在rnn模型中因为要与权重相乘,所以需要转化为 [-1,sequence_length,frame_size] [ 样本数,时序数,特征数 ],把特征划分出来,
然后特征乘以权重,然后按时序向上传递,得到输出
结合其他代码分析,对应图片而言,rnn包括LSTM的输入必须是 一次性输入所有时序的x,即 [ 样本数,时序数,特征数 ]
其实这个网络应该是这样

我的理解:像图像这种所有时序的特征结合起来才能确定y的模型用多对一RNN,且每次输入所有时序的特征,而词语预测不然。
rnn-手写数字识别-网络结构-shape的更多相关文章
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
		
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
 - 5 TensorFlow入门笔记之RNN实现手写数字识别
		
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
 - TensorFlow使用RNN实现手写数字识别
		
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
 - 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
		
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
 - 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
		
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
 - 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
		
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
 - [Python]基于CNN的MNIST手写数字识别
		
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
 - 深度学习面试题12:LeNet(手写数字识别)
		
目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...
 - Mnist手写数字识别 Tensorflow
		
Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...
 
随机推荐
- BigInteger 类 和 BigDecimal 类
			
一 .BigInteger BigInteger类在计算和处理任意大小的整数方面是很有用的. BigInteger 任意大的整数,原则上是,只要你的计算机的内存足够大,可以有无限位的. BigInte ...
 - java设计模式之生产者/消费者模式
			
什么是生产者/消费者模式? 某个模块负责产生数据,这些数据由另一个模块来负责处理(此处的模块是广义的,可以是类.函数.线程.进程等).产生数据的模块,就形象地称为生产者:而处理数据的模块,就称为消费者 ...
 - 1003. Check If Word Is Valid After Substitutions Medium检查替换后的词是否有效
			
网址:https://leetcode.com/problems/check-if-word-is-valid-after-substitutions/ 参考:https://leetcode.com ...
 - PostgreSQL导出一张表到MySQL
			
1. 查看PostgreSQL表结构,数据量,是否有特殊字段值 region_il=# select count(*) from result_basic; count --------- ( row ...
 - Qt Widgets——菜单和菜单栏
			
主窗口MainWindow需要菜单栏QMenuBar及菜单QMenu来组成自身,一般应用程序的所有功能都能在菜单中找到.接下来就来说说它们. QMenu 它添加了很多动作QAction,并用自身组成了 ...
 - Microsoft Windows远程桌面协议中间人攻击漏洞(CVE-2005-1794)漏洞解决方案(Windows server2003)
			
1.启动“终端服务配置” 2.选择“连接”,看到“RDP-Tcp”,在其上右键,选择“属性” 3.“常规”选项卡,将加密级别修改为“符合FIPS标准”,点击应用 应用即可,实验发现并不需要重启服务或操 ...
 - 使用Redis数据库(1)(三十三)
			
Spring Boot中除了对常用的关系型数据库提供了优秀的自动化支持之外,对于很多NoSQL数据库一样提供了自动化配置的支持,包括:Redis, MongoDB, Elasticsearch, So ...
 - 一篇文章有若干行,以空行作为输入结束的条件。统计一篇文章中单词the(不管大小写,单词the是由空格隔开的)的个数。
			
#include <iostream>using namespace std; int k = 0;int n = 0;int main() { char c; char a[1000]; ...
 - BeanUtils.copyProperties(A,B)使用注意事项
			
***最近项目中用到BeanUtils.copyProperties(),然后踩了一些坑,也在网上查看了很多同行的测试和总结,现在将自己的测试.整理的注意事项分享如下,希望大家一起学习进步. ***注 ...
 - Notation, First Definitions 转 http://brnt.eu/phd/node9.html
			
LaTeX command Equivalent to Output style Remarks \textnormal{...} {\normalfont...} document font fam ...