《手写数字识别——手动搭建全连接层》一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配、梯度计算、准确度的统计等问题,但是这样的实践对机器学习的理解是大有裨益的。在大多数情况下,我们还是希望能多简单就多简单地去搭建网络模型,这同时也算对得起TensorFlow这个强大的工具了。本节,还是以手写数据集MNIST为例,利用TensorFlow2.0的keras高层API重现之前的网络。

一、数据的导入与预处理


关于这个过程,与上节讲过的类似,就不再赘述了。需要提一点的就是,为了程序的整洁,将数据类型的转换过程单独写成一个预处理函数preprocess,通过Dataset对象的map方法应用该预处理函数。整个数据导入与预处理代码如下:

import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个批次数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess)

二、模型构建


对于全连接层,keras提供了layers.Dense(units,activation)接口,利用它可以建立一层layer,多层堆叠放入keras提供的Sequential容器中,就形成了一个网络模型。在Dense的参数中,units决定了这层layer含有的神经元数量,activation是激活函数的选择。同之前的网络一样,我们的网络传播可以看做是:input(784 units)->layer1(256 units)->ReLu->layer2(128 units)->ReLu->output(10 units)。因此,在Sequential容器中定义后三层,activation指定为ReLu,而输入层需要通过build时候指定input_shape来告诉网络输入层的神经元数量。构建的代码如下,通过summary方法可以打印网络信息。

#网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary()

三、模型的训练


模型的训练最重要的就是权重更新和准确度统计。keras提供了多种优化器(optimizer)用于更新权重。优化器实际就是不同的梯度下降算法,缓解了传统梯度下降可能无法收敛到全局最小值的问题。在上一节中就稍加讨论了三种。这里就简单对比一下一些优化器,至于详细的区别今后有时间再写篇随笔专门讨论:

  1. SGD:TensorFlow2.0 SGD实际是随机梯度下降+动量的综合优化器。随机梯度下降是每次更新随机选取一个样本计算梯度,这样计算梯度快很多,但怕大噪声;动量是在梯度下降的基础上,累计历史梯度信息加速梯度下降,这是因为一方面它想水稀释牛奶一样,能减小随机梯度下降对噪声的敏感度,另一方面动量赋予下降以惯性,可以预见梯度变化。这优化器实话说让我联想到了PID控制。SGD需要指定学习率和动量大小。一般地,动量大小设置为0.9。
  2. Adagrad:采用自适应梯度的优化器。所谓自适应梯度,就是根据参数的频率,对每个参数应用不同的学习速率。但是该算法在迭代次数变得很大时,学习速率会变得很小,导致不能继续更新。Adagrad要求指定初始化的学习速率、累加器初始值和防止分母为0的偏置值。
  3. Adadelta:采用自适应增量的优化器。解决了adagrad算法学习速率消失的问题。Adagrad要求指定初始化的学习速率、衰减率和防止分母为0的偏置值。这个衰减率跟动量差不多,一般也指定为0.9。
  4. RMSprop:类似于Adadelta。
  5. Adam:采用梯度的一阶和二阶矩来估计更新参数。它结合了Adadelta和RMSprop的优点。可以说,深度学习通常都会选择Adam优化器。TensorFlow中,Adam优化器需要指定4个参数,但经验证明,它的默认参数能表现出很好的效果。

鉴于以上对比,此处选用Adam作为优化器,并采用其默认参数。

除了梯度下降,还需要考虑的是Loss的计算方法。之前,我们采用的是预测概率与实际值的差平方的均值,专业名称应该是欧几里得损失函数。其实,这是个错误,欧几里得损失函数适用于二元分类,多元分类应该采用交叉熵损失函数。有时候针对多元函数,我们会很不自觉地想把输出层归一化,于是会在输出层之后,交叉熵计算前先softmax一下。但是由于softmax是采用指数形式进行计算的,如果输出各类概率相差较大,则大概率在归一化后几乎为1,小概率归一化之后几乎为0。为了避免这一问题,通常是去掉softmax,在交叉熵函数tf.losses.CategoricalCrossentropy的参数中指from_logits=True。

Loss函数和优化器配置都可以通过compile方法指定,同时,还可以指定metrics列表来决定需要自动计算的信息,如准确度。

通过fit方法可以传入训练数据和测试数据。代码如下:

#配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=10,validation_data=val_db,validation_freq=1)

以上建立的网络模型在第一次train_db迭代完后就可以达到0.8以上的准确度,而且这个迭代每次仅花费3秒左右。经过大约50次迭代,准确度就可以高达0.98!而通过上一节的方式,要达到这样的准确度,起码得训练半个小时。这其中最主要的差别就在于梯度下降算法的优化。

四、完整代码


 import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess) #网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary() #配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=100,validation_data=val_db,validation_freq=5)

手写数字识别——利用keras高层API快速搭建并优化网络模型的更多相关文章

  1. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  2. 【百度飞桨】手写数字识别模型部署Paddle Inference

    从完成一个简单的『手写数字识别任务』开始,快速了解飞桨框架 API 的使用方法. 模型开发 『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入 ...

  3. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  4. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  5. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  6. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  7. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  8. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

  9. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

随机推荐

  1. 使用requests、re、BeautifulSoup、线程池爬取携程酒店信息并保存到Excel中

    import requests import json import re import csv import threadpool import time, random from bs4 impo ...

  2. getElementsByName和getElementById获取控件

    js对控件的操作通常使用getElementsByName或getElementById来获取不同的控件进行操作 getElementsByName() 得到的是一个array, 不能直接设value ...

  3. 3,HDFS原理

    1,HDFS体系结构 ··· HDFS是采用master/slaves即主从结构模型来管理数据的.这种模型主要由四部分组成,分别是Client.NameNode.DataNode.SecondaryN ...

  4. A——奇怪的玩意(POJ1862)

      题目: 我们的化学生物学家发明了一种新的叫stripies非常神奇的生命.该stripies是透明的无定形变形虫似的生物,生活在果冻状的营养培养基平板菌落.大部分的时间stripies在移动.当他 ...

  5. 0226 rest接口设计

                背景 为了更方便的书写和阐述问题,文章中按照第一人称的角度书写.作为一个以java为主要开发语言的工程师,我所描述的都是java相关的编码和设计. 工程师的静态输出就是代码和文 ...

  6. linux中文件处理命令

    目录 touch cat more less head tail touch 解释 命令名称:touch 命令所在路径:/bin/touch 执行权限:所有用户 功能描述:创建空文件 语法 touch ...

  7. C#设计模式学习笔记:(22)备忘录模式

    本笔记摘抄自:https://www.cnblogs.com/PatrickLiu/p/8176974.html,记录一下学习过程以备后续查用. 一.引言 今天我们要讲行为型设计模式的第十个模式--备 ...

  8. wow.js wow.min.js animate.css animate.min.css

    奉献给下载不到源码的小伙伴,下载到的请忽视 wow.js (function() { var MutationObserver, Util, WeakMap, getComputedStyle, ge ...

  9. SQLServer之查询当前服务器下所有目录视图表

    SQL脚本 /*************1:删除临时表*************/ if exists(select * from tempdb..sysobjects where id=object ...

  10. 【转】JS 的 new 到底是干什么的?

    原文:https://zhuanlan.zhihu.com/p/23987456?refer=study-fe 大部分讲 new 的文章会从面向对象的思路讲起,但是我始终认为,在解释一个事物的时候,不 ...