过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:

baseline模型结构为:10000-16-16-1

smaller_model模型结构为:10000-4-4-1

bigger_model模型结构为:10000-512-512-1

造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。

解决方法

正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。

正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975

tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。

l2_model = keras.models.Sequential([
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
activation=tf.nn.relu),
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy']) l2_model_history = l2_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

具体见: https://yq.aliyun.com/articles/68901

dpt_model = keras.models.Sequential([
keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
keras.layers.Dropout(0.3), #百分之30的神经元失效
keras.layers.Dense(16, activation=tf.nn.relu),
keras.layers.Dropout(0.7), #百分之70的神经元失效
keras.layers.Dense(1, activation=tf.nn.sigmoid)
]) dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy','binary_crossentropy']) dpt_model_history = dpt_model.fit(train_data, train_labels,
epochs=20,
batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)

总结

常用防止过拟合的方法有:

  1. 增加数据量
  2. 减少网络结构参数
  3. 正则化
  4. dropout
  5. 数据扩增data-augmentation
  6. 批标准化

[深度学习] tf.keras入门4-过拟合和欠拟合的更多相关文章

  1. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  2. [深度学习] tf.keras入门5-模型保存和载入

    目录 设置 基于checkpoints的模型保存 通过ModelCheckpoint模块来自动保存数据 手动保存权重 整个模型保存 总体代码 模型可以在训练中或者训练完成后保存.具体文档参考:http ...

  3. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  4. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. MYSQL-->InnoDB引擎底层原理

    逻辑存储结构 逻辑存储结构图 表空间 表空间文件在Linux下存放在 /var/lib/mysql文件中的 xxx.ibd 文件就是表空间文件 表空间文件用来存储,记录,索引等数据. 段 段分为,数据 ...

  2. 九、docker swarm主机编排

    一. 什么是Docker Swarm Swarm 是 Docker 公司推出的用来管理 docker 集群的平台,几乎全部用GO语言来完成的开发的,代码开源在https://github.com/do ...

  3. 三十三、HPA实现自动扩缩容

    通过HPA实现业务应用的动态扩缩容 HPA控制器介绍 当系统资源过高的时候,我们可以使用如下命令来实现 Pod 的扩缩容功能 $ kubectl -n luffy scale deployment m ...

  4. VirtualBox 下 CentOS7 静态 IP 的配置 → 多次踩坑总结,蚌埠住了!

    开心一刻 一个消化不良的病人向医生抱怨:我近来很不正常,吃什么拉什么,吃黄瓜拉黄瓜,吃西瓜拉西瓜,怎样才能恢复正常呢? 医生沉默片刻:那你只能吃屎了 环境准备 VirtualBox 6.1 网络连接方 ...

  5. ARM TrustZone白皮书部分阅读

    嵌入式系统安全的一些解决方法及缺陷 外部硬件安全模块:在主SoC之外包含一个专用的硬件安全模块或可信元件,e.g. 手机的SIM卡.隔离仅限于可以从非易失性存储器运行的相对静态程序 内部硬件安全模块: ...

  6. 【笔记】入门DP

    复习一下近期练习的入门 \(DP\) .巨佬勿喷.\(qwq\) 重新写一遍练手,加深理解. 代码已经处理,虽然很明显,但请勿未理解就贺 \(qwq\) 0X00 P1057 [NOIP2008 普及 ...

  7. oracle常用查看命令

    select sum(bytes/1024/1024/1024) from dba_segments;   #注:查看表空间大小,除以3个1024后的大小为GB du instance_name(实例 ...

  8. 项目上的业务《接收一个xml信息包进行解析,xml中包含base64解析为电子文件》

    我就直接贴代码了,不太会说,附上注释. ps:需要根据系统字段和xml里面的标签字段进行建表,之后把xml标签的值进行添加.创建表的方法就是拼的sql. // 在线接收接口 @Transactiona ...

  9. 基于sklearn的集成学习实战

    集成学习投票法与bagging 投票法 sklearn提供了VotingRegressor和VotingClassifier两个投票方法.使用模型需要提供一个模型的列表,列表中每个模型采用tuple的 ...

  10. Selenium4+Python3系列(十) - Page Object设计模式

    前言 Page Object(PO)模式,是Selenium实战中最为流行,并且被自动化测试同学所熟悉和推崇的一种设计模式之一.在设计测试时,把页面元素定位和元素操作方法按照页面抽象出来,分离成一定的 ...