《手写数字识别——手动搭建全连接层》一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配、梯度计算、准确度的统计等问题,但是这样的实践对机器学习的理解是大有裨益的。在大多数情况下,我们还是希望能多简单就多简单地去搭建网络模型,这同时也算对得起TensorFlow这个强大的工具了。本节,还是以手写数据集MNIST为例,利用TensorFlow2.0的keras高层API重现之前的网络。

一、数据的导入与预处理


关于这个过程,与上节讲过的类似,就不再赘述了。需要提一点的就是,为了程序的整洁,将数据类型的转换过程单独写成一个预处理函数preprocess,通过Dataset对象的map方法应用该预处理函数。整个数据导入与预处理代码如下:

import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个批次数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess)

二、模型构建


对于全连接层,keras提供了layers.Dense(units,activation)接口,利用它可以建立一层layer,多层堆叠放入keras提供的Sequential容器中,就形成了一个网络模型。在Dense的参数中,units决定了这层layer含有的神经元数量,activation是激活函数的选择。同之前的网络一样,我们的网络传播可以看做是:input(784 units)->layer1(256 units)->ReLu->layer2(128 units)->ReLu->output(10 units)。因此,在Sequential容器中定义后三层,activation指定为ReLu,而输入层需要通过build时候指定input_shape来告诉网络输入层的神经元数量。构建的代码如下,通过summary方法可以打印网络信息。

#网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary()

三、模型的训练


模型的训练最重要的就是权重更新和准确度统计。keras提供了多种优化器(optimizer)用于更新权重。优化器实际就是不同的梯度下降算法,缓解了传统梯度下降可能无法收敛到全局最小值的问题。在上一节中就稍加讨论了三种。这里就简单对比一下一些优化器,至于详细的区别今后有时间再写篇随笔专门讨论:

  1. SGD:TensorFlow2.0 SGD实际是随机梯度下降+动量的综合优化器。随机梯度下降是每次更新随机选取一个样本计算梯度,这样计算梯度快很多,但怕大噪声;动量是在梯度下降的基础上,累计历史梯度信息加速梯度下降,这是因为一方面它想水稀释牛奶一样,能减小随机梯度下降对噪声的敏感度,另一方面动量赋予下降以惯性,可以预见梯度变化。这优化器实话说让我联想到了PID控制。SGD需要指定学习率和动量大小。一般地,动量大小设置为0.9。
  2. Adagrad:采用自适应梯度的优化器。所谓自适应梯度,就是根据参数的频率,对每个参数应用不同的学习速率。但是该算法在迭代次数变得很大时,学习速率会变得很小,导致不能继续更新。Adagrad要求指定初始化的学习速率、累加器初始值和防止分母为0的偏置值。
  3. Adadelta:采用自适应增量的优化器。解决了adagrad算法学习速率消失的问题。Adagrad要求指定初始化的学习速率、衰减率和防止分母为0的偏置值。这个衰减率跟动量差不多,一般也指定为0.9。
  4. RMSprop:类似于Adadelta。
  5. Adam:采用梯度的一阶和二阶矩来估计更新参数。它结合了Adadelta和RMSprop的优点。可以说,深度学习通常都会选择Adam优化器。TensorFlow中,Adam优化器需要指定4个参数,但经验证明,它的默认参数能表现出很好的效果。

鉴于以上对比,此处选用Adam作为优化器,并采用其默认参数。

除了梯度下降,还需要考虑的是Loss的计算方法。之前,我们采用的是预测概率与实际值的差平方的均值,专业名称应该是欧几里得损失函数。其实,这是个错误,欧几里得损失函数适用于二元分类,多元分类应该采用交叉熵损失函数。有时候针对多元函数,我们会很不自觉地想把输出层归一化,于是会在输出层之后,交叉熵计算前先softmax一下。但是由于softmax是采用指数形式进行计算的,如果输出各类概率相差较大,则大概率在归一化后几乎为1,小概率归一化之后几乎为0。为了避免这一问题,通常是去掉softmax,在交叉熵函数tf.losses.CategoricalCrossentropy的参数中指from_logits=True。

