Tensorflow学习笔记No.9
模型的保存与恢复
介绍一些常见的模型保存与恢复方法,以及如何使用回调函数保存模型。
1.保存完整模型
model.save()方法可以保存完整的模型,包括模型的架构、模型的权重以及优化器。
model.save()的参数为保存路径以及文件名。
首先我们构建一个简单的Sequential模型,使用fishion_mnist数据集进行训练,得到一个训练后的模型。
1 import tensorflow as tf
2 import numpy as np
3
4 (train_image, train_label), (test_image, test_label) = tf.keras.datasets.fashion_mnist.load_data()
5
6 train_image = np.expand_dims(train_image, -1)
7 test_image = np.expand_dims(test_image, -1)
8
9 model = tf.keras.Sequential()
10 model.add(tf.keras.layers.Conv2D(32, [3, 3], input_shape = (28, 28, 1), activation = 'relu'))
11 model.add(tf.keras.layers.Conv2D(64, [3, 3], activation = 'relu'))
12 model.add(tf.keras.layers.GlobalAveragePooling2D())
13 model.add(tf.keras.layers.Dense(64, activation = 'relu'))
14 model.add(tf.keras.layers.Dense(10, activation = 'softmax'))
15
16 model.compile(optimizer = 'adam',
17 loss = 'sparse_categorical_crossentropy',
18 metrics = ['acc'])
19
20 history = model.fit(train_image, train_label,
21 epochs = 10,
22 validation_data = (test_image, test_label))
使用model.summary()查看当前模型结构:
1 model.summary()

训练完成后我们使用model.evaluate()方法对测试集进行评估。
1 model.evaluate(test_image, test_label)
正确率如下图所示:

然后我们使用model.save()方法保存完整模型。
1 model.save('model_1.h5')
保存后我们会得到一个名为model_1.h5的文件,这个文件就是保存好的模型。
保存好的模型会放到指定位置。

我们可以使用tf.keras.models.load_model()方法来导入我们保存好的模型,参数为已保存模型的储存位置以及文件名。
1 new_model = tf.keras.models.load_model('model_1.h5')
我们使用.summary()方法查看一下模型结构是否与之前相同。
1 new_model.summary()

可以发现模型结构是完全一致的。
然后我们使用.evaluate()方法对测试集进行评估,看一下模型权重是否被保存。
1 new_model.evaluate(test_image, test_label)

可以发现,loss值与之前完全相同。
注意:这里acc值发生了很严重的变化,目前不知道我还是什么原因导致的,但这不代表我们保存的模型或者是权重出现了问题,使用.predict()方法依然可以正常对数据进行分类预测,这是我使用.predict()方法预测后与原数据一一比对后确认过的,可能是出现了一个小bug,知道原因的小伙伴可以评论区回复我。
重申:模型是已经正常保存了的,是可以正常使用的,大家无需担心,loss值足够证明我们的模型是正常保存的。
2.仅保存模型结构
有时候我们可能不需要保存模型的权重,而只想保存模型的架构。
这时可以使用model.to_jison()来获取模型架构。
1 json = model.to_json()
json中保存了模型架构的完整信息。
我们可以使用python的文件操作方法将它写入到磁盘上,使用时再从磁盘上读入即可,这里不详细说明了,大家自行百度即可。
使用tf.keras.models.model_from_jsom()来恢复模型,参数为我们之前保存模型信息的变量json。
1 new_model = tf.keras.models.model_from_json(json)
同样,我们查看模型结构并对模型进行评估。
1 new_model.summary()

1 new_model.compile(optimizer = 'adam',
2 loss = 'sparse_categorical_crossentropy',
3 metrics = ['acc'])
4
5 new_model.evaluate(test_image, test_label)
注意,由于我们没有保存优化器,所以要先对模型添加一个优化器再进行评估。

