《手写数字识别——手动搭建全连接层》一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配、梯度计算、准确度的统计等问题,但是这样的实践对机器学习的理解是大有裨益的。在大多数情况下,我们还是希望能多简单就多简单地去搭建网络模型,这同时也算对得起TensorFlow这个强大的工具了。本节,还是以手写数据集MNIST为例,利用TensorFlow2.0的keras高层API重现之前的网络。

一、数据的导入与预处理


关于这个过程,与上节讲过的类似,就不再赘述了。需要提一点的就是,为了程序的整洁,将数据类型的转换过程单独写成一个预处理函数preprocess,通过Dataset对象的map方法应用该预处理函数。整个数据导入与预处理代码如下:

import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个批次数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess)

二、模型构建


对于全连接层,keras提供了layers.Dense(units,activation)接口,利用它可以建立一层layer,多层堆叠放入keras提供的Sequential容器中,就形成了一个网络模型。在Dense的参数中,units决定了这层layer含有的神经元数量,activation是激活函数的选择。同之前的网络一样,我们的网络传播可以看做是:input(784 units)->layer1(256 units)->ReLu->layer2(128 units)->ReLu->output(10 units)。因此,在Sequential容器中定义后三层,activation指定为ReLu,而输入层需要通过build时候指定input_shape来告诉网络输入层的神经元数量。构建的代码如下,通过summary方法可以打印网络信息。

#网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary()

三、模型的训练


模型的训练最重要的就是权重更新和准确度统计。keras提供了多种优化器(optimizer)用于更新权重。优化器实际就是不同的梯度下降算法,缓解了传统梯度下降可能无法收敛到全局最小值的问题。在上一节中就稍加讨论了三种。这里就简单对比一下一些优化器,至于详细的区别今后有时间再写篇随笔专门讨论:

  1. SGD:TensorFlow2.0 SGD实际是随机梯度下降+动量的综合优化器。随机梯度下降是每次更新随机选取一个样本计算梯度,这样计算梯度快很多,但怕大噪声;动量是在梯度下降的基础上,累计历史梯度信息加速梯度下降,这是因为一方面它想水稀释牛奶一样,能减小随机梯度下降对噪声的敏感度,另一方面动量赋予下降以惯性,可以预见梯度变化。这优化器实话说让我联想到了PID控制。SGD需要指定学习率和动量大小。一般地,动量大小设置为0.9。
  2. Adagrad:采用自适应梯度的优化器。所谓自适应梯度,就是根据参数的频率,对每个参数应用不同的学习速率。但是该算法在迭代次数变得很大时,学习速率会变得很小,导致不能继续更新。Adagrad要求指定初始化的学习速率、累加器初始值和防止分母为0的偏置值。
  3. Adadelta:采用自适应增量的优化器。解决了adagrad算法学习速率消失的问题。Adagrad要求指定初始化的学习速率、衰减率和防止分母为0的偏置值。这个衰减率跟动量差不多,一般也指定为0.9。
  4. RMSprop:类似于Adadelta。
  5. Adam:采用梯度的一阶和二阶矩来估计更新参数。它结合了Adadelta和RMSprop的优点。可以说,深度学习通常都会选择Adam优化器。TensorFlow中,Adam优化器需要指定4个参数,但经验证明,它的默认参数能表现出很好的效果。

鉴于以上对比,此处选用Adam作为优化器,并采用其默认参数。

除了梯度下降,还需要考虑的是Loss的计算方法。之前,我们采用的是预测概率与实际值的差平方的均值,专业名称应该是欧几里得损失函数。其实,这是个错误,欧几里得损失函数适用于二元分类,多元分类应该采用交叉熵损失函数。有时候针对多元函数,我们会很不自觉地想把输出层归一化,于是会在输出层之后,交叉熵计算前先softmax一下。但是由于softmax是采用指数形式进行计算的,如果输出各类概率相差较大,则大概率在归一化后几乎为1,小概率归一化之后几乎为0。为了避免这一问题,通常是去掉softmax,在交叉熵函数tf.losses.CategoricalCrossentropy的参数中指from_logits=True。

Loss函数和优化器配置都可以通过compile方法指定,同时,还可以指定metrics列表来决定需要自动计算的信息,如准确度。

通过fit方法可以传入训练数据和测试数据。代码如下:

#配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=10,validation_data=val_db,validation_freq=1)

以上建立的网络模型在第一次train_db迭代完后就可以达到0.8以上的准确度,而且这个迭代每次仅花费3秒左右。经过大约50次迭代,准确度就可以高达0.98!而通过上一节的方式,要达到这样的准确度,起码得训练半个小时。这其中最主要的差别就在于梯度下降算法的优化。

四、完整代码


 import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess) #网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary() #配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=100,validation_data=val_db,validation_freq=5)

手写数字识别——利用keras高层API快速搭建并优化网络模型的更多相关文章

  1. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  2. 【百度飞桨】手写数字识别模型部署Paddle Inference

    从完成一个简单的『手写数字识别任务』开始,快速了解飞桨框架 API 的使用方法. 模型开发 『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入 ...

  3. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  4. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  5. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  6. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  7. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  8. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

  9. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

随机推荐

  1. 全文检索框架---Lucene

    一.什么是全文检索 1.数据分类 我们生活中的数据总体分为两种:结构化数据和非结构化数据.   结构化数据:指具有固定格式或有限长度的数据,如数据库,元数据等.   非结构化数据:指不定长或无固定格式 ...

  2. 珠峰-6-http和http-server原理

    ???? websock改天研究下然后用node去搞. websock的实现原理. ##### 第9天的笔记内容. ## Header 规范 ## Http 状态码 - 101 webscoket 双 ...

  3. 解释为什么wait()和notify(), notifyAll()要放在同步块中

    首先,wait()是释放锁的,因此wait()之前要先获得锁,而锁在同步块开始的时候获得,结束时释放,即同步块内为持有锁的阶段. 那为什么要设计同步块呢?或者说没有同步块会怎样呢?

  4. for遍历用例数据时,报错:TypeError: list indices must be integers, not dict,'int' object is not iterable解决方法

    一:报错:TypeError: list indices must be integers, not dict for i in range(0,len(test_data)): suite.addT ...

  5. C# bubble sort,selection sort,insertion sort

    static void Main(string[] args) { InsertionSortDemo(); Console.ReadLine(); } static void InsertionSo ...

  6. HTTPS原理及流程

    HTTPS为什么更安全:数据对称加密传出,对称密钥使用非对称加密协商. HTTPS就一定安全吗:不一定,如果用户在浏览器端执意访问证书可疑或过期的站点,就存在安全隐患. --- HTTPS实现原理:h ...

  7. 题解【CF1311F Moving Points】

    \[ \texttt{Preface} \] 赛时,把 " 任意时刻 " 理解成 " 整数时刻 " 了,看起来一脸不可做的亚子,还各种推式子. 话说我为什么觉得 ...

  8. c#日期时间段判断

    select * from 表名 where (case when ISDATE(字段名)=1 then CONVERT(varchar(100),cast(字段名 as datetime),23) ...

  9. vue项目下的导入和导出

    本篇博文主要记录我们在写项目的时候经常需要用到导入和导出. 导入 首先定义一个模态弹窗,一般情况下会使用一个input(设置opacity:0)覆盖在显示的按钮上面 <!-- 3.导入 --&g ...

  10. 三种比较好玩的黑客效果JS代码(摘取)

    <html> <head> <title>The Matrix</title> <script src="http://ajax.goo ...