————————————————————————————————————

写在开头:此文参照莫烦python教程(墙裂推荐!!!)

————————————————————————————————————

分类实验之识别手写数字

  • 这个实验的内容是:基于TensorFlow,实现手写数字的识别。
  • 这里用到的数据集是大家熟知的mnist数据集。
  • mnist有五万多张手写数字的图片,每个图片用28x28的像素矩阵表示。所以我们的输入层每个案列的特征个数就有28x28=784个;因为数字有0,1,2…9共十个,所以我们的输出层是个1x10的向量。输出层是十个小于1的非负数,表示该预测是0,1,2…9的概率,我们选取最大概率所对应的数字作为我们的最终预测。
  • 真实的数字表示为该数字所对应的位置为1,其余位置为0的1x10的向量。

下面就开始实验啦!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #导入数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果还没下载mnist就下载 #定义添加层
def add_layer(inputs,in_size,out_size,activation_function=None):
#定义添加层内容,返回这层的outputs
Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
Wx_plus_b = tf.matmul(inputs,Weights)+biases #预测
if activation_function is None: #如果没有激励函数,那么outputs就是预测值
outputs = Wx_plus_b
else: #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
outputs = activation_function(Wx_plus_b)
return outputs #定义计算正确率的函数
def t_accuracy(t_xs,t_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:t_xs})
correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
return result #定义神经网络的输入值和输出值
xs = tf.placeholder(tf.float32,[None,784]) #None是不规定大小,这里指的是案例个数,而输入特征个数为28x28 = 784
ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例个数,不做规定;10是因为有10个数字,所以输出是10 #增加输出层
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里的激励函数是softmax,此函数多用于多类分类 #计算误差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此误差计算方式和softmax配套用,效果好 #训练
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#学习因子为0.5 #开始训练
sess = tf.Session()
sess.run(tf.initialize_all_variables()) for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100) #提取数据集的100个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1849
0.6537
0.7393
0.7836
0.8053
0.8203
0.8275
0.837
0.8465
0.8504
0.8567
0.8571
0.8643
0.8637
0.8664
0.8687
0.8719
0.8742
0.8763
0.8773

上面4行就是下载的mnist数据集的四个文件。然后看打印出来的正确率可知,这个网络的预测能力是越来越好的。

下面试一下啊,抽取500个数据来训练,看看效果如何:

for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(500) #提取数据集的500个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
0.9001
0.9022
0.9023
0.9026
0.903
0.903
0.9037
0.9036
0.9034
0.9027
0.9041
0.903
0.9039
0.9034
0.9037
0.9046
0.9055
0.9045
0.9053
0.905

由上面打印出来的正确率可知,抽取500个数据来训练的话,正确率会达到90%


*点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


3 TensorFlow入门之识别手写数字的更多相关文章

  1. 6 TensorFlow实现cnn识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  2. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  3. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  4. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  5. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  6. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  7. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  8. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  9. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

随机推荐

  1. spring读取配置文件PropertyPlaceholderConfigurer类的使用

    这里主要介绍PropertyPlaceholderConfigurer这个类的使用,spring中的该类主要用来读取配置文件并将配置文件中的变量设置到上下文环境中,并进行赋值. 一.此处使用list标 ...

  2. CentOS/Linux 网卡设置 IP地址配置永久生效

    CentOS/Linux下设置IP地址 1.临时生效设置 1.1修改IP地址 #ifconfig eth0 192.168.100.100 1.2修改网关地址 #route add default g ...

  3. awk数组处理字符串合并

    需求: 有一文本文件 lessons.txt 内容如下,请使用 awk 处理该文本,并输出内容如 result.txt lessons.txt: 634751 预排 568688 预排 386760 ...

  4. libxl库的介绍,对Excel操作封装得很好的一个库,兼容2007版和多字节字符(最后有破解版下载)

    前段时间忙着毕业论文,终于有时间写博客了. 早些时候老大给我的一个任务需要对excel进行读表操作,研究了一下c++对excel的操作. 对Excel的操作基本有com,ODBC,AD等,其中ODBC ...

  5. CTO俱乐部

    主持人:目前互联网金融行业存在哪些行业痛点?云信CreditCloud 是如何解决这些痛点的,过程中有哪些思考? 朱家波:目前这个行业有点乱象丛生, 投资人对行业的不信任是一个很大的痛点.解决不信任的 ...

  6. 【Google Earth】pro之视频录制

    一.谷歌地球文件简介 谷歌地球能识别的文件分为:gpx.kml.kmz文件.谷歌地球的官方文件为kml和kmz,其中kmz是kml和图片.模型等数据的压缩文件,kml为数据信息文件,也可以分为航迹和字 ...

  7. 57、Design Support Library 介绍及环境搭建

    一.Material Design几个要素 扁平化.简洁: 水波反馈: 良好体验的过渡动画: 材料空间位置的直观变化: 二.Android Studio配置 在 build.gradle 文件中加入, ...

  8. python 之re模块(正则表达式) 分组、断言详解

    正则表达式分组.断言详解   提示:阅读本文需要有一定的正则表达式基础. 正则表达式中的断言,作为高级应用出现,倒不是因为它有多难,而是概念比较抽象,不容易理解而已,今天就让小菜通俗的讲解一下. 如果 ...

  9. android应用安全——代码安全(android代码混淆)

    android2.3的SDK开始在eclipse中支持代码混淆功能(理论上java都支持混淆,但关键在于如何编写proguard的混淆脚本,2.3的SDK使用简单的配置就可以实现混淆).使用SDK2. ...

  10. cascade(级联)和inverse关系详解

    序言 写这篇文章之前,自己也查了很多的资料来搞清楚这两者的关系和各自所做的事情,但是百度一搜,大多数博文感觉说的云里雾里,可能博主自己清楚是怎么一回事,但是给一个不懂的人或者一知半解的人看的话,别人也 ...