手写数字识别——利用keras高层API快速搭建并优化网络模型
在《手写数字识别——手动搭建全连接层》一文中,我们通过机器学习的基本公式构建出了一个网络模型,其实现过程毫无疑问是过于复杂了——不得不考虑诸如数据类型匹配、梯度计算、准确度的统计等问题,但是这样的实践对机器学习的理解是大有裨益的。在大多数情况下,我们还是希望能多简单就多简单地去搭建网络模型,这同时也算对得起TensorFlow这个强大的工具了。本节,还是以手写数据集MNIST为例,利用TensorFlow2.0的keras高层API重现之前的网络。
一、数据的导入与预处理
关于这个过程,与上节讲过的类似,就不再赘述了。需要提一点的就是,为了程序的整洁,将数据类型的转换过程单独写成一个预处理函数preprocess,通过Dataset对象的map方法应用该预处理函数。整个数据导入与预处理代码如下:
import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个批次数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess)
二、模型构建
对于全连接层,keras提供了layers.Dense(units,activation)接口,利用它可以建立一层layer,多层堆叠放入keras提供的Sequential容器中,就形成了一个网络模型。在Dense的参数中,units决定了这层layer含有的神经元数量,activation是激活函数的选择。同之前的网络一样,我们的网络传播可以看做是:input(784 units)->layer1(256 units)->ReLu->layer2(128 units)->ReLu->output(10 units)。因此,在Sequential容器中定义后三层,activation指定为ReLu,而输入层需要通过build时候指定input_shape来告诉网络输入层的神经元数量。构建的代码如下,通过summary方法可以打印网络信息。
#网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary()
三、模型的训练
模型的训练最重要的就是权重更新和准确度统计。keras提供了多种优化器(optimizer)用于更新权重。优化器实际就是不同的梯度下降算法,缓解了传统梯度下降可能无法收敛到全局最小值的问题。在上一节中就稍加讨论了三种。这里就简单对比一下一些优化器,至于详细的区别今后有时间再写篇随笔专门讨论:
- SGD:TensorFlow2.0 SGD实际是随机梯度下降+动量的综合优化器。随机梯度下降是每次更新随机选取一个样本计算梯度,这样计算梯度快很多,但怕大噪声;动量是在梯度下降的基础上,累计历史梯度信息加速梯度下降,这是因为一方面它想水稀释牛奶一样,能减小随机梯度下降对噪声的敏感度,另一方面动量赋予下降以惯性,可以预见梯度变化。这优化器实话说让我联想到了PID控制。SGD需要指定学习率和动量大小。一般地,动量大小设置为0.9。
- Adagrad:采用自适应梯度的优化器。所谓自适应梯度,就是根据参数的频率,对每个参数应用不同的学习速率。但是该算法在迭代次数变得很大时,学习速率会变得很小,导致不能继续更新。Adagrad要求指定初始化的学习速率、累加器初始值和防止分母为0的偏置值。
- Adadelta:采用自适应增量的优化器。解决了adagrad算法学习速率消失的问题。Adagrad要求指定初始化的学习速率、衰减率和防止分母为0的偏置值。这个衰减率跟动量差不多,一般也指定为0.9。
- RMSprop:类似于Adadelta。
- Adam:采用梯度的一阶和二阶矩来估计更新参数。它结合了Adadelta和RMSprop的优点。可以说,深度学习通常都会选择Adam优化器。TensorFlow中,Adam优化器需要指定4个参数,但经验证明,它的默认参数能表现出很好的效果。
鉴于以上对比,此处选用Adam作为优化器,并采用其默认参数。
除了梯度下降,还需要考虑的是Loss的计算方法。之前,我们采用的是预测概率与实际值的差平方的均值,专业名称应该是欧几里得损失函数。其实,这是个错误,欧几里得损失函数适用于二元分类,多元分类应该采用交叉熵损失函数。有时候针对多元函数,我们会很不自觉地想把输出层归一化,于是会在输出层之后,交叉熵计算前先softmax一下。但是由于softmax是采用指数形式进行计算的,如果输出各类概率相差较大,则大概率在归一化后几乎为1,小概率归一化之后几乎为0。为了避免这一问题,通常是去掉softmax,在交叉熵函数tf.losses.CategoricalCrossentropy的参数中指from_logits=True。
Loss函数和优化器配置都可以通过compile方法指定,同时,还可以指定metrics列表来决定需要自动计算的信息,如准确度。
通过fit方法可以传入训练数据和测试数据。代码如下:
#配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=10,validation_data=val_db,validation_freq=1)
以上建立的网络模型在第一次train_db迭代完后就可以达到0.8以上的准确度,而且这个迭代每次仅花费3秒左右。经过大约50次迭代,准确度就可以高达0.98!而通过上一节的方式,要达到这样的准确度,起码得训练半个小时。这其中最主要的差别就在于梯度下降算法的优化。
四、完整代码
import tensorflow as tf
from tensorflow.keras import datasets,optimizers,Sequential,metrics,layers # 改变数据类型
def preprocess(x,y):
x = tf.cast(x,dtype=tf.float32)/255-0.5
x = tf.reshape(x,[-1,28*28])
y = tf.one_hot(y, depth=10)
y = tf.cast(y, dtype=tf.int32)
return x,y #60k 28*28
(train_x,train_y),(val_x,val_y) = datasets.mnist.load_data() #生成Dataset对象
train_db = tf.data.Dataset.from_tensor_slices((train_x,train_y)).shuffle(10000).batch(256)
val_db = tf.data.Dataset.from_tensor_slices((val_x,val_y)).shuffle(10000).batch(256) #预处理,对每个数据应用preprocess
train_db = train_db.map(preprocess)
val_db = val_db.map(preprocess) #网络模型
model = Sequential([
layers.Dense(256,activation=tf.nn.relu),
layers.Dense(128,activation=tf.nn.relu),
layers.Dense(10),
])
#input_shape=(batch_size,input_dims)
model.build(input_shape=(None,28*28))
model.summary() #配合Adam优化器、交叉熵Loss函数、metrics列表
model.compile(optimizer=optimizers.Adam(),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
#数据传入,迭代10次train_db,每迭代1次,计算一次测试数据集准确度
model.fit(train_db,epochs=100,validation_data=val_db,validation_freq=5)
手写数字识别——利用keras高层API快速搭建并优化网络模型的更多相关文章
- 手写数字识别——基于LeNet-5卷积网络模型
在<手写数字识别——利用Keras高层API快速搭建并优化网络模型>一文中,我们搭建了全连接层网络,准确率达到0.98,但是这种网络的参数量达到了近24万个.本文将搭建LeNet-5网络, ...
- 【百度飞桨】手写数字识别模型部署Paddle Inference
从完成一个简单的『手写数字识别任务』开始,快速了解飞桨框架 API 的使用方法. 模型开发 『手写数字识别』是深度学习里的 Hello World 任务,用于对 0 ~ 9 的十类数字进行分类,即输入 ...
- CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 卷积神经网络CNN 手写数字识别
1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...
- 利用神经网络算法的C#手写数字识别
欢迎大家前往云+社区,获取更多腾讯海量技术实践干货哦~ 下载Demo - 2.77 MB (原始地址):handwritten_character_recognition.zip 下载源码 - 70. ...
- keras框架的MLP手写数字识别MNIST,梳理?
keras框架的MLP手写数字识别MNIST 代码: # coding: utf-8 # In[1]: import numpy as np import pandas as pd from kera ...
- 【问题解决方案】Keras手写数字识别-ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接
参考:台大李宏毅老师视频课程-Keras-Demo 在载入数据阶段报错: ConnectionResetError: [WinError 10054] 远程主机强迫关闭了一个现有的连接 Google之 ...
- keras—多层感知器MLP—MNIST手写数字识别
一.手写数字识别 现在就来说说如何使用神经网络实现手写数字识别. 在这里我使用mind manager工具绘制了要实现手写数字识别需要的模块以及模块的功能: 其中隐含层节点数量(即神经细胞数量)计算 ...
- keras和tensorflow搭建DNN、CNN、RNN手写数字识别
MNIST手写数字集 MNIST是一个由美国由美国邮政系统开发的手写数字识别数据集.手写内容是0~9,一共有60000个图片样本,我们可以到MNIST官网免费下载,总共4个.gz后缀的压缩文件,该文件 ...
随机推荐
- [jQuery]顶级对象$(二)
$ 是 jQuery 的缩写 <script> # 方法1. $ 是jQuery的别称 弹出提示 $(function () { alert(11) ); # 方法2 jQuery(fun ...
- React之拆分组件与组件之间的传值
父子组件传值: 父组件向子组件传值通过向子组件TodoItem进行属性绑定(content={item}.index={index}),代码如下 getTodoItem () { return thi ...
- Webpack之(progressive web application) - PWA中的 Service Workers 是什么
学习文档:https://webpack.docschina.org/guides/progressive-web-application/ 参考文档:https://developers.googl ...
- 惊讶!缓存刚Put再Get居然获取不到?
最近一直在老家远程办公,微信突然响了下,有同事说遇到了一个奇怪的问题,让我帮忙看下. 现象就是标题所说的缓存获取不到的问题,我一听感觉这个问题挺有意思的,决定一探究竟. 下面给出部分代码还原下案发现场 ...
- C语言实现双人控制的战斗小游戏
实现功能 1.双人分别控制小人移动 2.子弹碰撞 3.可改变出弹方向 4.血条实体化 前言 这个游戏是看了知乎一位非常好的老师的专栏后练手写的,(至于是哪位,知乎搜C语言小游戏最牛逼的那位) 有老师系 ...
- Git简易教程(常用命令)
本文章参考了Pro Git 1 Git简介 Linux内核开源项目有着众多参与者,为了提高开发效率,项目组于2002年开始启用分布式版本控制系统BitKeeper来管理和维护代码.在BitKeeper ...
- CDC+ETL实现数据集成方案
欢迎咨询,合作! weix:wonter 名词解释: CDC又称变更数据捕获(Change Data Capture),开启cdc的源表在插入INSERT.更新UPDATE和删除DELETE活动时会插 ...
- MySql学习-1.MySql的安装:
1.安装包的下载(mysql-v5.7.25 )(NavicatforMySQL_11.2.15): 链接:https://pan.baidu.com/s/166hyyYd3DMjYhMwdW805F ...
- VUE中使用XLSX实现导出excel表格
简介 项目中经常会用导出数据的场景,这里介绍 VUE 中如何使用插件 xlsx 导出数据 安装 ## 1.使用 npm 或 yarn 安装依赖(三个依赖) npm install -S file-sa ...
- sqlserver 批量修改数据库表主键名称为PK_表名
1.我们在创建sqlserver得数据表的主键的时候,有时会出现,后面加一串随机字符串的情况,如图所示: 2.如果你有强迫症的话,可以使用以下sql脚本进行修改,将主键的名称修改为PK_表名. --将 ...