3 TensorFlow入门之识别手写数字
————————————————————————————————————
写在开头:此文参照莫烦python教程(墙裂推荐!!!)
————————————————————————————————————
分类实验之识别手写数字
- 这个实验的内容是:基于TensorFlow,实现手写数字的识别。
- 这里用到的数据集是大家熟知的mnist数据集。
- mnist有五万多张手写数字的图片,每个图片用28x28的像素矩阵表示。所以我们的输入层每个案列的特征个数就有28x28=784个;因为数字有0,1,2…9共十个,所以我们的输出层是个1x10的向量。输出层是十个小于1的非负数,表示该预测是0,1,2…9的概率,我们选取最大概率所对应的数字作为我们的最终预测。
- 真实的数字表示为该数字所对应的位置为1,其余位置为0的1x10的向量。
下面就开始实验啦!
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
#导入数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)#如果还没下载mnist就下载
#定义添加层
def add_layer(inputs,in_size,out_size,activation_function=None):
#定义添加层内容,返回这层的outputs
Weights = tf.Variable(tf.random_normal([in_size,out_size]))#Weigehts是一个in_size行、out_size列的矩阵,开始时用随机数填满
biases = tf.Variable(tf.zeros([1,out_size])+0.1) #biases是一个1行out_size列的矩阵,用0.1填满
Wx_plus_b = tf.matmul(inputs,Weights)+biases #预测
if activation_function is None: #如果没有激励函数,那么outputs就是预测值
outputs = Wx_plus_b
else: #如果有激励函数,那么outputs就是激励函数作用于预测值之后的值
outputs = activation_function(Wx_plus_b)
return outputs
#定义计算正确率的函数
def t_accuracy(t_xs,t_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:t_xs})
correct_pre = tf.equal(tf.argmax(y_pre,1),tf.argmax(t_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_pre,tf.float32))
result = sess.run(accuracy,feed_dict={xs:t_xs,ys:t_ys})
return result
#定义神经网络的输入值和输出值
xs = tf.placeholder(tf.float32,[None,784]) #None是不规定大小,这里指的是案例个数,而输入特征个数为28x28 = 784
ys = tf.placeholder(tf.float32,[None,10]) #Nnoe也是案例个数,不做规定;10是因为有10个数字,所以输出是10
#增加输出层
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax)#这里的激励函数是softmax,此函数多用于多类分类
#计算误差
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys*tf.log(prediction),reduction_indices=[1])) #此误差计算方式和softmax配套用,效果好
#训练
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)#学习因子为0.5
#开始训练
sess = tf.Session()
sess.run(tf.initialize_all_variables())
for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(100) #提取数据集的100个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
0.1849
0.6537
0.7393
0.7836
0.8053
0.8203
0.8275
0.837
0.8465
0.8504
0.8567
0.8571
0.8643
0.8637
0.8664
0.8687
0.8719
0.8742
0.8763
0.8773
上面4行就是下载的mnist数据集的四个文件。然后看打印出来的正确率可知,这个网络的预测能力是越来越好的。
下面试一下啊,抽取500个数据来训练,看看效果如何:
for i in range(1000):
batch_xs,batch_ys = mnist.train.next_batch(500) #提取数据集的500个数据,因为原来数据太大了
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50 == 0:
print (t_accuracy(mnist.test.images,mnist.test.labels)) #每隔50个,打印一下正确率。注意:这里是要用test的数据来测试
0.9001
0.9022
0.9023
0.9026
0.903
0.903
0.9037
0.9036
0.9034
0.9027
0.9041
0.903
0.9039
0.9034
0.9037
0.9046
0.9055
0.9045
0.9053
0.905
由上面打印出来的正确率可知,抽取500个数据来训练的话,正确率会达到90%
*点击[这儿:TensorFlow]发现更多关于TensorFlow的文章*
3 TensorFlow入门之识别手写数字的更多相关文章
- 6 TensorFlow实现cnn识别手写数字
------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...
- 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字
TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...
- TensorFlow实战之Softmax Regression识别手写数字
关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...
- 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)
笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- python手写神经网络实现识别手写数字
写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...
- 用BP人工神经网络识别手写数字
http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...
- python机器学习使用PCA降维识别手写数字
PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
随机推荐
- 使用StringTokenizer分解字符串
Java切割字符串.一般使用substring.split.StringTokenizer来处理,前两种是String对象的方法,使用字符串能够直接处理,本文介绍下StringTokenizer的使用 ...
- RelativeSource.TemplatedParent 属性wpf
今天看到这一句代码时候,自己只是知道绑定了,可是不知道绑定了什么啊 就去查了一下,后来说的好像是绑定的TemplateParent返回的 一个值.可是这是为什么呢, 有的说是绑定的是一个资源. 下面有 ...
- mui 子页面切换父页面底部导航
在父页面中新增方法: function switchTab(tab){ plus.webview.hide(activeTab); activeTab= tab; plus.webview.show( ...
- UML学习目录
用例图:http://www.cnblogs.com/yjjm/archive/2012/01/28/2385861.html http://kb.cnblogs.com/page/129491/
- VC++显示文件或文件夹属性
When you select a file or folder in Explorer window, and choose 'Properties' from the menu, you get ...
- 动画间隔AnimationInterval 场景切换、图层叠加
从这一个月的学习进度上来看算比较慢的了,从开始学习C++到初试cocos,这也是我做过的比较大的决定,从工作中里挤出时间来玩玩自己喜欢的游戏开发也是一件非常幸福的事情,虽然现在对cocos的了解还只是 ...
- 模拟ORA-26040: Data block was loaded using the NOLOGGING option
我们知道通过设置nologging选项.能够加快oracle的某些操作的运行速度,这在运行某些维护任务时是非常实用的,可是该选项也非常危急,假设使用不当,就可能导致数据库发生ORA-26040错误. ...
- redis问题集
Redis有哪些数据结构? 字符串String.字典Hash.列表List.集合Set.有序集合SortedSet. 如果你是Redis中高级用户,还需要加上下面几种数据结构HyperLogLog.G ...
- hdu1066(经典题)
求N个数阶乘末尾除0后的数值. 主要的难点在于要把这个N个数所含的2和5的队数去掉. 网上方法很多很好. 不多说 Last non-zero Digit in N! Time Limit: 2000/ ...
- 《从零开始学Swift》学习笔记(Day 57)——Swift编码规范之注释规范:文件注释、文档注释、代码注释、使用地标注释
原创文章,欢迎转载.转载请注明:关东升的博客 前面说到Swift注释的语法有两种:单行注释(//)和多行注释(/*...*/).这里来介绍一下他们的使用规范. 1.文件注释 文件注释就在每一个文件开头 ...