[深度学习] tf.keras入门5-模型保存和载入
目录
模型可以在训练中或者训练完成后保存。具体文档参考:https://tensorflow.google.cn/tutorials/keras/save_and_restore_models
设置
依赖项设置:
!pip install -q h5py pyyaml
模型建立:
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
from tensorflow import keras
tf.__version__
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
#创建模型
model = create_model()
model.summary()
基于checkpoints的模型保存
通过ModelCheckpoint模块来自动保存数据
#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1)
model = create_model()
model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型
通过load_weight读取权重
#对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4%
#载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2
手动保存权重
# 保存权重
model.save_weights('./checkpoints/my_checkpoint')
#恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%
整个模型保存
基于keras的HD5文件保存整个模型所有参数,优化器参数等。
#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%
总体代码
from __future__ import absolute_import, division, print_function
import os
import tensorflow as tf
from tensorflow import keras
tf.__version__
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
# 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy'])
return model
#创建模型
model = create_model()
model.summary()
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
'''
#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1)
model = create_model()
model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型
#对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4%
#载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2
# 保存权重
model.save_weights('./checkpoints/my_checkpoint')
#恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%
'''
#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%
[深度学习] tf.keras入门5-模型保存和载入的更多相关文章
- [深度学习] tf.keras入门1-基本函数介绍
目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...
- [深度学习] tf.keras入门4-过拟合和欠拟合
过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...
- [深度学习] tf.keras入门3-回归
目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...
- [深度学习] tf.keras入门2-分类
目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...
- 深度学习:Keras入门(一)之基础篇
1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...
- 深度学习:Keras入门(一)之基础篇【转】
本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...
- 深度学习:Keras入门(一)之基础篇(转)
转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)
说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...
- 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】
本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...
随机推荐
- python实现给定K个字符数组,从这k个字符数组中任意取一个字符串,按顺序拼接,列出所有可能的字符串组合结果!
题目描述:给定K个字符数组,从这k个字符数组中任意取一个字符串,按顺序拼接,列出所有可能的字符串组合结果! 样例: input:[["a","b"," ...
- 二手商城集成jwt认证授权
------------恢复内容开始------------ 使用jwt进行认证授权的主要流程 参考博客(https://www.cnblogs.com/RayWang/p/9536524.html) ...
- 2022-08-21-Freewind主题_cdn替换版
layout: post cid: 16 title: Freewind主题 cdn替换版 slug: 16 date: 2022/08/21 14:06:00 updated: 2022/08/21 ...
- 一天五道Java面试题----第十天(简述Redis事务实现--------->负载均衡算法、类型)
这里是参考B站上的大佬做的面试题笔记.大家也可以去看视频讲解!!! 文章目录 1.简述Redis事务实现 2.redis集群方案 3.redis主从复制的核心原理 4.CAP理论,BASE理论 5.负 ...
- 解决ffmpeg的播放摄像头的延时优化问题(项目案例使用有效)
在目前的项目中使用了flv的播放摄像头的方案,但是延时达到了7-8秒,所以客户颇有微词,没有办法,只能开始优化播放延时的问题,至于对接摄像头的方案有好几种,这种咱们以后在聊,今天只要聊聊聊优化参数的问 ...
- 如何检查“lateinit”变量是否已初始化?
kotlin中经常会使用延迟初始化,如果要校验lateinit var 变量是否初始化.可以使用属性引用上的.isInitialized. 原文中是这样描述的:To check whether a l ...
- 快读《ASP.NET Core技术内幕与项目实战》WebApi3.1:WebApi最佳实践
本节内容,涉及到6.1-6.6(P155-182),以WebApi说明为主.主要NuGet包:无 一.创建WebApi的最佳实践,综合了RPC和Restful两种风格的特点 1 //定义Person类 ...
- OpenFOAM 编程 | 求解捕食者与被捕食者模型(predator-prey model)问题(ODEs)
0. 写在前面 本文问题参考自文献 \(^{[1]}\) 第一章例 6,并假设了一些条件,基于 OpenFOAM-v2206 编写程序数值上求解该问题.笔者之前也写过基于 OpenFOAM 求解偏分方 ...
- Django系列---理论一
教程:http://c.biancheng.net/django/ 特点 集成 ORM 组件:Django 的 Model 层自带数据库 ORM 组件,为操作不同类型的数据库提供了统一的方式. URL ...
- 2流高手速成记(之八):基于Sentinel实现微服务体系下的限流与熔断
我们接上回 上一篇中,我们进行了简要的微服务实现,也体会到了SpringCloudAlibaba的强大和神奇之处 我们仅改动了两个注释,其他全篇代码不变,原来的独立服务就被我们分为了provider和 ...