模型的保存与加载一般有三种模式:save/load weights(最干净、最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有的状态都保存起来),saved_model(更通用的方式,以固定模型格式保存,该格式是各种语言通用的)

具体使用方法如下:

        # 保存模型
model.save_weights('./checkpoints/my_checkpoint')
# 加载模型
model = keras.create_model()
model.load_weights('./checkpoints/my_checkpoint')

示例:

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics def preprocess(x, y):
x = tf.cast(x, dtype=tf.float32) / 255.
x = tf.reshape(x, [28 * 28])
y = tf.cast(y, dtype=tf.int32)
y = tf.one_hot(y, depth=10)
return x, y batchsz = 128
(x, y), (x_val, y_val) = datasets.mnist.load_data()
print('datasets:', x.shape, y.shape, x.min(), x.max()) db = tf.data.Dataset.from_tensor_slices((x, y))
db = db.map(preprocess).shuffle(60000).batch(batchsz)
ds_val = tf.data.Dataset.from_tensor_slices((x_val, y_val))
ds_val = ds_val.map(preprocess).batch(batchsz) sample = next(iter(db))
print(sample[0].shape, sample[1].shape) network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.build(input_shape=(None, 28 * 28))
network.summary() network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
) network.fit(db, epochs=3, validation_data=ds_val, validation_freq=2) network.evaluate(ds_val) network.save_weights('weights.ckpt')
print('saved weights.')
del network network = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(64, activation='relu'),
layers.Dense(32, activation='relu'),
layers.Dense(10)])
network.compile(optimizer=optimizers.Adam(lr=0.01),
loss=tf.losses.CategoricalCrossentropy(from_logits=True),
metrics=['accuracy']
)
network.load_weights('weights.ckpt')
print('loaded weights!')
network.evaluate(ds_val)

运行效果如下:

可以看到保存前后的精度和损失差距不大,这是由于神经网络的运算过程中会有很多不确定因子,这些不确定因子不会通过save_weights方法保存,要想保存前后运行结果一致,就需要完整的保存网络模型。即model.save方法

使用方法如下:

# 模型保存
network.save('model.h5')
print('saved total model.')
# 模型加载
print('load model from file')
network = tf.keras.models.load_model('model.h5')
# 评估
network.evaluate(x_val,y_val)

除了这种方法之外,tensorflow还支持保存为标准的可以给其他语言使用的模型,使用saved_model即可

使用方法如下:

tf.saved_model.save(m,'/tmp/saved_model/')
imported = tf.saved_model.load(path)
f = imported.signatures["serving_default"]
print(f(x=tf.ones([1,28,28,3])))

tensorflow模型的保存与加载的更多相关文章

  1. tensorflow 之模型的保存与加载(二)

    上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...

  2. tensorflow 之模型的保存与加载(三)

    前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...

  3. tensorflow 之模型的保存与加载(一)

    怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...

  4. Python之TensorFlow的模型训练保存与加载-3

    一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...

  5. Tensorflow 模型持久化saver及加载图结构

    主要内容: 1. 直接保存,加载模型; (可以指定加载,保存的var_list) 2. 加载,保存指定变量的模型 3. slim加载模型使用 4. 加载模型图结构和参数等 tensorflow 恢复部 ...

  6. (sklearn)机器学习模型的保存与加载

    需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...

  7. pytorch_模型参数-保存,加载,打印

    1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...

  8. pytorch 中模型的保存与加载,增量训练

     让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...

  9. fashion_mnist多分类训练,两种模型的保存与加载

    from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...

随机推荐

  1. Codeforces_828

    A.模拟,注意单人的时候判断顺序. #include<bits/stdc++.h> using namespace std; int n,a,b; int main() { ios::sy ...

  2. Java集合中removeIf的使用

    在JDK1.8中,Collection以及其子类新加入了removeIf方法,作用是按照一定规则过滤集合中的元素.这里给读者展示removeIf的用法.首先设想一个场景,你是公司某个岗位的HR,收到了 ...

  3. Algorithms - Insertion Sort - 插入排序

    Insertion Sort - 插入排序 插入排序算法的 '时间复杂度' 是输入规模的二次函数, 深度抽象后表示为, n 的二次方. import time, random F = 0 alist ...

  4. MySQL8.0 InnoDB并行执行

    概述 MySQL经过多年的发展已然成为最流行的数据库,广泛用于互联网行业,并逐步向各个传统行业渗透.之所以流行,一方面是其优秀的高并发事务处理的能力,另一方面也得益于MySQL丰富的生态.MySQL在 ...

  5. php面试笔记(4)-php基础知识-流程控制

    本文是根据慕课网Jason老师的课程进行的PHP面试知识点总结和升华,如有侵权请联系我进行删除,email:guoyugygy@163.com 在面试中,考官往往喜欢基础扎实的面试者,而流程控制相关的 ...

  6. Java开发最佳实践(二) ——《Java开发手册》之"异常处理、MySQL 数据库"

    二.异常日志 (一) 异常处理 (二) 日志规约 三.单元测试 四.安全规约 五.MySQL数据库 (一) 建表规约 (二) 索引规约 (三) SQL语句 (四) ORM映射 六.工程结构 七.设计规 ...

  7. Linux文本三剑客

    grep 文本过滤工具. 作用: 文本搜索工具,根据用户指定的行进行匹配检查,打印匹配到的行. 模式: 由正则表达式字符及文本字符所编写的过滤条件. grep的使用 语法:  grep [OPTION ...

  8. centos7安装bind(DNS服务)

    环境介绍 公网IP:149.129.92.239 内网IP:172.17.56.249 系统:CentOS 7.4 一.安装 yum install bind bind-utils -y 二.修改bi ...

  9. 杭电--1009 C语言实现

    思路:要用有限的猫粮得到最多的javabean,则在房间中得到的javabean比例应尽可能的大. 用一个结构体,保存每个房间中的javabean和猫粮比例和房间号,然后将结构体按比例排序,则从比例最 ...

  10. 前端vue开发中的跨域问题解决,以及nginx上线部署。(vue devServer与nginx)

    前言 最近做的一个项目中使用了vue+springboot的前后端分离模式 在前端开发的的时候,使用vue cli3的devServer来解决跨域问题 上线部署则是用的nginx反向代理至后台服务所开 ...