过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:

baseline模型结构为:10000-16-16-1

smaller_model模型结构为:10000-4-4-1

bigger_model模型结构为:10000-512-512-1

造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。

解决方法

正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。

正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975

tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。

l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy']) l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

具体见: https://yq.aliyun.com/articles/68901

dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy']) dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

总结

常用防止过拟合的方法有:

  1. 增加数据量
  2. 减少网络结构参数
  3. 正则化
  4. dropout
  5. 数据扩增data-augmentation
  6. 批标准化

[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章

  1. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  2. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  3. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  4. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. 洛谷P3381 (最小费用最大流模板)

    记得把数组开大一点,不然就RE了... 1 #include<bits/stdc++.h> 2 using namespace std; 3 #define int long long 4 ...

  2. 洛谷P3810 陌上花开 (cdq)

    最近才学了cdq,所以用cdq写的代码(这道题也是cdq的模板题) 这道题是个三维偏序问题,先对第一维排序,然后去掉重复的,然后cdq分治即可. 为什么要去掉重复的呢?因为相同的元素互相之间都能贡献, ...

  3. pthread_mutex_t & pthread_cond_t 总结

    pthread_mutex_t & pthread_cond_t 总结 一.多线程并发 1.1 多线程并发引起的问题 我们先来看如下代码: #include <stdio.h> # ...

  4. JVM运行模式和逃逸分析

    JVM三种运行模式: 解释模式(Interpreted Mode):只使用解释器(-Xint强制JVM使用解释模式),执行一行JVM字节码就编译一行为机器码.(可以马上看到效果,但是运行过程比较慢) ...

  5. springboot+thymeleaf中前台页面展示中、将不同的数字替换成不同的字符串。使用条件运算符

    主要用到的知识就是thyme leaf中的条件运算符 表达式:(condition)?:then:else 当条件condition成立时返回then.否则返回else 具体代码:<td th: ...

  6. token字段,请务加在请求地址的头部header

    如下图所示,你必须在请求的头部加上 token参数,主要原因有两个.第一点,这个是登录标志,因为接口访问用不了cookie,所以只能通过这个header请求标志判断用户是否已经登录.第二点,系统有时候 ...

  7. Android 跨进程渲染

    本项目用于验证 Android 是否能够跨进程渲染 View,最终实现了在子进程创建WebView,主进程显示的功能. 一.跨进程渲染的意义 有一些组件比如 WebView 如果在主进程初始化,会大大 ...

  8. Java 多线程写zip文件遇到的错误 write beyond end of stream!

    最近在写一个大量小文件直接压缩到一个zip的需求,由于zip中的entry每一个都是独立的,不需要追加写入,也就是一个entry文件,写一个内容, 因此直接使用了多线程来处理,结果就翻车了,代码给出了 ...

  9. .NET 7.0 重磅发布及资源汇总

    2022-11-8 .NET 7.0 作为微软的开源跨平台开发平台正式发布.微软在公告中表示.NET 7为您的应用程序带来了C# 11 / F# 7,.NET MAUI,ASP.NET Core/Bl ...

  10. Java代码审计sql注入

    java_sec_code 该项目也可以叫做Java Vulnerability Code(Java漏洞代码). 每个漏洞类型代码默认存在安全漏洞(除非本身不存在漏洞),相关修复代码在注释里.具体可查 ...