3层-CNN卷积神经网络预测MNIST数字

本文创建一个简单的三层卷积网络来预测 MNIST 数字。这个深层网络由两个带有 ReLU 和 maxpool 的卷积层以及两个全连接层组成。



MNIST 由 60000 个手写体数字的图片组成。本文的目标是高精度地识别这些数字。

具体实现过程

  1. 导入 tensorflow、matplotlib、random 和 numpy。然后,导入
    mnist 数据集并进行独热编码。请注意,TensorFlow 有一些内置的库来处理 MNIST,也会用到它们:

  1. 仔细观察一些数据有助于理解 MNIST 数据集。了解训练数据集中有多少张图片,测试数据集中有多少张图片。可视化一些数字,以便了解它们是如何表示的。这种输出可以对于识别手写体数字的难度有一种视觉感知,即使是对于人类来说也是如此。

上述代码的输出:

图 1 MNIST手写数字的一个例子

  1. 设置学习参数 batch_size和display_step。另外,MNIST 图片都是 28×28 像素,因此设置 n_input=784,n_classes=10 代表输出数字 [0-9],并且 dropout 概率是 0.85,则:

  1. 设置 TensorFlow 计算图的输入。定义两个占位符来存储预测值和真实标签:

  1. 定义一个输入为 x,权值为 W,偏置为 b,给定步幅的卷积层。激活函数是 ReLU,padding 设定为
    SAME 模式:

  1. 定义一个输入是 x 的 maxpool 层,卷积核为
    ksize 并且 padding 为 SAME:

  1. 定义 convnet,构成是两个卷积层,然后是全连接层,一个 dropout 层,最后是输出层:

  1. 定义网络层的权重和偏置。第一个 conv 层有一个 5×5 的卷积核,1 个输入和 32 个输出。第二个
    conv 层有一个 5×5 的卷积核,32 个输入和 64 个输出。全连接层有
    7×7×64 个输入和 1024 个输出,而第二层有 1024 个输入和 10 个输出对应于最后的数字数目。所有的权重和偏置用 randon_normal 分布完成初始化:

  1. 建立一个给定权重和偏置的 convnet。定义基于 cross_entropy_with_logits 的损失函数,并使用 Adam 优化器进行损失最小化。优化后,计算精度:

  1. 启动计算图并迭代 training_iterats次,其中每次输入 batch_size 个数据进行优化。用从 mnist 数据集分离出的 mnist.train 数据进行训练。每进行 display_step 次迭代,会计算当前的精度。最后,在 2048 个测试图片上计算精度,此时无 dropout。

  1. 画出每次迭代的 Softmax 损失以及训练和测试的精度:

以下是上述代码的输出。首先看一下每次迭代的 Softmax 损失:

图 2 减少损失的一个例子

再来看一下训练和测试的精度:

图 3 训练和测试精度上升的一个例子

解读分析

使用 ConvNet,在 MNIST 数据集上的表现提高到了近 95% 的精度。ConvNet 的前两层网络由卷积、ReLU 激活函数和最大池化部分组成,然后是两层全连接层(含dropout)。训练的 batch 大小为 128,使用 Adam 优化器,学习率为 0.001,最大迭代次数为 500 次。

3层-CNN卷积神经网络预测MNIST数字的更多相关文章

  1. TensorFlow——CNN卷积神经网络处理Mnist数据集

    CNN卷积神经网络处理Mnist数据集 CNN模型结构: 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层池化:池化视野2*2,步长为2 第二层卷积 ...

  2. tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图

    tensorflow CNN 卷积神经网络中的卷积层和池化层的代码和效果图 因为很多 demo 都比较复杂,专门抽出这两个函数,写的 demo. 更多教程:http://www.tensorflown ...

  3. Deep Learning模型之:CNN卷积神经网络(一)深度解析CNN

    http://m.blog.csdn.net/blog/wu010555688/24487301 本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep le ...

  4. cnn(卷积神经网络)比较系统的讲解

    本文整理了网上几位大牛的博客,详细地讲解了CNN的基础结构与核心思想,欢迎交流. [1]Deep learning简介 [2]Deep Learning训练过程 [3]Deep Learning模型之 ...

  5. Keras(四)CNN 卷积神经网络 RNN 循环神经网络 原理及实例

    CNN 卷积神经网络 卷积 池化 https://www.cnblogs.com/peng8098/p/nlp_16.html 中有介绍 以数据集MNIST构建一个卷积神经网路 from keras. ...

  6. Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现(转)

    Deep Learning论文笔记之(四)CNN卷积神经网络推导和实现 zouxy09@qq.com http://blog.csdn.net/zouxy09          自己平时看了一些论文, ...

  7. day-16 CNN卷积神经网络算法之Max pooling池化操作学习

    利用CNN卷积神经网络进行训练时,进行完卷积运算,还需要接着进行Max pooling池化操作,目的是在尽量不丢失图像特征前期下,对图像进行downsampling. 首先看下max pooling的 ...

  8. 人工智能——CNN卷积神经网络项目之猫狗分类

    首先先导入所需要的库 import sys from matplotlib import pyplot from tensorflow.keras.utils import to_categorica ...

  9. [转]Theano下用CNN(卷积神经网络)做车牌中文字符OCR

    Theano下用CNN(卷积神经网络)做车牌中文字符OCR 原文地址:http://m.blog.csdn.net/article/details?id=50989742 之前时间一直在看 Micha ...

