Python之TensorFlow的模型训练保存与加载-3
一、TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式。我们把训练好的模型通过二次加载训练,或者独立加载模型训练。这基本上都是比较常用的方式。
二、模型的保存与加载类型有2种
1)需要重新建立图谱,来实现模型的加载
2)独家加载模型
模型的保存与训练加载:
tf.train.Saver(<var_list>,<max_to_keep>)
var_list: 指定要保存和还原的变量,作为一个dict或者list传递
max_to_keep: 指示要保留的最大检查点文件个数。
保存模型的文件:checkpoint文件/检查点文件
method:
save(<session>, <path>)
restore(<session>, <path>)
模型的独立加载:
1、tf.train.import_meta_graph(<meta_graph_or_file>) 读取训练时的数据流图
meta_graph_or_file: *.meta的文件
2、saver.restore(<session>, tf.train.latest_checkpoint(<path>)) 加载最后一次检测点
path: 含有checkpoint的上一级目录
3、graph = tf.get_default_graph() 默认图谱
graph.get_tensor_by_name(<name>) 获取对应数据传入占位符
name: tensor的那么名称,如果没有生命name,则为(placeholder:0), 数字0依次往后推
graph.get_collection(<name>) 获取收集集合
return tensor列表
补充:
如果不知道怎么去获取tensor的相关图谱,可以通过
graph.get_operations() 查看所有的操作符,最好断点查看
三、模型的保存与训练加载
import os
import tensorflow as tf def model_save():
# 1、准备特征值和目标值
with tf.variable_scope("data"):
# 占位符,用于数据传入
x = tf.placeholder(dtype=tf.float32, shape=[None, 1], name="x")
# 矩阵相乘必须是二维(为了模拟效果而设定固定值来训练)
y_true = tf.matmul(x, [[0.7]]) + 0.8 # 2、建立回归模型,随机给权重值和偏置的值,让他去计算损失,然后在当前状态下优化
with tf.variable_scope("model"):
# 模型 y = wx + b, w的个数根据特征数据而定,b随机
# 其中Variable的参数trainable可以指定变量是否跟着梯度下降一起优化(默认True)
w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name="w", trainable=True)
b = tf.Variable(0.0, name="b")
# 预测值
y_predict = tf.matmul(x, w) + b # 3、建立损失函数,均方误差
with tf.variable_scope("loss"):
loss = tf.reduce_mean(tf.square(y_true - y_predict)) # 4、梯度下降优化损失
with tf.variable_scope("optimizer"):
# 学习率的控制非常重要,如果过大会出现梯度消失/梯度爆炸导致NaN
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss) # 收集需要用于预测的模型
tf.add_to_collection("y_predict", y_predict) # 定义保存模型
saver = tf.train.Saver() # 通过绘画运行程序
with tf.Session() as sess:
# 存在变量时需要初始化
sess.run(tf.global_variables_initializer()) # 加载上次训练的模型结果
if os.path.exists("model/model/checkpoint"):
saver.restore(sess, "model/model/model") # 循环训练
for i in range(100):
# 读取数据(这里自己生成数据)
x_train = sess.run(tf.random_normal([100, 1], mean=1.75, stddev=0.5, name="x")) sess.run(train_op, feed_dict={x: x_train}) # 保存模型
if (i + 1) % 10 == 0:
print("第%d次训练保存,权重:%f, 偏值:%f" % (((i + 1) / 10), w.eval(), b.eval()))
saver.save(sess, "model/model/model")
四、模型的独立加载
def model_load():
with tf.Session() as sess:
# 1、加载模型
saver = tf.train.import_meta_graph("model/model/model.meta")
saver.restore(sess, tf.train.latest_checkpoint("model/model"))
graph = tf.get_default_graph() # 2、获取占位符
x = graph.get_tensor_by_name("data/x:0") # 3、获取权重和偏置
y_predict = graph.get_collection("y_predict")[0] # 4、读取测试数据
x_test = sess.run(tf.random_normal([10, 1], mean=1.75, stddev=0.5, name="x"))
# 5、预测
for i in range(len(x_test)):
predict = sess.run(y_predict, feed_dict={x: [x_test[i]]})
print("第%d个数据,原值:%f, 预测值:%f" % ((i + 1), x_test[i], predict))
Python之TensorFlow的模型训练保存与加载-3的更多相关文章
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- tensorflow模型的保存与加载
模型的保存与加载一般有三种模式:save/load weights(最干净.最轻量级的方式,只保存网络参数,不保存网络状态),save/load entire model(最简单粗暴的方式,把网络所有 ...
- [深度学习] Pytorch(三)—— 多/单GPU、CPU,训练保存、加载模型参数问题
[深度学习] Pytorch(三)-- 多/单GPU.CPU,训练保存.加载预测模型问题 上一篇实践学习中,遇到了在多/单个GPU.GPU与CPU的不同环境下训练保存.加载使用使用模型的问题,如果保存 ...
- pytorch 中模型的保存与加载,增量训练
让模型接着上次保存好的模型训练,模型加载 #实例化模型.优化器.损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam( ...
- (sklearn)机器学习模型的保存与加载
需求: 一直写的代码都是从加载数据,模型训练,模型预测,模型评估走出来的,但是实际业务线上咱们肯定不能每次都来训练模型,而是应该将训练好的模型保存下来 ,如果有新数据直接套用模型就行了吧?现在问题就是 ...
- pytorch_模型参数-保存,加载,打印
1.保存模型参数(gen-我自己的模型名字) torch.save(self.gen.state_dict(), os.path.join(self.gen_save_path, 'gen_%d.pt ...
- fashion_mnist多分类训练,两种模型的保存与加载
from tensorflow.python.keras.preprocessing.image import load_img,img_to_array from tensorflow.python ...
随机推荐
- nRF51822 配对之device_manager_init 调用,以及保证 用户数据存储 的Flash 操作不与device manager 模块冲突
昨天 遇到了一个烦心的问题,被老外客户怼了两句,恼火,很想发火,发现英文不够用,算了,就不跟直肠的鬼佬一般见识.说正事. 最近的一个nRF51822+MT2503 钱包防丢项目,准备接近量产了.昨天做 ...
- Jmeter(四十二)_控制器下遍历一组参数
概述 在接口自动化的过程中,经常遇到需要遍历的参数组.jmeter在中,foreach控制器可以实现遍历参数,但是只能有一个入参.一旦遇到数组,foreach控制器表示我也无能为力... 为了解决这个 ...
- python 运行当前目录下的所有文件
查看当前目录下所有py文件(本身除外run) import os file_list = os.listdir(os.getcwd()) # 获取当前目录下所有的文件名print(file_list ...
- 用户账户——《Python编程从入门到实践》
Web应用程序的核心是让任何用户都能够注册账户并能够使用它,不管用户身处何方 1.让用户能够输入数据 建立用于创建用户的身份验证系统之前,我们先来添加几个页面,让用户能够输入数据.当前,只有超级用户能 ...
- Jmeter5.11安装
jmeter5.11要对应jdk1.8以上版本 1.选择zip后缀进行下载 2.配置环境变量 (1)电脑桌面---->"计算机"图标---->鼠标右键选择"属 ...
- 【转】Android系统中Fastboot和Recovery所扮演的角色。
Android 刷机过程中 Fastboot 和 Recovery 的作用是什么? 自己在知乎的一篇回答,,现在翻出来放到博客,希望可以解答更多人的疑惑,抑或有什么理解上的错误,也望网友指出~ 今天恰 ...
- 利用detours写了一个工具用于instrument任意指定dll的任意指定函数入口
目录 wiki Disas Dtest Simple withdll load一个dll到指定进程 tracebld显示相关进程涉及的文件读写操作 My Instrumentation tool: w ...
- WebGL学习笔记(一):理解基本概念和渲染管线
WebGL 是以 OpenGL ES 2.0 为基础的 3D 编程应用接口. 渲染管线(图形流水线) 渲染管线是指将数据从3D场景转换成2D图像,最终在屏幕上显示出来的总过程.它分为几个阶段:应用阶段 ...
- apicloud打包成apk
前言:本文是打包vue项目,其他项目也是这样打包 页面的开发过程跟我们平时开发一样,利用vue把页面全部完成,最后进行npm run build将项目打包. 接下来就是apicloud打包的过程,首先 ...
- Swift编码总结3
1.编码转换: dataString.addingPercentEncoding(withAllowedCharacters: .urlQueryAllowed) ?? "" re ...