[深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合
简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。
官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit
过拟合和欠拟合
以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:
baseline模型结构为:10000-16-16-1
smaller_model模型结构为:10000-4-4-1
bigger_model模型结构为:10000-512-512-1
造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。
解决方法
正则化
正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。
正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975
tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。
l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
dropout
Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。
具体见: https://yq.aliyun.com/articles/68901
dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
])
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy'])
dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
总结
常用防止过拟合的方法有:
- 增加数据量
- 减少网络结构参数
- 正则化
- dropout
- 数据扩增data-augmentation
- 批标准化
[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章
- [深度学习] tf.keras入门3-回归
目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...
- [深度学习] tf.keras入门5-模型保存和载入
目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- [深度学习] tf.keras入门1-基本函数介绍
目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- POJ2533 Longest Ordered Subsequence (线性DP)
设dp[i]表示以i结尾的最长上升子序列的长度. dp[i]=max(dp[i],dp[j]+1). 1 #include <map> 2 #include <set> 3 # ...
- C++编程范式(函数)
1 // 2 // main.cpp 3 // test 4 // 5 // Created by Shaojun on 30/5/2020. 6 // Copyright 2020 Shaojun. ...
- React魔法堂:echarts-for-react源码略读
前言 在当前工业4.0和智能制造的产业升级浪潮当中,智慧大屏无疑是展示企业IT成果的最有效方式之一.然而其背后怎么能缺少ECharts的身影呢?对于React应用而言,直接使用ECharts并不是最高 ...
- JS---HelloWorld
1.功能效果图 2.代码实现 <!DOCTYPE html> <html> <head> <meta charset="utf-8"> ...
- Audacity开源音频处理软件使用入门
操作系统 :Windows10_x64 Audacity版本:3.2.1 Audacity是一款开源.免费.跨平台的音频处理及录音软件,支持Windows.macOS及Linux操作系统. 这里记录下 ...
- 【第1篇】人工智能(AI)语音测试原理和实践---宣传
前言 本文主要介绍作者关于人工智能(AI)语音测试的各方面知识点和实战技术. 本书共分为9章,第1.2章详细介绍人工智能(AI)语音测试各种知识点和人工智能(AI)语音交互原理:第3.4章介绍人工智 ...
- .net core 读取appsettings.json 文件中文乱码的问题
解决办法:设置高级保存选项 第一步:在工具栏找到自定义选项 第二步:添加高级保存选项Advanced save options 第三步:在Appsettings.json页面操作
- Netty学习记录-入门篇
你如果,缓缓把手举起来,举到顶,再突然张开五指,那恭喜你,你刚刚给自己放了个烟花. 模块介绍 netty-bio: 阻塞型网络通信demo. netty-nio: 引入channel(通道).buff ...
- 【FAQ】关于华为地图服务定位存在偏差的原因及解决办法
一. 问题描述: 华为地图服务"我的位置"能力,在中国大陆地区,向用户展示他们在地图上的当前位置与用户的实际位置存在较大的偏差. 具体差别可以查看下方的图片: 二. 偏差较大的原因 ...
- RDD(弹性分布式数据集)及常用算子
RDD(弹性分布式数据集)及常用算子 RDD(Resilient Distributed Dataset)叫做弹性分布式数据集,是 Spark 中最基本的数据 处理模型.代码中是一个抽象类,它代表一个 ...