这篇博文不介绍基础的RNN理论知识,只是初步探索如何使用Tensorflow,之后会用笔推导RNN的公式和理论,现在时间紧迫所以先使用为主~~

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow.examples.tutorials.mnist.input_data as input_data
from tensorflow.contrib import rnn mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
trainimgs, trainlabels, testimgs, testlabels \
= mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels
ntrain, ntest, dim, nclasses \
= trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")
dim_input = 28 #28*1
dim_hidden = 128 #28*128
dim_output = 10 #
nsteps = 28
weight = {
"hidden":tf.Variable(tf.random_normal([dim_input,dim_hidden])),
"out" :tf.Variable(tf.random_normal([dim_hidden,dim_output]))
}
biases = {
"hidden":tf.Variable(tf.random_normal([dim_hidden])),
"out" :tf.Variable(tf.random_normal([dim_output]))
} def RNN(_X,_W,_b,_nsteps,_name):
#[batchsize,nsteps*dim_input]-->>[batchsize,nsteps,dim_input]=[num,28,28]
_X = tf.reshape(_X,[-1,28,28])
#-->>[nsteps,batchsize,dim_input]==[28,num,28]
_X = tf.transpose(_X,[1,0,2])
#-->>[nsteps*batchsize,input]==[28*num,28]
_X = tf.reshape(_X,[-1,28])
#这里是共享权重,nsteps个weights全部一样的.
_H = tf.matmul(_X,_W['hidden']) + _b["hidden"]
_Hsplit = tf.split(_H,num_or_size_splits=nsteps,axis=0)
with tf.variable_scope(_name,reuse=tf.AUTO_REUSE):#重复使用参数节约空间,防止报错
#版本更新弃用
#scop.reuse_variables()
#设计一个计算单元
lstm_cell = rnn.BasicLSTMCell(128,forget_bias=1.0)
#版本更新已经弃用
#lstm_cell = rnn_cell.BasicLSTMCell(dim_hidden,forget_bias=1.0)
#利用RNN单元搭建网络,这里用的最简单的,其它以后在说
_LSTM_0,_LSTM_S = rnn.static_rnn(lstm_cell,_Hsplit,dtype=tf.float32)
#版本更新已经弃用
#_LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
return tf.matmul(_LSTM_0[-1],_W["out"])+_b["out"]
#使用GPU按需增长模式
config = tf.ConfigProto(allow_soft_placement=True)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.3)
config.gpu_options.allow_growth = True
if __name__== "__main__":
learning_rate = 0.001
x = tf.placeholder(dtype=tf.float32,shape=[None,28*28],name="input_x")
y = tf.placeholder(dtype=tf.float32,shape=[None,10],name="output_y")
pred = RNN(x,weight,biases,nsteps,"basic")
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
optm = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
accr = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1),tf.argmax(y,1)),dtype=tf.float32))
init = tf.global_variables_initializer()
print("RNN Already") training_epochs = 50
batch_size = 16
display_step = 1
sess = tf.Session(config=config)
sess.run(init)
print("Start optimization")
for epoch in range(training_epochs):
avg_cost = 0.
total_batch = int(mnist.train.num_examples/batch_size)
#total_batch = 100
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Fit training using batch data
feeds = {x: batch_xs, y: batch_ys}
sess.run(optm, feed_dict=feeds)
# Compute average loss
avg_cost += sess.run(cost, feed_dict=feeds) / total_batch
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
feeds = {x: batch_xs, y: batch_ys}
train_acc = sess.run(accr, feed_dict=feeds)
print(" Training accuracy: %.3f" % (train_acc))
feeds = {x: testimgs, y: testlabels}
test_acc = sess.run(accr, feed_dict=feeds)
print(" Test accuracy: %.3f" % (test_acc))
print("Optimization Finished.")

  • 没有训练结束,使用的GTX1060训练了大概8分钟,如果训练结束感觉应该可以达到97%左右
  • 因为是单层网络,深度不够,也没处理数据~~
  • 这只是简单了解RNN工作流程,和如何用TF操作RNN
  • 以后会慢慢补上~~