Loss函数和优化器配置都可以通过compile方法指定,同时,还可以指定metrics列表来决定需要自动计算的信息,如准确度。

通过fit方法可以传入训练数据和测试数据。代码如下:

#配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=10,validation_data=val_db,validation_freq=1)

以上建立的网络模型在第一次train_db迭代完后就可以达到0.8以上的准确度,而且这个迭代每次仅花费3秒左右。经过大约50次迭代,准确度就可以高达0.98!而通过上一节的方式,要达到这样的准确度,起码得训练半个小时。这其中最主要的差别就在于梯度下降算法的优化。

四、完整代码


 import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess) #网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary() #配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=100,validation_data=val_db,validation_freq=5)

手写数字识别——利用keras高层API快速搭建并优化网络模型的更多相关文章

  1. 手写数字识别——基于LeNet-5卷积网络模型

    在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...

  2. 【百度飞桨】手写数字识别模型部署Paddle Inference

    从完成一个简单的『手写数字识别任务』开始,快速了解飞桨框架 API 的使用方法. 模型开发 『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入 ...

  3. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  4. 卷积神经网络CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  5. 利用神经网络算法的C#手写数字识别

    欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...

  6. keras框架的MLP手写数字识别MNIST,梳理?

    keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...

  7. 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接

    参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...

  8. keras—多层感知器MLP—MNIST手写数字识别

    一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能:  其中隐含层节点数量(即神经细胞数量)计算 ...

  9. keras和tensorflow搭建DNN、CNN、RNN手写数字识别

    MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...

随机推荐

  1. VSTO开发指南(VB2013版) 第三章 Excel编程

    通过前两章的内容,有了一定的基础,但进入第三章,实例的步骤非常多,并且随着VS版本的升级,部分功能菜单界面发生了很大变化,所以,第三章的案例我将逐步编写! 实例3.1的目标就是给Excel写一个加载宏 ...

  2. 挂号平台首页开发(UI组件部分)

    JQ插件模式开发UI组件 JQ插件开发方法: 1.$.extend() 扩展JQ(比较简单,功能略显不足) $.extend({ sayHello:function(){ console.log(&q ...

  3. JAVA系统架构高并发解决方案 分布式缓存 分布式事务解决方案

    JAVA系统架构高并发解决方案 分布式缓存 分布式事务解决方案

  4. 面向对象+闭包+三种对象的声明方式(字面式、new Object、构造函数、工厂模式、原型模式、混合模式)

    面向对象: 对代码的一种抽象,对外统一提供调用接口的编程思想 对象的属性:事物自身拥有的东西 对象的方法:事物的功能 对象:事物的一个实例 对象的原型:.prototype -> 内存地址 -& ...

  5. 【DTOJ】2703:两个数的余数和商

    DTOJ 2703:两个数的余数和商  解题报告 2017.11.10 第一版 ——由翱翔的逗比w原创,引用<C++ Primer Plus(第6版)中文版> 题目信息: 题目描述 给你a ...

  6. webstorm 添加代码模板

    file>setting>Live Templates>选择文件类型

  7. python3练习100题——002

    因为特殊原因,昨天没有做题.今天继续- 原题链接:http://www.runoob.com/python/python-exercise-example2.html 题目: 企业发放的奖金根据利润提 ...

  8. memcached与redis比较

    1- memcached介绍 Memcached是一个自由开源的,高性能,分布式内存对象缓存系统. Memcached是以LiveJournal旗下Danga Interactive公司的Brad F ...

  9. 【python基础语法】第1天作业练习题

    # 1.下面那些不能作为标识符? """ 1.find 2. _num 3.7val 4.add. 5.def 6.pan 7.-print 8.open_file 9. ...

  10. 数据结构(集合)学习之Map(一)

    集合 框架关系图: 补充:HashTable父类是Dictionary,不是AbstractMap. Map: Map(接口)和Collection都属于集合,但是Map不是Collection的子类 ...