可以发现loss值非常的大,也就说明的模型没有被训练过,模型中的参数都是随机产生的。
3.仅保存模型权重
同样的,我们也可以仅保存模型权重。
权重的保存有两种方法,可以像上面保存模型结构一样使用model.get_weights()把模型结构读入到变量中再进行保存,也可以使用keras提供的方法直接保存到磁盘上。
这里主要介绍第二种(主要是第一种用处不大)。
使用model.save_weights()方法进行保存,参数为保存路径以及文件名。
1 model.save_weights('weights_1.h5')
同样,我们会得到一个对应的文件。

然后使用.load_weigths()方法可以载入权重,参数为路径及文件名。
1 new_model.load_weights('weights_1.h5')
对测试集进行评估查看是否被正常载入。
1 new_model.evaluate(test_image, test_label)

loss值与之前相同,说明权重被正常载入了。
注意,保存权重也不会保存优化器,这里不用重定义优化器是因为上面已经给new_model这个对象定义过优化器了。
4.使用回调函数保存模型
我觉得这是最实用也是最好的模型保存方法。
首先定义一个回调函数监测训练过程并保存模型。
使用tf.keras.callbacks.ModelCheckpoint()来定义这样一个回调函数。
它的主要参数为:
filepath:储存位置。
moinitor = 'val_loss':监视的变量。
verboss = 0:是否显示详细信息。
save_best_only = False:为True则会保存loss最低的或者acc最高的。
save_weihts_only = False:是否只保存权重,为False会保存整个模型。
1 checkpoint = tf.keras.callbacks.ModelCheckpoint('modelcp',
2 save_weights_only = True,
3 save_best_only = True,
4 verbose = 1)
然后我们构建模型训练一下试试。
1 new_model = tf.keras.models.model_from_json(json)
2
3 new_model.compile(optimizer = 'adam',
4 loss = 'sparse_categorical_crossentropy',
5 metrics = ['acc'])
6
7 history = new_model.fit(train_image, train_label,
8 epochs = 5,
9 validation_data = (test_image, test_label),
10 callbacks = [checkpoint])
要在.fit()中加入callbacks参数调用回调函数。

可以发现我们的模型信息被保存了,同时多出来三个保存好的文件。

同样使用.load_weights()来载入权重,并进行评估。
1 new_model = tf.keras.models.model_from_json(json)
2
3 new_model.compile(optimizer = 'adam',
4 loss = 'sparse_categorical_crossentropy',
5 metrics = ['acc'])
6
7 new_model.load_weights('modelcp')
8
9 new_model.evaluate(test_image, test_label)
得到结果:

