rnn-手写数字识别-网络结构-shape
手写数字识别经典案例,目标是:
1. 掌握tf编写RNN的方法
2. 剖析RNN网络结构
tensorflow编程
#coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data ### 注意
# init_state = tf.zeros(shape=[batch_size,rnn_cell.state_size])
# init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) mnist=input_data.read_data_sets("./data",one_hot=True) # 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100 # rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数, 时序
hidden_num=100 # 隐层神经元个数
n_classes=10 # 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) ### 注意
# init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32) output,states=tf.nn.dynamic_rnn(rnn_cell,x,initial_state=init_state,dtype=tf.float32)
return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
print()
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss) step+=1
如果你非常熟悉rnn,代码整体上还是比较好理解的,但是里面涉及许多次的shape设置,比较让人头大,特别是后期写各种rnn时,很容易迷糊,所以每个模型都要理解透彻。
以上代码涉及到shape的变量有
x y w b x变形 init_state
其中比较难理解的是 x x变形 init_state
网络结构
首先回顾一下RNN网络,以便对上个问题进行深入分析。

公式简写如下:
h1 = f(x1w1 + h0w2)
o1 = h1w3 输出层就是简单的全连接,这里不做讨论
shape分析
我们把每个时刻的输入看做向量或者矩阵,因为如果只是一个数,没有shape可言,而且也很简单,没有讨论的必要。
首先有如下思考:
1. h是隐层的输出,也就是x传进去得到的输出,因此传一个x就有一个h(但这并不足以说明什么)
其次从公式层面考虑

从公式可以看出,x和h的行必须相同,列不必相同
图形表示

这是单节点隐层,那么多节点呢?
首先一个神经元节点对应一组weight,多个神经元就是多组weight
其次从公式层面考虑

从公式看出,h和x行相同,h列和神经元个数相同。
图形表示

综上所述,h0的shape是行为 x的行,即batch,列为神经元个数
也就是说一个神经元对应一个h0
对应到上述代码
init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # rnn_cell.state_size 100,100为节点数
init_state=rnn_cell.zero_state(batch_size, dtype=tf.float32)
对于输入x的shape,把代码转化成图

根据图来理解:
每次输入n张图片,也就是一次性输入所有时序的x,所有x的shape 为 [None,sequence_length*frame_size]
在rnn模型中因为要与权重相乘,所以需要转化为 [-1,sequence_length,frame_size] [ 样本数,时序数,特征数 ],把特征划分出来,
然后特征乘以权重,然后按时序向上传递,得到输出
结合其他代码分析,对应图片而言,rnn包括LSTM的输入必须是 一次性输入所有时序的x,即 [ 样本数,时序数,特征数 ]
其实这个网络应该是这样

我的理解:像图像这种所有时序的特征结合起来才能确定y的模型用多对一RNN,且每次输入所有时序的特征,而词语预测不然。
rnn-手写数字识别-网络结构-shape的更多相关文章
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
- 5 TensorFlow入门笔记之RNN实现手写数字识别
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- TensorFlow使用RNN实现手写数字识别
学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- [Python]基于CNN的MNIST手写数字识别
目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...
- 深度学习面试题12:LeNet(手写数字识别)
目录 神经网络的卷积.池化.拉伸 LeNet网络结构 LeNet在MNIST数据集上应用 参考资料 LeNet是卷积神经网络的祖师爷LeCun在1998年提出,用于解决手写数字识别的视觉任务.自那时起 ...
- Mnist手写数字识别 Tensorflow
Mnist手写数字识别 Tensorflow 任务目标 了解mnist数据集 搭建和测试模型 编辑环境 操作系统:Win10 python版本:3.6 集成开发环境:pycharm tensorflo ...
随机推荐
- 【消息队列】从各方面比较下kafka、activemq、rabbitmq、rocketmq之间的区别
一.单机吞吐量ActiveMQ:万级,吞吐量比RocketMQ和Kafka要低了一个数量级RabbitMQ:万级,吞吐量比RocketMQ和Kafka要低了一个数量级RocketMQ:10万级,Roc ...
- Python 2.7.x 使用Requests发起https请求时报Warning的问题
warning :如下 /usr/local/lib/python2.7/dist-packages/requests/packages/urllib3/connectionpool.py:852: ...
- 发送http请求,get和post两种请求方式
GET请求 GetMethod getMethod=null; String datas = "json=" + plain; HttpClient httpClient = ne ...
- vue 关于npm run build 的小问题
vue项目使用npm run build命令进行打包操作,打包之后试运行报错,报错为: 且命令行警告信息为: 解决办法: 找到项目目录下的config文件夹里的index.js文件,将build对象下 ...
- 数据结构与算法之PHP查找算法(二分查找)
二分查找又称折半查找,只对有序的数组有效. 优点是比较次数少,查找速度快,平均性能好,占用系统内存较少: 缺点是要求待查表为有序表,且插入删除困难. 因此,折半查找方法适用于不经常变动而查找频繁的有序 ...
- [Codeforces Round #340 (Div. 2)]
[Codeforces Round #340 (Div. 2)] vp了一场cf..(打不了深夜的场啊!!) A.Elephant 水题,直接贪心,能用5步走5步. B.Chocolate 乘法原理计 ...
- 数据库恢复(database restore)之兵不血刃——半小时恢复客户数据库
昨天,一个客户打打来电话,说他们的数据库坏了,不能用了,需要我帮助恢复下,这马上要放假了,居然出了这事儿,自己也不太喜欢恢复数据库这类,尤其是他们的数据库是个win上的库,但心里很清楚,客户比咱着急, ...
- spring boot cloud
eclipse spring boot 项目创建 https://www.cnblogs.com/shuaihan/p/8027082.html https://www.cnblogs.com/LUA ...
- Maven常见jar包依赖
<!-- servlet --> <dependency> <groupId>javax.servlet</groupId> <artifactI ...
- nyoj 1091 还是01背包(超大数dp)
nyoj 1091 还是01背包 描述 有n个重量和价值分别为 wi 和 vi 的物品,从这些物品中挑选总重量不超过W的物品,求所有挑选方案中价值总和的最大值 1 <= n <=40 1 ...