保存和恢复模型(Save and restore models)

官网示例:https://www.tensorflow.org/tutorials/keras/save_and_restore_models

在训练期间保存检查点

在训练期间或训练结束时自动保存检查点。
权重存储在检查点格式的文件集合中,这些文件仅包含经过训练的权重(采用二进制格式)。
可以使用经过训练的模型,而无需重新训练该模型,或从上次暂停的地方继续训练,以防训练过程中断

  • 检查点回调用法:创建检查点回调,训练模型并将ModelCheckpoint回调传递给该模型,得到检查点文件集合,用于分享权重
  • 检查点回调选项:该回调提供了多个选项,用于为生成的检查点提供独一无二的名称,以及调整检查点创建频率。

手动保存权重

使用 Model.save_weights 方法即可手动保存权重

保存整个模型

整个模型可以保存到一个文件中,其中包含权重值、模型配置(架构)、优化器配置。
可以为模型设置检查点,并稍后从完全相同的状态继续训练,而无需访问原始代码。
Keras通过检查架构来保存模型,使用HDF5标准提供基本的保存格式。
特别注意:

  • 目前无法保存TensorFlow优化器(来自tf.train)。
  • 使用此类优化器时,需要在加载模型后对其进行重新编译,使优化器的状态变松散。

MNIST数据集

MNIST(Mixed National Institute of Standards and Technology database)是一个计算机视觉数据集

示例

脚本内容

GitHub:https://github.com/anliven/Hello-AI/blob/master/Google-Learn-and-use-ML/5_save_and_restore_models.py

 # coding=utf-8
import tensorflow as tf
from tensorflow import keras
import numpy as np
import pathlib
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
print("# TensorFlow version: {} - tf.keras version: {}".format(tf.VERSION, tf.keras.__version__)) # 查看版本 # ### 获取示例数据集 ds_path = str(pathlib.Path.cwd()) + "\\datasets\\mnist\\" # 数据集路径
np_data = np.load(ds_path + "mnist.npz") # 加载numpy格式数据
print("# np_data keys: ", list(np_data.keys())) # 查看所有的键 # 加载mnist数据集
(train_images, train_labels), (test_images, test_labels) = keras.datasets.mnist.load_data(path=ds_path + "mnist.npz")
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # ### 定义模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) # 构建一个简单的模型
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model mod = create_model()
mod.summary() # ### 在训练期间保存检查点 # 检查点回调用法
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path) # 检查点存放目录
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True,
verbose=2) # 创建检查点回调
model1 = create_model()
model1.fit(train_images, train_labels,
epochs=10,
validation_data=(test_images, test_labels),
verbose=0,
callbacks=[cp_callback] # 将ModelCheckpoint回调传递给该模型
) # 训练模型,将创建一个TensorFlow检查点文件集合,这些文件在每个周期结束时更新 model2 = create_model() # 创建一个未经训练的全新模型(与原始模型架构相同,才能分享权重)
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
print("# Untrained model2, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) model2.load_weights(checkpoint_path) # 从检查点加载权重
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集,重新进行评估
print("# Restored model2, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 # 检查点回调选项
checkpoint_path2 = "training_2/cp-{epoch:04d}.ckpt" # 使用“str.format”方式为每个检查点设置唯一名称
checkpoint_dir2 = os.path.dirname(checkpoint_path)
cp_callback2 = tf.keras.callbacks.ModelCheckpoint(checkpoint_path2,
verbose=1,
save_weights_only=True,
period=5 # 每隔5个周期保存一次检查点
) # 创建检查点回调
model3 = create_model()
model3.fit(train_images, train_labels,
epochs=50,
callbacks=[cp_callback2], # 将ModelCheckpoint回调传递给该模型
validation_data=(test_images, test_labels),
verbose=0) # 训练一个新模型,每隔5个周期保存一次检查点并设置唯一名称
latest = tf.train.latest_checkpoint(checkpoint_dir2)
print("# latest checkpoint: {}".format(latest)) # 查看最新的检查点 model4 = create_model() # 重新创建一个全新的模型
loss, acc = model2.evaluate(test_images, test_labels) # 使用测试集进行评估
print("# Untrained model4, accuracy: {:5.2f}%".format(100 * acc)) # 未训练模型的表现(准确率约为10%) model4.load_weights(latest) # 加载最新的检查点
loss, acc = model4.evaluate(test_images, test_labels) #
print("# Restored model4, accuracy: {:5.2f}%".format(100 * acc)) # 模型表现得到大幅提升 # ### 手动保存权重
model5 = create_model()
model5.fit(train_images, train_labels,
epochs=10,
validation_data=(test_images, test_labels),
verbose=0) # 训练模型
model5.save_weights('./training_3/my_checkpoint') # 手动保存权重 model6 = create_model()
loss, acc = model6.evaluate(test_images, test_labels)
print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc))
model6.load_weights('./training_3/my_checkpoint')
loss, acc = model6.evaluate(test_images, test_labels)
print("# Restored model6, accuracy: {:5.2f}%".format(100 * acc)) # ### 保存整个模型
model7 = create_model()
model7.fit(train_images, train_labels, epochs=5)
model7.save('my_model.h5') # 保存整个模型到HDF5文件 model8 = keras.models.load_model('my_model.h5') # 重建完全一样的模型,包括权重和优化器
model8.summary()
loss, acc = model8.evaluate(test_images, test_labels)
print("Restored model8, accuracy: {:5.2f}%".format(100 * acc))

