欢迎大家关注我们的网站和系列教程:http://www.tensorflownews.com/,学习更多的机器学习、深度学习的知识!

手写数字识别

接下来将会以 MNIST 数据集为例,使用卷积层和池化层,实现一个卷积神经网络来进行手写数字识别,并输出卷积和池化效果。

数据准备

  • MNIST 数据集下载

MNIST 数据集可以从 THE MNIST DATABASE of handwritten digits 的网站直接下载。

网址:http://yann.lecun.com/exdb/mnist/

train-images-idx3-ubyte.gz: 训练集图片

train-labels-idx1-ubyte.gz: 训练集列标

t10k-images-idx3-ubyte.gz: 测试集图片

t10k-labels-idx1-ubyte.gz: 测试集列标

TensorFlow 有加载 MNIST 数据库相关的模块,可以在程序运行时直接加载。

代码如下:

from tensorflow.examples.tutorials.mnist import input_data
import matplotlib.pyplot as pyplot #引入 MNIST 数据集
mnist = input_data.read_data_sets("/tmp/data/", one_hot=False) #选取训练集中的第 1 个图像的矩阵
mnist_one=mnist.train.images[0] #输出图片的维度,结果是:(784,)
print(mnist_one.shape) #因为原始的数据是长度是 784 向量,需要转换成 28*28 的矩阵。
mnist_one_image=mnist_one.reshape((28,28)) #输出矩阵的维度
print(mnist_one_image.shape) #使用 matplotlib 输出为图片
pyplot.imshow(mnist_one_image) pyplot.show()

代码的输出依次是:

1.单个手写数字图片的维度:

(784,)

2.转化为二维矩阵之后的打印结果:

(28, 28)

3.使用 matplotlib 输出为图片

模型实现

TensorFlow conv2d 函数介绍:

tf.nn.conv2d(x, W, strides, padding=’SAME’)

针对输入的 4 维数据 x 计算 2 维卷积。

参数 x:

4 维张量,每一个维度分别是 batch,in_height,in_height,in_channels。

[batch, in_height, in_width, in_channels]

灰度图像只有 2 维来表示每一个像素的值,彩色图像每一个像素点有 3 通道的 RGB 值,所以一个彩色图片转化成张量后是 3 维的,分别是长度,宽度,颜色通道数。又因为每一次训练都是训练都是输入很多张图片,所以,多个 3 维张量组合在一起变成了 4 维张量。

参数 w:

过滤器,因为是二维卷积,所以它的维度是:

[filter_height, filter_width, in_channels, out_channels]

与参数 x 对应,前 3 个参数分别是对应 x 的 filter_height, filter_width, in_channels,最后一个参数是过滤器的输出通道数量。

参数 strides:

1 维长度为 4 的张量,对应参数 x 的 4 个维度上的步长。

参数 padding:

边缘填充方式,主要是 “SAME”, “VALID”,一般使用 “SAME”。

卷积层简单封装
# 池化操作
def conv2d(x, W, b, strides=1):
# Conv2D wrapper, with bias and relu activation
x = tf.nn.conv2d(x, W, strides=[1, strides, strides, 1], padding='SAME')
x = tf.nn.bias_add(x, b)
return tf.nn.relu(x)
TensorFlow max_pool 函数介绍:

tf.nn.max_pool(x, ksize, strides ,padding)

参数 x:

和 conv2d 的参数 x 相同,是一个 4 维张量,每一个维度分别代表 batch,in_height,in_height,in_channels。

参数 ksize:

池化核的大小,是一个 1 维长度为 4 的张量,对应参数 x 的 4 个维度上的池化大小。

参数 strides:

1 维长度为 4 的张量,对应参数 x 的 4 个维度上的步长。

参数 padding:

边缘填充方式,主要是 “SAME”, “VALID”,一般使用 “SAME”。

接下来将会使用 TensorFlow 实现以下结构的卷积神经网络:

下一篇文章,将会用 TensorFlow 实现这个卷积神经网络。

本篇文章出自http://www.tensorflownews.com,对深度学习感兴趣,热爱Tensorflow的小伙伴,欢迎关注我们的网站!

