TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model

Checkmate is designed to be a simple drop-in solution for a very common Tensorflow use-case: keeping track of the best model checkpoints during training.

The BestCheckpointSaver is a wrapper around a tf.train.Saver.

The BestCheckpointSaver provides the ability to save the best n checkpoints, whereas the tf.train.Saver can only save the last n checkpoints.

Features

  • Save only best n checkpoints
  • Compares checkpoints based on a user-provided value
  • Can rank checkpoints by highest or lowest values
  • Automatically delete outdated checkpoints
  • Provide at a glance record of each checkpoint's associated value (the user-provided value obtained from that checkpoint)

Using the BestCheckpointSaver

from checkmate import BestCheckpointSaver

# ...build model...

best_ckpt_saver = BestCheckpointSaver(
save_dir=best_checkpoint_dir,
num_to_keep=3,
maximize=True
) # train and evaluate
for train_step in range(max_steps):
sess.run(train_op)
if train_step % evaluation_interval == 0:
accuracy = sess.run(eval_op, feed_dict=validation_data)
best_ckpt_saver.handle(accuracy, sess, global_step_tensor)

Loading the best checkpoint

import checkmate

# ...build model...

saver = tf.train.Saver()
saver.restore(sess, checkmate.get_best_checkpoint(best_checkpoint_dir, select_maximum_value=True))

At this stage, the module is no-frills with limited documentation. It is not intended to work in distributed settings or with complex Session/Graph management (i.e. the tf.Estimator framework). Contributions are welcome.

TensorFlow Saver 保存最佳模型 tf.train.Saver Save Best Model的更多相关文章

  1. 跟我学算法- tensorflow模型的保存与读取 tf.train.Saver()

    save =  tf.train.Saver() 通过save. save() 实现数据的加载 通过save.restore() 实现数据的导出 第一步: 数据的载入 import tensorflo ...

  2. TensorFlow:tf.train.Saver()模型保存与恢复

    1.保存 将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.S ...

  3. tensorflow的tf.train.Saver()模型保存与恢复

    将训练好的模型参数保存起来,以便以后进行验证或测试.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf.train.Saver( ...

  4. 图融合之加载子图:Tensorflow.contrib.slim与tf.train.Saver之坑

    import tensorflow as tf import tensorflow.contrib.slim as slim import rawpy import numpy as np impor ...

  5. 机器学习与Tensorflow(7)——tf.train.Saver()、inception-v3的应用

    1. tf.train.Saver() tf.train.Saver()是一个类,提供了变量.模型(也称图Graph)的保存和恢复模型方法. TensorFlow是通过构造Graph的方式进行深度学习 ...

  6. tf.train.Saver()

    1. 实例化对象 saver = tf.train.Saver(max_to_keep=1) max_to_keep: 表明保存的最大checkpoint文件数.当一个新文件创建的时候,旧文件就会被删 ...

  7. Tensorflow滑动平均模型tf.train.ExponentialMovingAverage解析

    觉得有用的话,欢迎一起讨论相互学习~Follow Me 移动平均法相关知识 移动平均法又称滑动平均法.滑动平均模型法(Moving average,MA) 什么是移动平均法 移动平均法是用一组最近的实 ...

  8. tensorflow 下的滑动平均模型 —— tf.train.ExponentialMovingAverage

    在采用随机梯度下降算法训练神经网络时,使用 tf.train.ExponentialMovingAverage 滑动平均操作的意义在于提高模型在测试数据上的健壮性(robustness). tenso ...

  9. TensorFlow 实战(二)—— tf.train(优化算法)

    Training | TensorFlow tf 下以大写字母开头的含义为名词的一般表示一个类(class) 1. 优化器(optimizer) 优化器的基类(Optimizer base class ...

随机推荐

  1. POJ 2643

    #include<iostream> #include<stdio.h> #include<string> #include<algorithm> #d ...

  2. 再谈C#委托与事件

    之前写过一篇关于C#委托与事件的文章(见<C#委托和事件例析>),不过还是收到一些网友的提问.所以,今天再换另一个角度来详解一下这个问题. 一.在控制台下使用委托和事件 我们都知道,C#中 ...

  3. 七、Framework类库

    1.Framework类库简介 .Net Framework类库包含Framework类库(Framework Class Library,FCL).FCL是一组DLL程序集的统称,其中含有数千个类型 ...

  4. JavaScript -- Window-框架

    -----025-Window-框架.html----- <!DOCTYPE html> <html> <head> <meta http-equiv=&qu ...

  5. scala-02-数组的操作

    scala中的数组和 java中的数组一样, 定义了长度后不可改变 1, 产生一个数组: 有3种创建数组的方式, 分别直接new, 直接赋值, 或者使用 Array中的rang来产生 /** * 获取 ...

  6. Eclipse juno 中安装 JBoss Tools,集成Hibernate

    在Eclipse中集成Hibernate工具可以帮助开发者根据数据库生成映射文件.注释代码以及反向工程. Hibernate Tools作为JBoss Tools的核心组件,已经被捆绑在JBoss T ...

  7. Spring技术内幕_IOC容器载入Bean定义资源文件

    转自:http://blog.csdn.net/chjttony/article/details/6259723 1.当spring的IoC容器将Bean定义的资源文件封装为Spring的Resour ...

  8. SPA页面初试

    之前一直很好奇,SPA应用到底是怎么实现的,昨天无意间看到了有一篇介绍的文章,就想着来试一下水(以下根据我的理解所写,可能会让你看的云里雾里,如果想加深了解,最好先了解下window.location ...

  9. Deep learning with Python 学习笔记(4)

    本节讲卷积神经网络的可视化 三种方法 可视化卷积神经网络的中间输出(中间激活) 有助于理解卷积神经网络连续的层如何对输入进行变换,也有助于初步了解卷积神经网络每个过滤器的含义 可视化卷积神经网络的过滤 ...

  10. 使用EntityManager批量保存数据

    @PersistenceContext EntityManager em; 从别的系统中定期同步某张表的数据,由于数据量较大,采用批量保存 JPA EntityManager的四个主要方法 ① pub ...