运行结果

C:\Users\anliven\AppData\Local\conda\conda\envs\mlcc\python.exe D:/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML/5_save_and_restore_models.py
# TensorFlow version: 1.12.0 - tf.keras version: 2.1.6-tf
# np_data keys: ['x_test', 'x_train', 'y_train', 'y_test']
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense (Dense) (None, 512) 401920
_________________________________________________________________
dropout (Dropout) (None, 512) 0
_________________________________________________________________
dense_1 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________ Epoch 00001: saving model to training_1/cp.ckpt
Epoch 00002: saving model to training_1/cp.ckpt
Epoch 00003: saving model to training_1/cp.ckpt
Epoch 00004: saving model to training_1/cp.ckpt
Epoch 00005: saving model to training_1/cp.ckpt
Epoch 00006: saving model to training_1/cp.ckpt
Epoch 00007: saving model to training_1/cp.ckpt
Epoch 00008: saving model to training_1/cp.ckpt
Epoch 00009: saving model to training_1/cp.ckpt
Epoch 00010: saving model to training_1/cp.ckpt 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model2, accuracy: 8.20% 32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model2, accuracy: 86.40% Epoch 00005: saving model to training_2/cp-0005.ckpt
Epoch 00010: saving model to training_2/cp-0010.ckpt
Epoch 00015: saving model to training_2/cp-0015.ckpt
Epoch 00020: saving model to training_2/cp-0020.ckpt
Epoch 00025: saving model to training_2/cp-0025.ckpt
Epoch 00030: saving model to training_2/cp-0030.ckpt
Epoch 00035: saving model to training_2/cp-0035.ckpt
Epoch 00040: saving model to training_2/cp-0040.ckpt
Epoch 00045: saving model to training_2/cp-0045.ckpt
Epoch 00050: saving model to training_2/cp-0050.ckpt # latest checkpoint: training_1\cp.ckpt 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 140us/step
# Untrained model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 2s
1000/1000 [==============================] - 0s 110us/step
# Restored model4, accuracy: 86.40% 32/1000 [..............................] - ETA: 5s
1000/1000 [==============================] - 0s 220us/step
# Restored model6, accuracy: 18.20% 32/1000 [..............................] - ETA: 0s
1000/1000 [==============================] - 0s 40us/step
# Restored model6, accuracy: 87.40%
Epoch 1/5 32/1000 [..............................] - ETA: 9s - loss: 2.4141 - acc: 0.0625
320/1000 [========>.....................] - ETA: 0s - loss: 1.8229 - acc: 0.4469
576/1000 [================>.............] - ETA: 0s - loss: 1.4932 - acc: 0.5694
864/1000 [========================>.....] - ETA: 0s - loss: 1.2624 - acc: 0.6481
1000/1000 [==============================] - 1s 530us/step - loss: 1.1978 - acc: 0.6620
Epoch 2/5 32/1000 [..............................] - ETA: 0s - loss: 0.5490 - acc: 0.8750
320/1000 [========>.....................] - ETA: 0s - loss: 0.4832 - acc: 0.8594
576/1000 [================>.............] - ETA: 0s - loss: 0.4630 - acc: 0.8715
864/1000 [========================>.....] - ETA: 0s - loss: 0.4356 - acc: 0.8808
1000/1000 [==============================] - 0s 200us/step - loss: 0.4298 - acc: 0.8790
Epoch 3/5 32/1000 [..............................] - ETA: 0s - loss: 0.1681 - acc: 0.9688
320/1000 [========>.....................] - ETA: 0s - loss: 0.2826 - acc: 0.9437
576/1000 [================>.............] - ETA: 0s - loss: 0.2774 - acc: 0.9340
832/1000 [=======================>......] - ETA: 0s - loss: 0.2740 - acc: 0.9327
1000/1000 [==============================] - 0s 200us/step - loss: 0.2781 - acc: 0.9280
Epoch 4/5 32/1000 [..............................] - ETA: 0s - loss: 0.1589 - acc: 0.9688
288/1000 [=======>......................] - ETA: 0s - loss: 0.2169 - acc: 0.9410
608/1000 [=================>............] - ETA: 0s - loss: 0.2186 - acc: 0.9457
864/1000 [========================>.....] - ETA: 0s - loss: 0.2231 - acc: 0.9479
1000/1000 [==============================] - 0s 200us/step - loss: 0.2164 - acc: 0.9480
Epoch 5/5 32/1000 [..............................] - ETA: 0s - loss: 0.1095 - acc: 1.0000
352/1000 [=========>....................] - ETA: 0s - loss: 0.1631 - acc: 0.9744
608/1000 [=================>............] - ETA: 0s - loss: 0.1671 - acc: 0.9638
864/1000 [========================>.....] - ETA: 0s - loss: 0.1545 - acc: 0.9688
1000/1000 [==============================] - 0s 210us/step - loss: 0.1538 - acc: 0.9670
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_14 (Dense) (None, 512) 401920
_________________________________________________________________
dropout_7 (Dropout) (None, 512) 0
_________________________________________________________________
dense_15 (Dense) (None, 10) 5130
=================================================================
Total params: 407,050
Trainable params: 407,050
Non-trainable params: 0
_________________________________________________________________ 32/1000 [..............................] - ETA: 3s
1000/1000 [==============================] - 0s 150us/step
Restored model8, accuracy: 86.10% Process finished with exit code 0

