转载:tensorflow保存训练后的模型
1、使用tf.train.Saver.save()方法保存模型
tf.train.Saver.save(sess, save_path, global_step=None, latest_filename=None, meta_graph_suffix='meta', write_meta_graph=True, write_state=True)
- sess: 用于保存变量操作的会话。
- save_path: String类型,用于指定训练结果的保存路径。
- global_step: 如果提供的话,这个数字会添加到save_path后面,用于构建checkpoint文件。这个参数有助于我们区分不同训练阶段的结果。
2、使用tf.train.Saver.restore方法价值模型
tf.train.Saver.restore(sess, save_path)
- sess: 用于加载变量操作的会话。
- save_path: 同保存模型是用到的的save_path参数。
下面通过一个代码演示这两个函数的使用方法
import tensorflow as tf
import numpy as np
x = tf.placeholder(tf.float32, shape=[None, 1])
y = 4 * x + 4
w = tf.Variable(tf.random_normal([1], -1, 1))
b = tf.Variable(tf.zeros([1]))
y_predict = w * x + b
loss = tf.reduce_mean(tf.square(y - y_predict))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss)
isTrain = False
train_steps = 100
checkpoint_steps = 50
checkpoint_dir = ''
saver = tf.train.Saver() # defaults to saving all variables - in this case w and b
x_data = np.reshape(np.random.rand(10).astype(np.float32), (10, 1))
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
if isTrain:
for i in xrange(train_steps):
sess.run(train, feed_dict={x: x_data})
if (i + 1) % checkpoint_steps == 0:
saver.save(sess, checkpoint_dir + 'model.ckpt', global_step=i+1)
else:
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print(sess.run(w))
print(sess.run(b))
转载:tensorflow保存训练后的模型的更多相关文章
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 直接调用 C++ 接口实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过直 ...
- 在 C/C++ 中使用 TensorFlow 预训练好的模型—— 间接调用 Python 实现
现在的深度学习框架一般都是基于 Python 来实现,构建.训练.保存和调用模型都可以很容易地在 Python 下完成.但有时候,我们在实际应用这些模型的时候可能需要在其他编程语言下进行,本文将通过 ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
- Tensorflow使用训练好的模型进行测试,发现计算速度越来越慢
实验时要对多个NN模型进行对比,依次加载直到第8个模型时,发现运行速度明显变慢而且电脑开始卡顿,查看内存占用90+%. 原因:使用过的NN模型还会保存在内存,继续加载一方面使新模型加载特别特别慢,另一 ...
- libsvm 训练后的模型参数讲解(转)
主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数的意义都是神马?以及如果通过model得到相应模型的表达式,这里主要以分类问题为例子.测试数据使用的是li ...
- [转]libsvm 训练后的模型参数讲解
http://blog.sina.com.cn/s/blog_6646924501018fqc.html 主要就是讲解利用libsvm-mat工具箱建立分类(回归模型)后,得到的模型model里面参数 ...
- keras 保存训练的最佳模型
转自:https://anifacc.github.io/deeplearning/machinelearning/python/2017/08/30/dlwp-ch14-keep-best-mode ...
- Tensorflow 用训练好的模型预测
本节涉及点: 从命令行参数读取需要预测的数据 从文件中读取数据进行预测 从任意字符串中读取数据进行预测 一.从命令行参数读取需要预测的数据 训练神经网络是让神经网络具备可用性,真正使用神经网络时,需要 ...
- TensorFlow 模型优化工具包 — 训练后整型量化
模型优化工具包是一套先进的技术工具包,可协助新手和高级开发者优化待部署和执行的机器学习模型.自推出该工具包以来, 我们一直努力降低机器学习模型量化的复杂性 (https://www.tensorfl ...
随机推荐
- [IOT] - 在树莓派的 Raspbian 系统中安装 .Net Core 3.0 运行环境
之前在 Docker 中配置过 .Net Core 运行环境,地址:[IOT] - Raspberry Pi 4 Model B 系统初始化,Docker CE + .Net Core 开发环境配置 ...
- 测试wss是否连接企业微信成功
企业微信考勤机有时候无法连接,可以使用下面代码来测试下网络情况 <html> <head> <title>测试wss</title> </hea ...
- 谈谈 Callable 任务是怎么运行的?它的执行结果又是怎么获取的?
谈谈 Callable 任务是怎么运行的?它的执行结果又是怎么获取的? 向线程池提交Callable任务,会创建一个新线程(执行任务的线程)去执行这个Callable任务,但是通过Future#get ...
- SpringBoot+Mybatis+Druid批量更新 multi-statement not allow异常
本文链接:https://blog.csdn.net/weixin_43947588/article/details/90109325 注:该文是本博主记录学习之用,没有太多详细的讲解,敬请谅解! ...
- TCP 协议 精解
http://www.cnblogs.com/sunev/archive/2012/06/23/2559389.html
- moodle3.7上传中文文件,无法引用,图片不显示
初始安装moodle3.7 上传图片,名称为中文时,无法引用图片,图片不显示.这里采用修改moodle根目录下的config.php文件, 添加了变量$CFG->slasharguments = ...
- Java之路---Day17(数据结构)
2019-11-04-23:03:13 目录: 1.常用的数据结构 2.栈 3.队列 4.数组 5.链表 6.红黑树 常用的数据结构: 包含:栈.队列.数组.链表和红黑树 栈: 栈:stack,又称堆 ...
- 【转载】C#中Datatable修改列名
在C#的数据表格DataTable操作过程中,有时候会遇到修改DataTable数据表格的列名的需求,其实C#中的DataTable的列名支持手动修改调整,可以通过DataTable类的Columns ...
- TinyMCE基础配置
选择器配置 插件配置 工具栏配置 菜单配置 皮肤配置 编辑区宽高配置 编辑区样式配置 隐藏状态栏 选择器配置 选择器就是CSS选择器,它告诉TinyMCE哪个元素是可编辑的. 示例: tinymce. ...
- js 数组 添加或删除 元素 splice 创建一个新的数组,新数组中的元素是通过检查指定数组中符合条件的所有元素 filter
里面可以用 箭头函数 splice 删除 增加 数组 中元素 操作数组 filter 创建新数组 检查指定数组中符合条件的所有元素