随机推荐

  1. 【Java集合】JDK1.7和1.8 HashMap有什么区别

    JDK1.7和1.8 HashMap区别: 1.数组+链表改成了数组+链表或红黑树: 2.表的插入方式从头插法改成了尾插法,简单说就是插入时,如果数组位置上已经有元素,1.7将新元素放到数组中,原始节 ...

  2. 反病毒攻防研究第005篇:简单木马分析与防范part1

    一.前言 病毒与木马技术发展到今天,由于二者总是相辅相成,你中有我,我中有你,所以它们之间的界限往往已经不再那么明显,相互之间往往都会采用对方的一些技术以达到自己的目的,所以现在很多时候也就将二者直接 ...

  3. 在进程空间使用虚拟内存(Windows 核心编程)

    虚拟内存空间 如今的 Windows 操作系统不仅可以运行多个应用程序,还可以让每一个应用程序享受到约 4 GB 的虚拟内存空间(包括系统占用),假如内存为 4 GB 的话.那为什么 Window 可 ...

  4. <JVM上篇:内存与垃圾回收篇>02-类加载子系统

    笔记来源:尚硅谷JVM全套教程,百万播放,全网巅峰(宋红康详解java虚拟机) 同步更新:https://gitee.com/vectorx/NOTE_JVM https://codechina.cs ...

  5. 如何使用java搭建一款高性能的Mqtt集群broker!

    SMQTT是一款开源的MQTT消息代理Broker, SMQTT基于Netty开发,底层采用Reactor3反应堆模型,支持单机部署,支持容器化部署,具备低延迟,高吞吐量,支持百万TCP连接,同时支持 ...

  6. Pytest自动化测试-简易入门教程(03)

    今天分享内容的重点,和大家来讲一下我们的测试框架--Pytest 讲到这个框架的话呢,可能有伙伴就会问老师,我在学习自动化测试过程中,我们要去学一些什么东西? 第一个肯定要学会的是一门编程语言,比如说 ...

  7. Kafka源码分析(二) - 生产者

    系列文章目录 https://zhuanlan.zhihu.com/p/367683572 目录 系列文章目录 一. 使用方式 step 1: 设置必要参数 step 2: 创建KafkaProduc ...

  8. ThreadLocal内存溢出代码演示和原因分析!

    ThreadLocal 翻译成中文是线程本地变量的意思,也就是说它是线程中的私有变量,每个线程只能操作自己的私有变量,所以不会造成线程不安全的问题. ​ 线程不安全是指,多个线程在同一时刻对同一个全局 ...

  9. 最简单的方法是使用标准的 Linux GUI 程序之一: i-nex 收集硬件信息,并且类似于 Windows 下流行的 CPU-Z 的显示。 HardInfo 显示硬件具体信息,甚至包括一组八个的流行的性能基准程序,你可以用它们评估你的系统性能。 KInfoCenter 和 Lshw 也能够显示硬件的详细信息,并且可以从许多软件仓库中获取。

    最简单的方法是使用标准的 Linux GUI 程序之一: i-nex 收集硬件信息,并且类似于 Windows 下流行的 CPU-Z 的显示. HardInfo 显示硬件具体信息,甚至包括一组八个的流 ...

  10. 云计算OpenStack共享组件---信息队列rabbitmq(2)

    一.MQ 全称为 Message Queue, 消息队列( MQ ) 是一种应用程序对应用程序的通信方法.应用程序通过读写出入队列的消息(针对应用程序的数据)来通信,而无需专用连接来链接它们. 消息传 ...