生成的文件

anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ll training_1
total 1601
-rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_1
total 1601
-rw-r--r-- 1 anliven 197121 71 5月 5 23:36 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_2
total 16001
-rw-r--r-- 1 anliven 197121 81 5月 5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0005.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0005.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0010.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0010.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0015.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0015.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0020.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0020.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:36 cp-0025.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:36 cp-0025.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0030.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0030.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0035.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0035.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0040.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0040.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0045.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0045.ckpt.index
-rw-r--r-- 1 anliven 197121 1631508 5月 5 23:37 cp-0050.ckpt.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 cp-0050.ckpt.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l training_3
total 1601
-rw-r--r-- 1 anliven 197121 83 5月 5 23:37 checkpoint
-rw-r--r-- 1 anliven 197121 1631517 5月 5 23:37 my_checkpoint.data-00000-of-00001
-rw-r--r-- 1 anliven 197121 647 5月 5 23:37 my_checkpoint.index anliven@ANLIVEN MINGW64 /d/Anliven/Anliven-Code/PycharmProjects/Google-Learn-and-use-ML
$ ls -l my_model.h5
-rw-r--r-- 1 anliven 197121 4909112 5月 5 23:37 my_model.h5

问题处理

问题描述:出现如下告警信息。

WARNING:tensorflow:This model was compiled with a Keras optimizer (<tensorflow.python.keras.optimizers.Adam object at 0x00000280FD318780>) but is being saved in TensorFlow format with `save_weights`. The model's weights will be saved, but unlike with TensorFlow optimizers in the TensorFlow format the optimizer's state will not be saved.

Consider using a TensorFlow optimizer from `tf.train`.

问题处理:

正常告警,对脚本运行和结果无影响,暂不关注。

