过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:

baseline模型结构为:10000-16-16-1

smaller_model模型结构为:10000-4-4-1

bigger_model模型结构为:10000-512-512-1

造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。

解决方法

正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。

正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975

tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。

l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy']) l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

具体见: https://yq.aliyun.com/articles/68901

dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy']) dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

总结

常用防止过拟合的方法有:

  1. 增加数据量
  2. 减少网络结构参数
  3. 正则化
  4. dropout
  5. 数据扩增data-augmentation
  6. 批标准化

[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章

  1. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  2. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  3. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  4. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. 手把手教你使用LabVIEW OpenCV dnn实现图像分类(含源码)

    @ 目录 前言 一.什么是图像分类? 1.图像分类的概念 2.MobileNet简介 二.使用python实现图像分类(py_to_py_ssd_mobilenet.py) 1.获取预训练模型 2.使 ...

  2. Linux 下指定端口开放访问权限

    Linux 下指定端口开放访问权限 作者:Grey 原文地址: 博客园:Linux 下指定端口开放访问权限 CSDN:Linux 下指定端口开放访问权限 环境 CentOS 系和 Debian 系的防 ...

  3. Linux中CentOS 7的安装及Linux常用命令

    1. 前言 什么是Linux Linux是一套免费使用和自由传播的操作系统.说到操作系统,大家比较熟知的应该就是Windows和MacOS操作系统,我们今天所学习的Linux也是一款操作系统. 为什么 ...

  4. Audacity开源音频处理软件使用入门

    操作系统 :Windows10_x64 Audacity版本:3.2.1 Audacity是一款开源.免费.跨平台的音频处理及录音软件,支持Windows.macOS及Linux操作系统. 这里记录下 ...

  5. HTML元素大全(1)

    01.基础元素 <h1/2/3/4/5/6>标题 从大h1到小h6,块元素,有6级标题.是一种标题类语义标签,内置了字体.边距样式. 合理使用h标签,主要用于标题,不要为了加粗效果而随意使 ...

  6. JUC学习笔记——共享模型之管程

    JUC学习笔记--共享模型之管程 在本系列内容中我们会对JUC做一个系统的学习,本片将会介绍JUC的管程部分 我们会分为以下几部分进行介绍: 共享问题 共享问题解决方案 线程安全分析 Monitor ...

  7. JAVA-注解之 TODO、FIXME、XXX

    TODO.FIXME.XXX    //TODO : 表示待实现的功能    //FIXME: 代码存在Bug,不能Run或运行结果不正确,需要修复    //XXX  : 勉强可以工作,但是实现的方 ...

  8. C#使用正则表达式来验证是否是16进制字符串

    /// <summary> /// 判断是否为16进制字符串 /// </summary> /// <param name="hexString"&g ...

  9. Linux *.service文件详解

    什么是systemd service? systemd service是一种以.service 结尾的配置文件,是一个专用于Linux操作系统的系统与服务管理器.简单来说,用于后台以守护精灵(daem ...

  10. extern "C"的使用

    在使用C++开发程序时,有时使用到别人开发的第三方库,而这第三库是使用C开发的.直接使用会报错如下: cpp error LNK2019: 无法解析的外部符号 "int __cdecl su ...