过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:

baseline模型结构为:10000-16-16-1

smaller_model模型结构为:10000-4-4-1

bigger_model模型结构为:10000-512-512-1

造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。

解决方法

正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。

正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975

tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。

l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy']) l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

具体见: https://yq.aliyun.com/articles/68901

dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy']) dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

总结

常用防止过拟合的方法有:

  1. 增加数据量
  2. 减少网络结构参数
  3. 正则化
  4. dropout
  5. 数据扩增data-augmentation
  6. 批标准化

[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章

  1. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  2. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  3. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  4. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. 「JOISC 2022 Day1」京都观光 题解

    Solution 考虑从\((x_1,y_1)\)走到\((x_2,y_2)\)满足只改变一次方向,则容易求出先向南走当且仅当 \[\frac{a_{x_1} - a_{x_2}}{x_1 - x_2 ...

  2. 设计一个网上书店,该系统中所有的计算机类图书(ComputerBook)每本都有10%的折扣,所有的语言类图书(LanguageBook)每本都有2元的折扣,小说类图书(NovelBook)每100元

    现使用策略模式来设计该系统,绘制类图并编程实现 UML类图 书籍 package com.zheng; public class Book { private double price;// 价格 p ...

  3. 齐博x1文本代码标签的使用

    文本标签虽然简单,但是使用的地方确实非常多的. {qb:tag name="XXXX" type="text"}推荐新闻{/qb:tag} 类似这种使用的频率是 ...

  4. MinGW配置C语言编译器gcc和g++

    首先,在 https://sourceforge.net/projects/mingw/files/latest/download 下载安装MinGW,如下图所示: 点Installation-> ...

  5. python中的if条件语句

    # 如果...就... # 1. print('1.') if 1+1 == 2: print('1+1是等于2的') print('1+1还是等于2的') print('1+1就等于2的') # 2 ...

  6. Python基础部分:1、typora软件和对计算机的认识

    目录 一.typora软件 1.安装 2.markdown语法 二.计算机的本质 1.进制数 三.计算机五大组成部分概要 1.控制器 2.运算器 3.存储器 4.输入设备 5.输出设备 一.typor ...

  7. MFC 学习笔记

    MFC 学习笔记 一.MFC编程基础: 概述: 常用头文件: MFC控制台程序: MFC库程序: 规则库可以被各种程序所调用,扩展库只能被MFC程序调用. MFC窗口程序: 示例: MFC库中类的简介 ...

  8. Codeforces Round #805 (Div. 3)G2. Passable Paths

    题目大意: 给出一个无向无环连通图(树),n个点n-1条边,m次查询,每次询问给出一个集合,问集合里的树是否都在同一条链上(即能否不重复的走一条边而遍历整个点集) 思路:通过求lca,若有三个点x,y ...

  9. 真正“搞”懂HTTP协议03之时间穿梭

    上一篇我们简单的介绍了一下DoD模型和OSI模型,还着重的讲解了TCP的三次握手和四次挥手,让我们在空间层面,稍稍宏观的了解了HTTP所依赖的底层模型,那么这一篇,我们来追溯一下HTTP的历史,看一看 ...

  10. spring源码解析(一) 环境搭建(各种坑的解决办法)

    上次搭建spring源码的环境还是两年前,依稀记得那时候也是一顿折腾,奈何当时没有记录,导致两年后的今天把坑重踩了一遍,还遇到了新的坑,真是欲哭无泪;为了以后类似的事情不再发生,这次写下这篇博文来必坑 ...