AI - TensorFlow - 示例05:保存和恢复模型的更多相关文章

  1. 第六节,TensorFlow编程基础案例-保存和恢复模型(中)

    在我们使用TensorFlow的时候,有时候需要训练一个比较复杂的网络,比如后面的AlexNet,ResNet,GoogleNet等等,由于训练这些网络花费的时间比较长,因此我们需要保存模型的参数. ...

  2. 1 如何使用pb文件保存和恢复模型进行迁移学习(学习Tensorflow 实战google深度学习框架)

    学习过程是Tensorflow 实战google深度学习框架一书的第六章的迁移学习环节. 具体见我提出的问题:https://www.tensorflowers.cn/t/5314 参考https:/ ...

  3. AI - TensorFlow - 示例01:基本分类

    基本分类 基本分类(Basic classification):https://www.tensorflow.org/tutorials/keras/basic_classification Fash ...

  4. AI - TensorFlow - 示例03:基本回归

    基本回归 回归(Regression):https://www.tensorflow.org/tutorials/keras/basic_regression 主要步骤:数据部分 获取数据(Get t ...

  5. AI - TensorFlow - 示例02:影评文本分类

    影评文本分类 文本分类(Text classification):https://www.tensorflow.org/tutorials/keras/basic_text_classificatio ...

  6. AI - TensorFlow - 示例04:过拟合与欠拟合

    过拟合与欠拟合(Overfitting and underfitting) 官网示例:https://www.tensorflow.org/tutorials/keras/overfit_and_un ...

  7. TensorFlow学习笔记:保存和读取模型

    TensorFlow 更新频率实在太快,从 1.0 版本正式发布后,很多 API 接口就发生了改变.今天用 TF 训练了一个 CNN 模型,结果在保存模型的时候居然遇到各种问题.Google 搜出来的 ...

  8. 保存与恢复变量和模型,tensorflow官方文档阅读笔记

    官方中文文档的网址先贴出来:https://tensorflow.google.cn/programmers_guide/saved_model tf.train.Saver 类别提供了保存和恢复模型 ...

  9. tensorflow 1.0 学习:模型的保存与恢复(Saver)

    将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情.tf里面提供模型保存的是tf.train.Saver()模块. 模型保存,先要创建一个Saver对象:如 saver=tf. ...

随机推荐

  1. Filebeat在windows下安装使用

    一.windows下安装Filebeat 官网下载安装包 解压到指定目录,打开解压后的目录,打开filebeat.yml进行配置. 1.配置为输出到ElasticSearch ①:配置 Filebea ...

  2. 影像优化 OptimizeRaster工具包介绍

    Esri OptimizeRasters是一个高效.可配置的开源工具包. OptimizeRasters提供了以下功能: 影像格式转换和压缩.支持输出优化栅格格式:MRF.分块TIFF.云存储优化Ge ...

  3. YII框架的依赖注入容器

    依赖注入(Dependency Injection,DI)容器就是一个对象,它知道怎样初始化并配置对象及其依赖的所有对象. 所谓的依赖就是,一个对象,要使用另外一个对象才能完成某些功能.那么这个对象就 ...

  4. C++2.0新特性(五)——<Rvalue_reference和move语义>

    一.Rvalue_reference(右值引用)和move语义 1.左右值概念区分 左值:表达式结束后依然存在的对象,我们也叫做变量: 右值:表达式结束后就不存在的临时对象. 2.判断左值和右值 能对 ...

  5. finnal关键字修饰

    1.修饰变量,被赋值后不能被赋其他值,相当于常量 2.修饰方法,该方法不可以被子类重写,但可以重载 3.修饰类,修饰的类不可以被继承,如String,Math等

  6. 14.LAMP服务 Linux Apache Mysql Php和防护机制 xinetd、tcp wapper

    一.安装LAMP服务 Linux Apache Mysql Php       要求操作系统支持 php解析 apache调用php插件解析 phpmyadmin       yum install ...

  7. vmWare安装centos7之后使用yum安装提示there are on enabled repos(修改yum源)

    可以使用这个命令修改yum源 curl -o /etc/yum.repos.d/CentOS-Base.repo http://mirrors.aliyun.com/repo/Centos-7.rep ...

  8. 自定义Hooks函数获取窗口大小(十一)

    其实自定义Hooks函数和用Hooks创建组件很相似,跟我们平时用JavaScript写函数几乎一模一样,可能就是多了些React Hooks的特性,自定义Hooks函数偏向于功能,而组件偏向于界面和 ...

  9. gdbhooks 栈信息

    https://devguide.python.org/gdb/ https://sourceware.org/gdb/current/onlinedocs/gdb/Python.html#Pytho ...

  10. zookeeper/kafka的部署

    Ubuntu中安装zookeeper及kafka并配置环境变量   首先安装zookeeper zookeeper需要jdk环境,请在jdk安装完成的情况下安装zookeeper1.从官网下载zook ...