一:MNIST数据集   

下载地址

MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件

分别是test set images,test set labels,training set images,training set labels

training set包括60000个样本,test set包括10000个样本。

test set中前5000个样本来自原始的NISTtraining set,后5000个样本来自原始的NIST test set,因此,前5000个样本比后5000个样本更简单和干净。

每个样本是28*28像素的图片

二:tensorflow构建模型识别MNIST

导入数据:

from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
import tensorflow as tf
x = tf.placeholder(tf.float32, shape=[None, 784])
y_ = tf.placeholder(tf.float32, shape=[None, 10]) #真实值
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #预测值

softmax的目的:将输出转化为是每个数字的概率

#计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)

交叉熵:衡量预测值与真实值之间的差别,当然是越小越好

公式为:

其中y'是真实值,y为预测值

最后用梯度下降法优化参数即可

在Session中运行graph:

total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y})

 预测正确率:

correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

tf.argmax()函数返回axis轴上最大值的index

tf.equal()函数返回的是布尔值,需要用tf.cast()方法转为tf.float32类型

最后在test set上进行预测:

step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))

完整代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf mnist = input_data.read_data_sets('MNIST_data/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
y_label = tf.placeholder(tf.float32, [None, 10])
w = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, w) + b) #计算交叉熵
cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_label *tf.log(y), reduction_indices=[1]))
train = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
#eval
correct_prediction = tf.equal(tf.argmax(y, axis=1), tf.argmax(y_label, axis=1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) total_steps = 5000
batch_size = 100
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for step in range(total_steps+1):
batch_x, batch_y = mnist.train.next_batch(batch_size)
sess.run(train,feed_dict={x: batch_x, y_label: batch_y}) step_per_test = 100
if step % step_per_test == 0:
print(step, sess.run(accuracy, feed_dict={x: mnist.test.images, y_label: mnist.test.labels}))

运行结果:

准确率为0.92左右

后面我们会构建更好的模型达到更高的正确率。

相关链接:

详解 MNIST 数据集

基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型

基于tensorflow的MNIST手写数字识别(二)--入门篇

基于tensorflow的MNIST手写数字识别(三)--神经网络篇

基于TensorFlow的MNIST手写数字识别-初级的更多相关文章

  1. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  2. 基于TensorFlow的MNIST手写数字识别-深入

    构建多层卷积神经网络时需要多组W和偏移项b,我们封装2个方法来产生W和b 初级MNIST中用0初始化W和b,这里用噪声初始化进行对称打破,防止产生梯度0,同时用一个小的正值来初始化b避免dead ne ...

  3. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  4. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  5. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  6. [Python]基于CNN的MNIST手写数字识别

    目录 一.背景介绍 1.1 卷积神经网络 1.2 深度学习框架 1.3 MNIST 数据集 二.方法和原理 2.1 部署网络模型 (1)权重初始化 (2)卷积和池化 (3)搭建卷积层1 (4)搭建卷积 ...

  7. Tensorflow之MNIST手写数字识别:分类问题(2)

    整体代码: #数据读取 import tensorflow as tf import matplotlib.pyplot as plt import numpy as np from tensorfl ...

  8. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  9. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

随机推荐

  1. Linux-Cacti监控{Verson:1.2.8}

    首先需要一个LAMP平台 或LNMP平台 yum -y install httpd mariadb php mariadb-server mariadb-devel zlib freetype lib ...

  2. Angular.的简单运用

    从script引用angular文件.开始编写angular事件: 在angular文件中添加属性: ag-xxxx;初始化使用: ng-app="name"; 没有这个属性就不会 ...

  3. Celery 收下这捆芹菜!

    目录 Celery简介 Celery构成 Task Broker Worker Backend Celery使用 安装 基本使用 异步任务: delay 延迟任务: apply_async 周期任务: ...

  4. Antd组件库使用方法

    零.介绍: Ant design,是阿里巴巴的蚂蚁金服公司设计的一套适应用于web端和移动端网页的Ui组件库,组件好看,非常适合React框架使用. 官网:https://ant.design/ind ...

  5. python切片(获取一个子列表(数组))

    切片: 切片指从现有列表中,获取一个子列表 返回一个新列表,不影响原列表. 下标以 0 开始: list = ['红','绿','蓝','白','黑','黄','青']# 下标 0 1 2 3 4 5 ...

  6. 常见基本数据结构——树,二叉树,二叉查找树,AVL树

    常见数据结构——树 处理大量的数据时,链表的线性时间太慢了,不宜使用.在树的数据结构中,其大部分的运行时间平均为O(logN).并且通过对树结构的修改,我们能够保证它的最坏情形下上述的时间界. 树的定 ...

  7. qiniuLive 连麦流程介绍

    本文出自APICloud官方论坛 qiniuLive 封装了七牛直播云服务平台的移动端开放 SDK.该模块包括视频流采集和视频流播放两部分 iOS连麦流程图: Android连麦流程图: 以下部分代码 ...

  8. 洛谷 UVA11021 Tribles

    UVA11021 Tribles 题意翻译 题目大意 一开始有kk种生物,这种生物只能活1天,死的时候有p_ipi​的概率产生ii只这种生物(也只能活一天),询问m天内所有生物都死的概率(包括m天前死 ...

  9. nor flash之擦除和写入

    最近研究了下nor flash的掉电问题,对nor的掉电有了更多的认识.总结分享如下 擦除从0变1,写入从1变0 nor flash的物理特性是,写入之前需要先进行擦除.擦除后数据为全0xFF,此时写 ...

  10. lvs+keepalived部署k8s v1.16.4高可用集群

    一.部署环境 1.1 主机列表 主机名 Centos版本 ip docker version flannel version Keepalived version 主机配置 备注 lvs-keepal ...