参考:

  • 唐迪宇课程,因为版本问题会出现很多代码更新
  • 其它中间忘记记录了,如有侵权请联系博主,抱歉~

RNN探索(2)之手写数字识别的更多相关文章

  1. 5 TensorFlow入门笔记之RNN实现手写数字识别

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  2. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

  3. TensorFlow使用RNN实现手写数字识别

    学习,笔记,有时间会加注释以及函数之间的逻辑关系. # https://www.cnblogs.com/felixwang2/p/9190664.html # https://www.cnblogs. ...

  4. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  5. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  6. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  7. Tensorflow实战 手写数字识别(Tensorboard可视化)

    一.前言 为了更好的理解Neural Network,本文使用Tensorflow实现一个最简单的神经网络,然后使用MNIST数据集进行测试.同时使用Tensorboard对训练过程进行可视化,算是打 ...

  8. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  9. 用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识

    用tensorflow搭建RNN(LSTM)进行MNIST 手写数字辨识 循环神经网络RNN相比传统的神经网络在处理序列化数据时更有优势,因为RNN能够将加入上(下)文信息进行考虑.一个简单的RNN如 ...

随机推荐

  1. 普通Linux用户1分钟上手vi编辑器

    *导读:普通用户只要花1分钟看第二部分即可.高级用户请忽略本文* 目录 1. 编辑器之战 2. vi的使用 2.1 vi的3个模式 2.2 vi的3个模式切换 2.3 vi最基本的命令 2.4 vi的 ...

  2. 二.Flask 学习模板

    Flask 为你配置 Jinja2 模板引擎.使用 render_template() 方法可以渲染模板,只需提供模板名称和需要作为参数传递给模板的变量就可简单执行. 至于模板渲染? 简单来说,就是将 ...

  3. Linux安装MySQL_5.6

    E&T: CentOS_7.4 64位; mysql-5.6.42-linux-glibc2.12-x86_64.tar; Xftp5; Xshell5; P1.下载Linux环境下的MySQ ...

  4. 10. Firewalls (防火墙 2个)

    Netfilter是在标准Linux内核中实现的强大的包过滤器. 用户空间iptables工具用于配置. 它现在支持数据包过滤(无状态或有状态),各种网络地址和端口转换(NAT / NAPT),以及用 ...

  5. Comedi的学习过程

    1.介绍Comedi 1.1Comedi是一个设备驱动开发的软件工具,它采用了一种3层组织模型:上层是用户层,Comedi提供了在用户控件编写程序的接口Comedilib,通过系统调用来控制硬件设备: ...

  6. Python安装及IDE激活

    简介: Windows10下安装激活Pycharm,并同时安装Python 3.x.2.x,便于在Pycharm开发环境中使用不同版本的解释器进行对比学习. 目录: 一.Python 3.x安装 二. ...

  7. python,pil库的小应用

    <pre>#euraxluo 5.15 #obj_1#跳一跳的外挂 from PIL import Image import subprocess import time import r ...

  8. PythonStudy——函数对象的案例

    # part1 # 加法运算 def add(n1, n2): return n1 + n2 def low(n1, n2): return n1 - n2 # 四则运算 def computed(n ...

  9. VirtualBox虚拟机禁止时间同步

    某机器为客户提供,宿主机时间快了20分钟,导致虚拟机时间也跟着快20分钟,每次更改完虚拟机时间,不到1分钟时间又变回去了 在一些情况下必须让VirtualBox虚拟客户机的时间和主机不同步,百度了一番 ...

  10. Linux updatedb命令详解

    Linux updatedb命令 updatedb 命令用来创建或更新 locate 命令所必需的数据库文件. updatedb 命令的执行过程较长,因为在执行时它会遍历整个系统的目录树,并将所有的文 ...