https://cloud.tencent.com/developer/article/1010815

8.更科学地模型训练与模型保存

filepath = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')
# fit model
model.fit(x, y, epochs=20, verbose=2, callbacks=[checkpoint], validation_data=(x, y))

save_best_only打开之后,会如下:

 ETA: 3s - loss: 0.5820Epoch 00017: val_loss did not improve

如果val_loss 提高了就会保存,没有提高就不会保存。

ModelCheckpoint

keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

该回调函数将在每个epoch后保存模型到filepath

filepath可以是格式化的字符串,里面的占位符将会被epoch值和传入on_epoch_endlogs关键字所填入

例如,filepath若为weights.{epoch:02d-{val_loss:.2f}}.hdf5,则会生成对应epoch和验证集loss的多个文件。

参数

  • filename:字符串,保存模型的路径

  • monitor:需要监视的值

  • verbose:信息展示模式,0或1

  • save_best_only:当设置为True时,将只保存在验证集上性能最好的模型

  • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。

  • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)

  • period:CheckPoint之间的间隔的epoch数。

9.如何在keras中使用tensorboard

 RUN = RUN + 1 if 'RUN' in locals() else 1   # locals() 函数会以字典类型返回当前位置的全部局部变量。

     LOG_DIR = model_save_path + '/training_logs/run{}'.format(RUN)
LOG_FILE_PATH = LOG_DIR + '/checkpoint-{epoch:02d}-{val_loss:.4f}.hdf5' # 模型Log文件以及.h5模型文件存放地址 tensorboard = TensorBoard(log_dir=LOG_DIR, write_images=True)
checkpoint = ModelCheckpoint(filepath=LOG_FILE_PATH, monitor='val_loss', verbose=1, save_best_only=True)
early_stopping = EarlyStopping(monitor='val_loss', patience=5, verbose=1) history = model.fit_generator(generator=gen.generate(True), steps_per_epoch=int(gen.train_batches / 4),
validation_data=gen.generate(False), validation_steps=int(gen.val_batches / 4),
epochs=EPOCHS, verbose=1, callbacks=[tensorboard, checkpoint, early_stopping])

都是在回调函数中起作用:

  • EarlyStopping patience:当early 
    (1)stop被激活(如发现loss相比上一个epoch训练没有下降),则经过patience个epoch后停止训练。 
    (2)mode:‘auto’,‘min’,‘max’之一,在min模式下,如果检测值停止下降则中止训练。在max模式下,当检测值不再上升则停止训练。

  • 模型检查点ModelCheckpoint 
    (1)save_best_only:当设置为True时,将只保存在验证集上性能最好的模型 
    (2) mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则,例如,当监测值为val_acc时,模式应为max,当检测值为val_loss时,模式应为min。在auto模式下,评价准则由被监测值的名字自动推断。 
    (3)save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等) 
    (4)period:CheckPoint之间的间隔的epoch数

  • 可视化tensorboard write_images: 是否将模型权重以图片的形式可视化

其他内容可参考keras中文文档

