————————————————————————————————————

写在开头:此文参照莫烦python教程(墙裂推荐!!!)

————————————————————————————————————

分类实验之识别手写数字

  • 这个实验的内容是:基于TensorFlow,实现手写数字的识别。
  • 这里用到的数据集是大家熟知的mnist数据集。
  • mnist有五万多张手写数字的图片,每个图片用28x28的像素矩阵表示。所以我们的输入层每个案列的特征个数就有28x28=784个;因为数字有0,1,2…9共十个,所以我们的输出层是个1x10的向量。输出层是十个小于1的非负数,表示该预测是0,1,2…9的概率,我们选取最大概率所对应的数字作为我们的最终预测。
  • 真实的数字表示为该数字所对应的位置为1,其余位置为0的1x10的向量。

下面就开始实验啦!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #导入数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果还没下载mnist就下载 #定义添加层
def add_layer(inputs,in_size,out_size,activation_function=None):
#定义添加层内容,返回这层的outputs
Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
Wx_plus_b = tf.matmul(inputs,Weights)+biases #预测
if activation_function is None: #如果没有激励函数,那么outputs就是预测值
outputs = Wx_plus_b
else: #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
outputs = activation_function(Wx_plus_b)
return outputs #定义计算正确率的函数
def t_accuracy(t_xs,t_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:t_xs})
correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
return result #定义神经网络的输入值和输出值
xs = tf.placeholder(tf.float32,[None,784]) #None是不规定大小,这里指的是案例个数,而输入特征个数为28x28 = 784
ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例个数,不做规定;10是因为有10个数字,所以输出是10 #增加输出层
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里的激励函数是softmax,此函数多用于多类分类 #计算误差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此误差计算方式和softmax配套用,效果好 #训练
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#学习因子为0.5 #开始训练
sess = tf.Session()
sess.run(tf.initialize_all_variables()) for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100) #提取数据集的100个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1849
0.6537
0.7393
0.7836
0.8053
0.8203
0.8275
0.837
0.8465
0.8504
0.8567
0.8571
0.8643
0.8637
0.8664
0.8687
0.8719
0.8742
0.8763
0.8773

上面4行就是下载的mnist数据集的四个文件。然后看打印出来的正确率可知,这个网络的预测能力是越来越好的。

下面试一下啊,抽取500个数据来训练,看看效果如何:

for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(500) #提取数据集的500个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
0.9001
0.9022
0.9023
0.9026
0.903
0.903
0.9037
0.9036
0.9034
0.9027
0.9041
0.903
0.9039
0.9034
0.9037
0.9046
0.9055
0.9045
0.9053
0.905

由上面打印出来的正确率可知,抽取500个数据来训练的话,正确率会达到90%


*点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


3 TensorFlow入门之识别手写数字的更多相关文章

  1. 6 TensorFlow实现cnn识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  2. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  3. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  4. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  5. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  6. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  7. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  8. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  9. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

随机推荐

  1. OpenERP report doesn't work

    1. When you have used OpenOffice edited  one of reports,it has stored the report's banary data is da ...

  2. centos7 mysql 5.7 官网下载tar安装

    https://dev.mysql.com/downloads/mysql/5.7.html#downloads 下载好上传到服务器,解压后以此安装 libs,client,server三个rpm r ...

  3. ThinkPHP 模板 Volist 标签嵌套循环输出多维数组

    ThinkPHP 中对 volist 标签嵌套使用可实现多维数组的输出. volist 嵌套使用 一般的二维数组,可以用 volist 标签直接循环输出.对于多维数组,则需要对其中的数组成员再次使用 ...

  4. 【Debian】install

    n年前的报废台式机实在不能忍受xp的速度,果断装Linux近期家里的小本装了Ubuntu14.04 ,实在不习惯最新的图形界面.装个debian试试吧. 1.专门弄一个空白分区2.官网下载debian ...

  5. 理解java的 多态

    http://www.cnblogs.com/chenssy/p/3372798.html

  6. 在window把项目上传到github

    作为一个开发者,写博客,上传项目到github好像是不可不会的技能,很多有经验的老司机都会这么建议你.本宝宝第一次要把项目传到github的时候,确实有点蒙蔽,什么鬼,传个东西有必要这么难吗? git ...

  7. Python全栈day24-25(面向对象编程)

    参考文档: http://www.cnblogs.com/linhaifeng/articles/6182264.html# 类:把一类事物的相同的特征和动作整合到一起就是类,类是抽象的概练 对象:就 ...

  8. PAT 1013 Battle Over Cities(并查集)

    1013. Battle Over Cities (25) 时间限制 400 ms 内存限制 65536 kB 代码长度限制 16000 B 判题程序 Standard 作者 CHEN, Yue It ...

  9. HDU 1695 GCD (欧拉函数,容斥原理)

    GCD Time Limit: 6000/3000 MS (Java/Others)    Memory Limit: 32768/32768 K (Java/Others) Total Submis ...

  10. 记录--关于Jquery uploadify 不能动态传值的问题(java)

    动态传值纠结多时后无效, 后得下面一番代码,依旧无效~~ 纳了几个闷,心灰意冷下   清理了 tomcat 一次 再出运行   可以了 真心纠结很久很久   无奈之下还是得  清理清理tomcat: ...