与训练时最后保存的结果相同。
关于模型的保存方法就介绍到这里了,后续会更新更多内容哦!o(* ̄▽ ̄*)o
Tensorflow学习笔记No.9的更多相关文章
- Tensorflow学习笔记2:About Session, Graph, Operation and Tensor
简介 上一篇笔记:Tensorflow学习笔记1:Get Started 我们谈到Tensorflow是基于图(Graph)的计算系统.而图的节点则是由操作(Operation)来构成的,而图的各个节 ...
- Tensorflow学习笔记2019.01.22
tensorflow学习笔记2 edit by Strangewx 2019.01.04 4.1 机器学习基础 4.1.1 一般结构: 初始化模型参数:通常随机赋值,简单模型赋值0 训练数据:一般打乱 ...
- Tensorflow学习笔记2019.01.03
tensorflow学习笔记: 3.2 Tensorflow中定义数据流图 张量知识矩阵的一个超集. 超集:如果一个集合S2中的每一个元素都在集合S1中,且集合S1中可能包含S2中没有的元素,则集合S ...
- TensorFlow学习笔记之--[compute_gradients和apply_gradients原理浅析]
I optimizer.minimize(loss, var_list) 我们都知道,TensorFlow为我们提供了丰富的优化函数,例如GradientDescentOptimizer.这个方法会自 ...
- 深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识
深度学习-tensorflow学习笔记(1)-MNIST手写字体识别预备知识 在tf第一个例子的时候需要很多预备知识. tf基本知识 香农熵 交叉熵代价函数cross-entropy 卷积神经网络 s ...
- 深度学习-tensorflow学习笔记(2)-MNIST手写字体识别
深度学习-tensorflow学习笔记(2)-MNIST手写字体识别超级详细版 这是tf入门的第一个例子.minst应该是内置的数据集. 前置知识在学习笔记(1)里面讲过了 这里直接上代码 # -*- ...
- tensorflow学习笔记(4)-学习率
tensorflow学习笔记(4)-学习率 首先学习率如下图 所以在实际运用中我们会使用指数衰减的学习率 在tf中有这样一个函数 tf.train.exponential_decay(learning ...
- tensorflow学习笔记(3)前置数学知识
tensorflow学习笔记(3)前置数学知识 首先是神经元的模型 接下来是激励函数 神经网络的复杂度计算 层数:隐藏层+输出层 总参数=总的w+b 下图为2层 如下图 w为3*4+4个 b为4* ...
- tensorflow学习笔记(2)-反向传播
tensorflow学习笔记(2)-反向传播 反向传播是为了训练模型参数,在所有参数上使用梯度下降,让NN模型在的损失函数最小 损失函数:学过机器学习logistic回归都知道损失函数-就是预测值和真 ...
- tensorflow学习笔记(1)-基本语法和前向传播
tensorflow学习笔记(1) (1)tf中的图 图中就是一个计算图,一个计算过程. 图中的constant是个常量 计 ...
随机推荐
- Linux基本目录机构
Linux基本目录机构 1. 基本介绍 Linux的文件系统采用级层式子的树状目录结构 最上层是根目录"/" Linux世界里,一切皆文件 2. 目录用途 /bin: 是Binar ...
- Shell学习(三)Shell参数传递
一.传参实例 ##脚本文件内容 #执行的文件名 echo $0; #第一个参数 echo $1; #第二个参数 echo $2; #第三个参数 echo $3; ##调用语句 ./testShell. ...
- zico2靶机渗透
zico2靶机渗透 开放了四个端口,分别是22,80,111以及57781端口. 扫到了目录http://192.168.114.152/dbadmin/ 进入看到php文件,访问,发现一个登录窗口. ...
- MyEclipse中的项目导入到Eclipse中运行的错误解决
之前用的myEclipse,后来把项目导入eclipse发现报错,将MyEclipse中的项目导入到Eclipse中运行,不注意一些细节,会造成无法运行的后果.下面就说说具体操作:导入后出现如下错误: ...
- 每日一个知识点系列:volatile的可见性原理
每日一个知识点系列的目的是针对某一个知识点进行概括性总结,可在一分钟内完成知识点的阅读理解,此处不涉及详细的原理性解读. img 看图说话 关键点1: 总线嗅探器(MESI 缓存一致性原理 ) 关键点 ...
- 主键生成器效率提升方案|基于雪花算法和Redis控制进程隔离
背景 主键生成效率用数据库自增效率也是比较高的,为什么要用主键生成器呢?是因为需要insert主表和明细表时,明细表有个字段是主表的主键作为关联.所以就需要先生成主键填好主表明细表的信息后再一次过在一 ...
- Spring AOP系列(二) — 动态代理引言
接上一篇Spring AOP系列(一)- 代理模式,本篇来聊聊动态代理. 动态代理与静态代理的区别 要想了解动态代理与静态代理的区别,需要有两个前置知识点:java程序是如何执行的以及类加载机制. j ...
- OpenMP变量作用域【private】【shared】
(1) privateprivate子句将一个或多个变量声明为线程的私有变量.每个线程都有它自己的变量私有副本,其他线程无法访问.即使在并行区域外有同名的共享变量,共享变量在并行区域内不起任何作用,并 ...
- 【网络协议】TCP/IP:数据链路层
物理层负责把计算机中的0.1数字信号转换为具体传输媒介的物理信号(电压的高低.电波的强弱.光的闪灭) 数据链路层协议定义了(通过通信介质互连的设备间的)数据传输规范 (常见的通信介质有同轴电缆.双绞线 ...
- GAN训练技巧汇总
GAN自推出以来就以训练困难著称,因为它的训练过程并不是寻找损失函数的最小值,而是寻找生成器和判别器之间的纳什均衡.前者可以直接通过梯度下降来完成,而后者除此之外,还需要其它的训练技巧. 下面对历年关 ...