目录

设置

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

手动保存权重

整个模型保存

总体代码


模型可以在训练中或者训练完成后保存。具体文档参考:https://tensorflow.google.cn/tutorials/keras/save_and_restore_models

设置

依赖项设置:

!pip install -q h5py pyyaml 

模型建立:

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000]
test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']) return model #创建模型
model = create_model()
model.summary()

基于checkpoints的模型保存

通过ModelCheckpoint模块来自动保存数据

#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型

通过load_weight读取权重

#对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4% #载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2

手动保存权重

# 保存权重
model.save_weights('./checkpoints/my_checkpoint') #恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%

整个模型保存

基于keras的HD5文件保存整个模型所有参数,优化器参数等。

#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

总体代码

from __future__ import absolute_import, division, print_function

import os

import tensorflow as tf
from tensorflow import keras tf.__version__ (train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() train_labels = train_labels[:1000]
test_labels = test_labels[:1000] train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0 # 模型创建模型
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation=tf.nn.relu, input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10, activation=tf.nn.softmax)
]) model.compile(optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.sparse_categorical_crossentropy,
metrics=['accuracy']) return model #创建模型
model = create_model()
model.summary() checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
'''
#创建回调函数
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path,
save_weights_only=True, #只保存权重
verbose=1) model = create_model() model.fit(train_images, train_labels, epochs = 10,
validation_data = (test_images,test_labels),
callbacks = [cp_callback]) #保存模型 #对全新没有训练的模型进行预测
model = create_model()
loss, acc = model.evaluate(test_images, test_labels)
print("Untrained model, accuracy: {:5.2f}%".format(100*acc)) #11.4% #载入权重参数后的模型
model.load_weights(checkpoint_path)
loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.2 # 保存权重
model.save_weights('./checkpoints/my_checkpoint') #恢复模型
model = create_model()
model.load_weights('./checkpoints/my_checkpoint') loss,acc = model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #87.00%
'''
#将整个模型保存为HDF5文件
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
#载入一个相同的模型
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc)) #86.30%

[深度学习] tf.keras入门5-模型保存和载入的更多相关文章

  1. [深度学习] tf.keras入门1-基本函数介绍

    目录 构建一个简单的模型 序贯(Sequential)模型 网络层的构造 模型训练和参数评价 模型训练 模型的训练 tf.data的数据集 模型评估和预测 基本模型的建立 网络层模型 模型子类函数构建 ...

  2. [深度学习] tf.keras入门4-过拟合和欠拟合

    过拟合和欠拟合 简单来说过拟合就是模型训练集精度高,测试集训练精度低:欠拟合则是模型训练集和测试集训练精度都低. 官方文档地址为 https://tensorflow.google.cn/tutori ...

  3. [深度学习] tf.keras入门3-回归

    目录 波士顿房价数据集 数据集 数据归一化 模型训练和预测 模型建立和训练 模型预测 总结 回归主要基于波士顿房价数据库进行建模,官方文档地址为:https://tensorflow.google.c ...

  4. [深度学习] tf.keras入门2-分类

    目录 Fashion MNIST数据库 分类模型的建立 模型预测 总体代码 主要介绍基于tf.keras的Fashion MNIST数据库分类, 官方文档地址为:https://tensorflow. ...

  5. 深度学习:Keras入门(一)之基础篇

    1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深度学习框架. Keras是一个高层神经网络API,支持快速实验,能够把你的idea迅速转换为结 ...

  6. 深度学习:Keras入门(一)之基础篇【转】

    本文转载自:http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorfl ...

  7. 深度学习:Keras入门(一)之基础篇(转)

    转自http://www.cnblogs.com/lc1217/p/7132364.html 1.关于Keras 1)简介 Keras是由纯python编写的基于theano/tensorflow的深 ...

  8. 深度学习:Keras入门(二)之卷积神经网络(CNN)

    说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么是卷积? 简单来说,卷积(或内积)就是一种先把对应位置相乘然后再把结果相加的运算.(具体含义或者数学公式 ...

  9. 深度学习:Keras入门(二)之卷积神经网络(CNN)【转】

    本文转载自:https://www.cnblogs.com/lc1217/p/7324935.html 说明:这篇文章需要有一些相关的基础知识,否则看起来可能比较吃力. 1.卷积与神经元 1.1 什么 ...

随机推荐

  1. 插入排序算法(Java代码实现)

    其它经典排序算法:https://blog.csdn.net/weixin_43304253/article/details/121209905 插入排序算法: 思路:将数据分为已经排序好的数据和未排 ...

  2. 累加和为 K 的子数组问题

    累加和为 K 的子数组问题 作者:Grey 原文地址: 博客园:累加和为 K 的子数组问题 CSDN:累加和为 K 的子数组问题 题目说明 数组全为正数,且每个数各不相同,求累加和为K的子数组组合有哪 ...

  3. 18.-cookies和session

    一.会话定义 从打开浏览器访问一个网站,到关闭浏览器结束此次访问,称之为一次绘画 HTTP协议是无状态的,导致绘画状态难以保持 Cookies和session就是为了保持会话状态而诞生的两个存储技术 ...

  4. maven的下载、安装、配置,idea中配置Maven

    下载 下载链接: 点击下载地址 : 找到:对应版本的包下载 安装 下载后的压缩包解压出来,然后将解压后的包放到日常安装软件的位置即安装成功,当然取决于个人意愿,也可以不移动. 打开安装包后的目录结构简 ...

  5. 从ObjectPool到CAS指令

    相信最近看过我的文章的朋友对于Microsoft.Extensions.ObjectPool不陌生:复用.池化是在很多高性能场景的优化技巧,它能减少内存占用率.降低GC频率.提升系统TPS和降低请求时 ...

  6. 死磕面试系列,Java到底是值传递还是引用传递?

    Java到底是值传递还是引用传递? 这虽然是一个老生常谈的问题,但是对于没有深入研究过这块,或者Java基础不牢的同学,还是很难回答得让人满意. 可能很多同学能够很轻松的背出JVM.分布式事务.高并发 ...

  7. Windows骚操作

    电脑常用的快捷键 键盘功能健:Tab.Shift.Ctrl.Alt.Windows.Enter.空格.上下左右健.CapsLock(大小写转换).NumLock(对小键盘控制开/关) 键盘快捷键:全选 ...

  8. [C++] - GCC和LLVM对方法 warning: non-void function does not return a value [-Wreturn-type] 的处理差异

    最近做一个C++开源项目发现一个奇怪问题,通过clang编译链接执行程序每到有一个就崩溃了,gcc下则没有此问题. 后来通过调试,发现原因是bool返回的方法是没有return语句!问题是为啥还能通过 ...

  9. 【题解】CF45I TCMCF+++

    题面传送门 题目描述 有 \(n\) 个数 \(a_i\) 请你从中至少选出一个数,使它们的乘积最大 解决思路 对于正数,对答案一定有贡献(正数越乘越大),所以输入正数时直接输出即可. 对于负数,如果 ...

  10. Day10:for循环结构的使用详解

    for循环 将0~100内的奇.偶数分别求和 思路 第一步先将0~100以内的奇.偶数分成两队,第二步使奇数累加.ou'shu public class ForCirculate{ public st ...