TensorFlow实战第五课(MNIST手写数据集识别)
Tensorflow实现softmax regression识别手写数字
MNIST手写数字识别可以形象的描述为机器学习领域中的hello world。
MNIST是一个非常简单的机器视觉数据集。它由几万张28*28像素的手写数字组成,这些图片只包含灰度值信息。我们的任务就是对这些手写数字进行分类。转换为0-9共十个分类。
首先在命令行中运行如下代码加载MNIST手写数据集:
from tensorflow.examples.tutorials.mnist import input_data
#number 1 to 10 data
#创建文件夹存放数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
数据集中包含55000个样本,测试集中有10000个样本,同时验证集有5000个样本。每一个样本都有他对应的标注信息,即label。
我们将在训练集上训练模型,在验证集上检验效果并决定何时完成训练,最后我们在测评及评测模型效果。
准备好数据后我们开始设计算法。我们采用的是softmax regression的算法训练手写数字识别的分类模型。数字分为0-9,所以一共有十个类别,当我们对一张图片进行预测时,softmax regression会对每一种类别估算一个概率,然后取估算概率最大的数字作为模型的输出结果。
注:当我们处理多分类模型时,通常需要使用softmax regression。例如卷积神经网络或者循环神经网络,如果是分类模型,那么最后一层同样是softmax regression。
loss函数选择的是交叉熵函数,交叉熵用来衡量预测值与真实值的相似程度,如果完全相同,他们的交叉熵等于零。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) # loss
train方法(最优化方法)采用梯度下降法。
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.Session()
# tf.initialize_all_variables() 这种写法马上就要被废弃
# 替换成下面的写法:
sess.run(tf.global_variables_initializer())
完整代码:
#classification 分类学习 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data
#number 1 to 10 data
#创建文件夹存放数据
mnist = input_data.read_data_sets('MNIST_data',one_hot=True) def add_layer(inputs,in_size,out_size,activation_function=None):
#添加一个以上的层 并且返回这个层的输出 Weights = tf.Variable(tf.random_normal([in_size,out_size]))
biases = tf.Variable(tf.zeros([1,out_size])+0.1)
Wx_plus_b = tf.matmul(inputs,Weights)+biases if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs def compute_accuracy(v_xs,v_ys):
global prediction
y_pre = sess.run(prediction,feed_dict={xs:v_xs})
correct_prediction = tf.equal(tf.argmax(y_pre,1),tf.argmax(v_ys,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
return result #define placeholder for inputs to network
xs = tf.placeholder(tf.float32,[None,784])#None就是不规定他有多少sample,但是规定大小为28*28
ys = tf.placeholder(tf.float32,[None,10]) #add output layer
#激励函数采用softmax函数
prediction = add_layer(xs,784,10,activation_function=tf.nn.softmax) # the error between prediction and real data
'''loss函数即最优化目标函数 选用交叉熵函数
交叉熵用来衡量预测值和真实值相似程度
如果完全相同 ,他们的交叉熵为零
'''
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction),
reduction_indices=[1])) # loss
#采用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init) for i in range(2000):
#每次只取100张图片
batch_xs,batch_ys = mnist.train.next_batch(100)
sess.run(train_step,feed_dict={xs:batch_xs,ys:batch_ys})
if i%50==0:
print(compute_accuracy(mnist.test.images,mnist.test.labels))
输出结果:

TensorFlow实战第五课(MNIST手写数据集识别)的更多相关文章
- TensorFlow系列专题(六):实战项目Mnist手写数据集识别
欢迎大家关注我们的网站和系列教程:http://panchuang.net/ ,学习更多的机器学习.深度学习的知识! 目录: 导读 MNIST数据集 数据处理 单层隐藏层神经网络的实现 多层隐藏层神经 ...
- Tensorflow项目实战一:MNIST手写数字识别
此模型中,输入是28*28*1的图片,经过两个卷积层(卷积+池化)层之后,尺寸变为7*7*64,将最后一个卷积层展成一个以为向量,然后接两个全连接层,第一个全连接层加一个dropout,最后一个全连接 ...
- 吴裕雄 python 神经网络——TensorFlow 实现LeNet-5模型处理MNIST手写数据集
import os import numpy as np import tensorflow as tf from tensorflow.examples.tutorials.mnist import ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
随机推荐
- Pycharm2019最新激活码
激活pycharm的方法有很多,一种是使用最新的激活码,另一种是使用破解补丁的方式(可以长期使用) pycharm2019最新激活码: 812LFWMRSH-eyJsaWNlbnNlSWQiOiI4M ...
- HDU 6039 - Gear Up | 2017 Multi-University Training Contest 1
建模简析: /* HDU 6039 - Gear Up [ 建模,线段树,图论 ] | 2017 Multi-University Training Contest 1 题意: 给你n个齿轮,有些齿轮 ...
- nginx负载均衡 页面缓存
nginx的upstream目前支持4种方式的分配 1.轮询(默认) 每个请求按时间顺序逐一分配到不同的后端服务器,如果后端服务器down掉,能自动剔除. 2.weight 指定轮询几率,weight ...
- Python 运算符优先级
这个表给出Python的运算符优先级(从低到高). 从最低的优先级(最松散地结合)到最高的优先级(最紧密地结合). 这意味着在一个表达式中,Python会首先计算表中较下面的运算符,然后在计算列在表上 ...
- flask框架(一):初入
1.装饰器回顾 # -*- coding: utf-8 -*- # @Author : Felix Wang # @time : 2018/7/3 17:10 import functools &qu ...
- 灰度图像--图像分割 Marr-Hildreth算子(LoG算子)
学习DIP第49天 转载请标明本文出处:*http://blog.csdn.net/tonyshengtan *,出于尊重文章作者的劳动,转载请标明出处!文章代码已托管,欢迎共同开发: https:/ ...
- sql 语句中 order by 的用法
order by 是用在where条件之后,用来对查询结果进行排序 order by 字段名 asc/desc asc 表示升序(默认为asc,可以省略) desc表示降序 order by 无法用于 ...
- Java线程之如何分析死锁及避免死锁
什么是死锁 java中的死锁是一种编程情况,其中两个或多个线程被永久阻塞,Java死锁情况出现至少两个线程和两个或更多资源. 在这里,我们将写了一个简单的程序,它将导致java死锁场景,然后我们将分析 ...
- __new()__与__init__()
1. __new__:创建对象时调用,会返回当前对象的一个实例.(默认情况下也就是你在类中没有没有重新这个方法,会默认返回当前类的示例,如果你重写了这个方法,但是在方法中没有返回当前类的示例,那么也就 ...
- Backen-Development record 1
单例模式 在应用这个模式时,单例对象的类必须保证只有一个实例存在. 服务进程中的其他对象再通过这个单例对象获取这些配置信息.这种方式简化了在复杂环境下的配置管理. __new__实现 用装饰器实现单例 ...