keras训练和保存的更多相关文章

  1. Keras模型的保存方式

    Keras模型的保存方式 在运行并且训练出一个模型后获得了模型的结构与许多参数,为了防止再次训练以及需要更好地去使用,我们需要保存当前状态 基本保存方式 h5 # 此处假设model为一个已经训练好的 ...

  2. keras训练cnn模型时loss为nan

    keras训练cnn模型时loss为nan 1.首先记下来如何解决这个问题的:由于我代码中 model.compile(loss='categorical_crossentropy', optimiz ...

  3. 使用Keras训练神经网络备忘录

    小书匠深度学习 文章太长,放个目录: 1.优化函数的选择 2.损失函数的选择 2.2常用的损失函数 2.2自定义函数 2.1实践 2.2将损失函数自定义为网络层 3.模型的保存 3.1同时保持结构和权 ...

  4. Keras 训练 inceptionV3 并移植到OpenCV4.0 in C++

    1. 训练 # --coding:utf--- import os import sys import glob import argparse import matplotlib.pyplot as ...

  5. 使用Keras训练大规模数据集

    官方提供的.flow_from_directory(directory)函数可以读取并训练大规模训练数据,基本可以满足大部分需求.但是在有些场合下,需要自己读取大规模数据以及对应标签,下面提供一种方法 ...

  6. Keras 训练一个单层全连接网络的线性回归模型

    1.准备环境,探索数据 import numpy as np from keras.models import Sequential from keras.layers import Dense im ...

  7. keras训练大量数据的办法

    最近在做一个鉴黄的项目,数据量比较大,有几百个G,一次性加入内存再去训练模青型是不现实的. 查阅资料发现keras中可以用两种方法解决,一是将数据转为tfrecord,但转换后数据大小会方法不好:另外 ...

  8. keras训练实例-python实现

    用keras训练模型并实时显示loss/acc曲线,(重要的事情说三遍:实时!实时!实时!)实时导出loss/acc数值(导出的方法就是实时把loss/acc等写到一个文本文件中,其他模块如前端调用时 ...

  9. Keras处理已保存模型中的自定义层(或其他自定义对象)

    如果要加载的模型包含自定义层或其他自定义类或函数,则可以通过 custom_objects 参数将它们传递给加载机制: from keras.models import load_model # 假设 ...

随机推荐

  1. 普林斯顿数学指南(第一卷) (Timothy Gowers 著)

    第I部分 引论 I.1 数学是做什么的 I.2 数学的语言和语法 I.3 一些基本的数学定义 I.4 数学研究的一般目的 第II部分 现代数学的起源 II.1 从数到数系 II.2 几何学 II.3 ...

  2. windows下能搭建php-fpm吗 phpstudy

    这个Windows和Linux系统是不一样的,因为一般nginx搭配php需要php-fpm中间件,但是Windows下需要第三方编译. 下载的包里有php-cgi.exe 但不是php-fpm如果想 ...

  3. 随机重拍与抽样(random_shuffle,random_sample,random_sample_n)

    //版本一:使用内部的随机数生成器 template<class RandomAccessIterator> void random_shuffle( RandomAccessIterat ...

  4. 本地开发不用改hosts 也可以绑定域名开发

    以往我们在开发 web 应用时,为了模拟生产环境都会修改系统中的hosts 文件,加入一个域名指向 127.0.0.1,绑定到开发目录,如下: 但是在 Chrome 中有一个域名是可以不用修改 hos ...

  5. kafka 的经典教程

    一.基本概念 介绍 Kafka是一个分布式的.可分区的.可复制的消息系统.它提供了普通消息系统的功能,但具有自己独特的设计. 这个独特的设计是什么样的呢? 首先让我们看几个基本的消息系统术语:Kafk ...

  6. webpack 中,loader、plugin 的区别

    loader 和 plugin 的主要区别: loader 用于加载某些资源文件. 因为 webpack 只能理解 JavaScript 和 JSON 文件,对于其他资源例如 css,图片,或者其他的 ...

  7. Python 基础语法+简单地爬取百度贴吧内容

    Python笔记 1.Python3和Pycharm2018的安装 2.Python3基础语法 2.1.1.数据类型 2.1.1.1.数据类型:数字(整数和浮点数) 整数:int类型 浮点数:floa ...

  8. windows系统如何设置域名解析

      C:\Windows\System32\drivers\etc        

  9. centos添加额外测源,解决:No package openvpn available.

    centos添加额外测源,解决:No package openvpn available. ##添加额外的repositories,安装openvpn yum install epel-release ...

  10. Elasticsearch的数据导出和导入操作(elasticdump工具),以及删除指定type的数据(delete-by-query插件)

    Elasticseach目前作为查询搜索平台,的确非常实用方便.我们今天在这里要讨论的是如何做数据备份和type删除.我的ES的版本是2.4.1. ES的备份,可不像MySQL的mysqldump这么 ...