TensorFlow 卷积神经网络手写数字识别数据集介绍的更多相关文章

  1. 深度学习-使用cuda加速卷积神经网络-手写数字识别准确率99.7%

    源码和运行结果 cuda:https://github.com/zhxfl/CUDA-CNN C语言版本参考自:http://eric-yuan.me/ 针对著名手写数字识别的库mnist,准确率是9 ...

  2. 吴裕雄 python 神经网络——TensorFlow 卷积神经网络手写数字图片识别

    import os import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data INPUT_N ...

  3. 实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu

    对手写数据集50000张图片实现阿拉伯数字0~9识别,并且对结果进行分析准确率, 手写数字数据集下载:http://yann.lecun.com/exdb/mnist/ 首先,利用图片本身的属性,图片 ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 基于Numpy的神经网络+手写数字识别

    基于Numpy的神经网络+手写数字识别 本文代码来自Tariq Rashid所著<Python神经网络编程> 代码分为三个部分,框架如下所示: # neural network class ...

  8. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  9. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

随机推荐

  1. Linux系统发行版本及其区别

    1 Linux系统组成 Linux操作系统=Linux内核+GNU软件及系统软件+必要的应用程序.下表为Linux系统各组成部分的贡献人员: Linux内核 GNU组件(gcc.bash) 其他必要应 ...

  2. JS逆向某网站登录密码分析

    声明: 本文仅供研究学习使用,请勿用于非法用途! 目标网站 aHR0cHM6Ly9hdXRoLmFsaXBheS5jb20vbG9naW4vaW5kZXguaHRt 今日目标网站是某知名支付网站,感觉 ...

  3. 【WPF学习】第五十三章 动画类型回顾

    创建动画面临的第一个挑战是为动画选择正确的属性.期望的结果(例如,在窗口中移动元素)与需要使用的属性(在这种情况下是Canvas.Left和Canvas.Top属性)之间的关系并不总是很直观.下面是一 ...

  4. Java面试必问之Hashmap底层实现原理(JDK1.8)

    1. 前言 上一篇从源码方面了解了JDK1.7中Hashmap的实现原理,可以看到其源码相对还是比较简单的.本篇笔者和大家一起学习下JDK1.8下Hashmap的实现.JDK1.8中对Hashmap做 ...

  5. 「ReStory」在 Markdown 中自由书写 React 组件 (Beta)

    介绍 先睹为快 我们在开发一个小小的 React 组件库,但是我们遇到了一个大难题,那就是为我们的组件库书写一个合理的文档. 作为组件文档,我们非常希望我们的组件用例代码能够展现出来,是的我们在书写文 ...

  6. 微服务架构-Gradle下载安装配置教程

    一.开发条件 JDK8下载地址:https://www.oracle.com/java/technologies/javase-jdk8-downloads.html Eclipse下载地址:http ...

  7. JS模块规范:AMD,CMD,CommonJS

    浅析JS模块规范 随着JS模块化编程的发展,处理模块之间的依赖关系成为了维护的关键. AMD,CMD,CommonJS是目前最常用的三种模块化书写规范. CommonJS CommonJS规范是诞生比 ...

  8. 在Deepin系统上装Python 3.8遇到的那些坑

    - 作为一天时间在Deepin上都没装好Python的代表,我感觉有必要记录一下我自己的解决方法 坑1-- SSL/TLS 字样错误 "pip is configured wih locat ...

  9. 【BIM】BIMFACE中创建矢量文本

    背景 在三维模型产品的设计中,针对空间的管理存在这样一个普遍的需求,那就是在三维模型中,将模型所代表的空间通过附加文本的方式来展示其所代表的实际位置或功能,之前尝试过若干方式,比如直接在建模的时候,将 ...

  10. burpsuit之Spider、Scanner、Intruder模块

    1.spider模块 1.spider模块介绍 被动爬网:(被动爬网获得的链接是手动爬网的时候返回页面的信息中分析发现超链接) 对于爬网的时候遇到HTML表单如何操作: 需要表单身份认证时如何操作(默 ...