目录

设置

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

手动保存权重

整个模型保存

总体代码


模型可以在训练中或者训练完成后保存。具体文档参考:https://tensorflow.google.cn/tutorials/keras/save_and_restore_models

设置

依赖项设置:

!pip install -q h5py pyyaml 

模型建立:

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000]
test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']) return model #创建模型
model = create_model()
model.summary()

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型

通过load_weight读取权重

#对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4% #载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2

手动保存权重

# 保存权重
model.save_weights('./checkpoints/my_checkpoint') #恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%

整个模型保存

基于keras的HD5文件保存整个模型所有参数,优化器参数等。

#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

总体代码

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000]
test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']) return model #创建模型
model = create_model()
model.summary() checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
'''
#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型 #对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4% #载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2 # 保存权重
model.save_weights('./checkpoints/my_checkpoint') #恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%
'''
#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

[深度学习] tf.keras入门5-模型保存和载入的更多相关文章

  1. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  4. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. Mybatis PageHelper 使用的注意事项

    什么时候会导致不安全的分页? PageHelper 方法使用了静态的 ThreadLocal 参数,分页参数和线程是绑定的. 只要你可以保证在 PageHelper 方法调用后紧跟 MyBatis 查 ...

  2. if、where、trim、choose、when、otherwise、foreach

    1.if if标签可通过test属性的表达式进行判断,若表达式的结果为true,则标签中的内容会执行:反之标签中 的内容不会执行 <!--List<Emp> getEmpListBy ...

  3. __g is not defined

    新手小白学习小程序开发遇到的问题以及解决方法 文章目录 1.出现的问题 2.解决的方法 1.出现的问题 2.解决的方法 删除app.json中的 "lazyCodeLoading" ...

  4. 制作一个windows垃圾清理小程序

    制作一个windows垃圾清理小程序: 把下列代码保存为.bat文件(如垃圾清理.bat) 双击它就能很快地清理垃圾文件,大约一分钟不到. 就是下面的文字(这行不用复制)=============== ...

  5. Spring Retry 重试

    重试的使用场景比较多,比如调用远程服务时,由于网络或者服务端响应慢导致调用超时,此时可以多重试几次.用定时任务也可以实现重试的效果,但比较麻烦,用Spring Retry的话一个注解搞定所有.话不多说 ...

  6. MongoDB数据库新手入门

    windows安装mongodb 5.0.2 官网下载msi文件 自定义安装到 d:/apptoools/mongodb/ 不要勾选mongodb compass 报错:verify that you ...

  7. 安装nvm 和 yarn

    安装nvm curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.39.1/install.sh | bash 执行上面的命令 如果出现问题 ...

  8. MFC 学习笔记

    MFC 学习笔记 一.MFC编程基础: 概述: 常用头文件: MFC控制台程序: MFC库程序: 规则库可以被各种程序所调用,扩展库只能被MFC程序调用. MFC窗口程序: 示例: MFC库中类的简介 ...

  9. Pycharm系列---QT配置

    PYSIDE2 添加外部工具 file---settings External Tools,点击左上角的 加号+ designer 位置: envs\QT6\Lib\site-packages\PyS ...

  10. Linux网络通信(TCP套接字编写,多进程多线程版本)

    预备知识 源IP地址和目的IP地址 IP地址在上一篇博客中也介绍过,它是用来标识网络中不同主机的地址.两台主机进行通信时,发送方需要知道自己往哪一台主机发送,这就需要知道接受方主机的的IP地址,也就是 ...