————————————————————————————————————

写在开头:此文参照莫烦python教程(墙裂推荐!!!)

————————————————————————————————————

分类实验之识别手写数字

  • 这个实验的内容是:基于TensorFlow,实现手写数字的识别。
  • 这里用到的数据集是大家熟知的mnist数据集。
  • mnist有五万多张手写数字的图片,每个图片用28x28的像素矩阵表示。所以我们的输入层每个案列的特征个数就有28x28=784个;因为数字有0,1,2…9共十个,所以我们的输出层是个1x10的向量。输出层是十个小于1的非负数,表示该预测是0,1,2…9的概率,我们选取最大概率所对应的数字作为我们的最终预测。
  • 真实的数字表示为该数字所对应的位置为1,其余位置为0的1x10的向量。

下面就开始实验啦!

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data #导入数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果还没下载mnist就下载 #定义添加层
def add_layer(inputs,in_size,out_size,activation_function=None):
#定义添加层内容,返回这层的outputs
Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
Wx_plus_b = tf.matmul(inputs,Weights)+biases #预测
if activation_function is None: #如果没有激励函数,那么outputs就是预测值
outputs = Wx_plus_b
else: #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
outputs = activation_function(Wx_plus_b)
return outputs #定义计算正确率的函数
def t_accuracy(t_xs,t_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:t_xs})
correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
return result #定义神经网络的输入值和输出值
xs = tf.placeholder(tf.float32,[None,784]) #None是不规定大小,这里指的是案例个数,而输入特征个数为28x28 = 784
ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例个数,不做规定;10是因为有10个数字,所以输出是10 #增加输出层
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里的激励函数是softmax,此函数多用于多类分类 #计算误差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此误差计算方式和softmax配套用,效果好 #训练
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#学习因子为0.5 #开始训练
sess = tf.Session()
sess.run(tf.initialize_all_variables()) for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100) #提取数据集的100个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1849
0.6537
0.7393
0.7836
0.8053
0.8203
0.8275
0.837
0.8465
0.8504
0.8567
0.8571
0.8643
0.8637
0.8664
0.8687
0.8719
0.8742
0.8763
0.8773

上面4行就是下载的mnist数据集的四个文件。然后看打印出来的正确率可知,这个网络的预测能力是越来越好的。

下面试一下啊,抽取500个数据来训练,看看效果如何:

for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(500) #提取数据集的500个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
0.9001
0.9022
0.9023
0.9026
0.903
0.903
0.9037
0.9036
0.9034
0.9027
0.9041
0.903
0.9039
0.9034
0.9037
0.9046
0.9055
0.9045
0.9053
0.905

由上面打印出来的正确率可知,抽取500个数据来训练的话,正确率会达到90%


*点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*


3 TensorFlow入门之识别手写数字的更多相关文章

  1. 6 TensorFlow实现cnn识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  2. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  3. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  4. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  5. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  6. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  7. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  8. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  9. KNN 算法-实战篇-如何识别手写数字

    公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...

随机推荐

  1. OpenERP report doesn't work

    1. When you have used OpenOffice edited  one of reports,it has stored the report's banary data is da ...

  2. ecmall的物流配送体系改造

    接触多了ecshop.ecmall原始逻辑的,一般都习惯以整单的方式统一计算运费,这是一种很简单的思路. 但淘宝多了,就发现,物流运费没有那么简单. 首先,每种商品单独设置运费的体系,或者叫运费模板: ...

  3. [转]OpenMP中的private/firstprivate/lastprivate/threadprivate之间的比较

    转自:http://blog.csdn.net/gengshenghong/article/details/6985431 private/firstprivate/lastprivate/threa ...

  4. 木马suppoie 处理的几个思路 木马文件的权限所有者 属主数组 定时任务 目录权限

    木马suppoie 处理的几个思路  木马文件的权限所有者  属主数组  定时任务   目录权限

  5. CentOS安装python setuptools and pip

    安装setup-tools wget https://pypi.python.org/packages/2.7/s/setuptools/setuptools-0.6c11-py2.7.egg --n ...

  6. Flea Circus(Project Euler 213)

    original version hackerrank programming version 题目大意是N*N的格子,每个格子一开始有1个跳蚤,每过单位时间跳蚤会等概率向四周跳,问M秒后空格子的期望 ...

  7. 【 D3.js 入门系列 --- 9.4 】 集群图的制作

    本人的个人博客为: www.ourd3js.com csdn博客为: blog.csdn.net/lzhlzz 转载请注明出处,谢谢. 集群图( Cluster )通经常使用于表示包括与被包括关系. ...

  8. tcpdf

    将文档整为pdf格式文档 网址:http://www.tcpdf.org/examples.php

  9. 手动编译svn

    #!/bin/bash yum -y remove subversionmkdir -p /dist/{dist,src}cd /dist/dist/bin/rm -f openssl* subver ...

  10. js判断css动画效果是否结束

    <!-- css样式 --> <style> .test{ width: 100px; height: 100px; transition: all 5s; backgroun ...