[深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合
简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。
官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit
过拟合和欠拟合
以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:
baseline模型结构为:10000-16-16-1
smaller_model模型结构为:10000-4-4-1
bigger_model模型结构为:10000-512-512-1
造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。
解决方法
正则化
正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。
正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975
tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。
l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
dropout
Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。
具体见: https://yq.aliyun.com/articles/68901
dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy'])
dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
总结
常用防止过拟合的方法有:
- 增加数据量
- 减少网络结构参数
- 正则化
- dropout
- 数据扩增data-augmentation
- 批标准化
[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章
- [深度学习] tf.keras入门3-回归
目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...
- [深度学习] tf.keras入门5-模型保存和载入
目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- [深度学习] tf.keras入门1-基本函数介绍
目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- 手把手教你使用LabVIEW OpenCV dnn实现图像分类(含源码)
@ 目录 前言 一.什么是图像分类? 1.图像分类的概念 2.MobileNet简介 二.使用python实现图像分类(py_to_py_ssd_mobilenet.py) 1.获取预训练模型 2.使 ...
- Linux 下指定端口开放访问权限
Linux 下指定端口开放访问权限 作者:Grey 原文地址: 博客园:Linux 下指定端口开放访问权限 CSDN:Linux 下指定端口开放访问权限 环境 CentOS 系和 Debian 系的防 ...
- Linux中CentOS 7的安装及Linux常用命令
1. 前言 什么是Linux Linux是一套免费使用和自由传播的操作系统.说到操作系统,大家比较熟知的应该就是Windows和MacOS操作系统,我们今天所学习的Linux也是一款操作系统. 为什么 ...
- Audacity开源音频处理软件使用入门
操作系统 :Windows10_x64 Audacity版本:3.2.1 Audacity是一款开源.免费.跨平台的音频处理及录音软件,支持Windows.macOS及Linux操作系统. 这里记录下 ...
- HTML元素大全(1)
01.基础元素 <h1/2/3/4/5/6>标题 从大h1到小h6,块元素,有6级标题.是一种标题类语义标签,内置了字体.边距样式. 合理使用h标签,主要用于标题,不要为了加粗效果而随意使 ...
- JUC学习笔记——共享模型之管程
JUC学习笔记--共享模型之管程 在本系列内容中我们会对JUC做一个系统的学习,本片将会介绍JUC的管程部分 我们会分为以下几部分进行介绍: 共享问题 共享问题解决方案 线程安全分析 Monitor ...
- JAVA-注解之 TODO、FIXME、XXX
TODO.FIXME.XXX //TODO : 表示待实现的功能 //FIXME: 代码存在Bug,不能Run或运行结果不正确,需要修复 //XXX : 勉强可以工作,但是实现的方 ...
- C#使用正则表达式来验证是否是16进制字符串
/// <summary> /// 判断是否为16进制字符串 /// </summary> /// <param name="hexString"&g ...
- Linux *.service文件详解
什么是systemd service? systemd service是一种以.service 结尾的配置文件,是一个专用于Linux操作系统的系统与服务管理器.简单来说,用于后台以守护精灵(daem ...
- extern "C"的使用
在使用C++开发程序时,有时使用到别人开发的第三方库,而这第三库是使用C开发的.直接使用会报错如下: cpp error LNK2019: 无法解析的外部符号 "int __cdecl su ...