TensorFlow笔记-模型的保存,恢复,实现线性回归
模型的保存
tf.train.Saver(var_list=None,max_to_keep=5)
•var_list:指定将要保存和还原的变量。它可以作为一个
dict或一个列表传递.
•max_to_keep:指示要保留的最近检查点文件的最大数量。
创建新文件时,会删除较旧的文件。如果无或0,则保留所有
检查点文件。默认为5(即保留最新的5个检查点文件。)
saver = tf.train.Saver()
saver.save(sess, "")
模型的恢复
恢复模型的方法是restore(sess, save_path),save_path是以前保存参数的路径,我们可以使用tf.train.latest_checkpoint来获取最近的检查点文件(也恶意直接写文件目录)
if os.path.exists("tmp/ckpt/checkpoint"):
saver.restore(sess,"")
print("恢复模型")
自定义命令行参数
import tensorflow as tf FLAGS = tf.app.flags.FLAGS tf.app.flags.DEFINE_string('data_dir', '/tmp/tensorflow/mnist/input_data',
"""数据集目录""")
tf.app.flags.DEFINE_integer('max_steps', 2000,
"""训练次数""")
tf.app.flags.DEFINE_string('summary_dir', '/tmp/summary/mnist/convtrain',
"""事件文件目录""") def main(argv):
print(FLAGS.data_dir)
print(FLAGS.max_steps)
print(FLAGS.summary_dir)
print(argv) if __name__=="__main__":
tf.app.run()
线性回归
准备数据
with tf.variable_scope("data"):
# 1、准备数据,x 特征值 [100, 1] y 目标值[100]
x = tf.random_normal([100, 1], mean=1.75, stddev=0.5, name="x_data")
# 矩阵相乘必须是二维的
y_true = tf.matmul(x, [[0.7]]) + 0.8
构建模型
with tf.variable_scope("model"):
# 2、建立线性回归模型 1个特征,1个权重, 一个偏置 y = x w + b
# 随机给一个权重和偏置的值,让他去计算损失,然后再当前状态下优化
# 用变量定义才能优化
weight = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name="w")
bias = tf.Variable(0.0, name="b")
y_predict = tf.matmul(x, weight) + bias
构造损失函数
with tf.variable_scope("loss"):
# 3、建立损失函数,均方误差
loss = tf.reduce_mean(tf.square(y_true - y_predict))
利用梯度下降
with tf.variable_scope("optimizer"):
# 4、梯度下降优化损失 leaning_rate: 0 ~ 1, 2, 3,5, 7, 10
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
源码
import tensorflow as tf
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = ''
# 在这里立flag
tf.app.flags.DEFINE_integer("max_step",100,"模型训练的步数")
tf.app.flags.DEFINE_string("model_dir","tmp/summary/test","模型文件的加载路径") FLAGS = tf.app.flags.FLAGS
def myregression():
with tf.variable_scope("data"):
x = tf.random_normal([100, 1], mean=1.75, stddev=0.5)
y_true = tf.matmul(x, [[0.7]]) + 0.8
with tf.variable_scope("model"):
# 权重 trainable 指定权重是否随着session改变
weight = tf.Variable(tf.random_normal([int(x.shape[1]), 1], mean=0, stddev=1), name="w")
# 偏置项
bias = tf.Variable(0.0, name='b')
# 构造y函数
y_predict = tf.matmul(x, weight) + bias
with tf.variable_scope("loss"):
# 定义损失函数
loss = tf.reduce_mean(tf.square(y_true - y_predict))
with tf.variable_scope("optimizer"):
# 使用梯度下降进行求解
train_op = tf.train.GradientDescentOptimizer(0.1).minimize((loss))
# 1.收集tensor
tf.summary.scalar("losses", loss)
tf.summary.histogram("weights", weight)
# 2.定义合并tensor的op
merged = tf.summary.merge_all()
# 定义一个保存模型的op
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
# import matplotlib.pyplot as plt
# plt.scatter(x.eval(), y_true.eval())
# plt.show()
print("初始化的权重:%f,偏置项:%f" % (weight.eval(), bias.eval()))
# 建立事件文件
filewriter = tf.summary.FileWriter('./tmp/summary/test/', graph=sess.graph)
# 加载模型
if os.path.exists("tmp/ckpt/checkpoint"):
saver.restore(sess,FLAGS.model_dir)
print("加载")
n = 0
while loss.eval() > 1e-6:
n += 1
if(n==FLAGS.max_step):
break
sess.run(train_op)
summary = sess.run(merged)
filewriter.add_summary(summary, n)
print("第%d次权重:%f,偏置项:%f" % (n, weight.eval(), bias.eval()))
saver.save(sess, FLAGS.model_dir)
return weight, bias myregression()
# x_min,x_max = np.min(x.eval()),np.max(x.eval())
# tx = np.arange(x_min,x_max,100)
TensorFlow笔记-模型的保存,恢复,实现线性回归的更多相关文章
- Tensorflow Learning1 模型的保存和恢复
CKPT->pb Demo 解析 tensor name 和 node name 的区别 Pb 的恢复 CKPT->pb tensorflow的模型保存有两种形式: 1. ckpt:可以恢 ...
- Tensorflow学习笔记----模型的保存和读取(4)
一.模型的保存:tf.train.Saver类中的save TensorFlow提供了一个一个API来保存和还原一个模型,即tf.train.Saver类.以下代码为保存TensorFlow计算图的方 ...
- Python之TensorFlow的模型训练保存与加载-3
一.TensorFlow的模型保存和加载,使我们在训练和使用时的一种常用方式.我们把训练好的模型通过二次加载训练,或者独立加载模型训练.这基本上都是比较常用的方式. 二.模型的保存与加载类型有2种 1 ...
- tensorflow 之模型的保存与加载(三)
前面的两篇博文 第一篇:简单的模型保存和加载,会包含所有的信息:神经网络的op,node,args等; 第二篇:选择性的进行模型参数的保存与加载. 本篇介绍,只保存和加载神经网络的计算图,即前向传播的 ...
- tensorflow 之模型的保存与加载(二)
上一遍博文提到 有些场景下,可能只需要保存或加载部分变量,并不是所有隐藏层的参数都需要重新训练. 在实例化tf.train.Saver对象时,可以提供一个列表或字典来指定需要保存或加载的变量. #!/ ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- tensorflow:模型的保存和训练过程可视化
在使用tf来训练模型的时候,难免会出现中断的情况.这时候自然就希望能够将辛辛苦苦得到的中间参数保留下来,不然下次又要重新开始. 保存模型的方法: #之前是各种构建模型graph的操作(矩阵相乘,sig ...
- 【TensorFlow】TensorFlow基础 —— 模型的保存读取与可视化方法总结
TensorFlow提供了一个用于保存模型的工具以及一个可视化方案 这里使用的TensorFlow为1.3.0版本 一.保存模型数据 模型数据以文件的形式保存到本地: 使用神经网络模型进行大数据量和复 ...
- tensorflow模型的保存与恢复
1.tensorflow中模型的保存 创建tf.train.saver,使用saver进行保存: saver = tf.train.Saver() saver.save(sess, './traine ...
随机推荐
- QString之simplified()用于读取数据、规范数据,非常方便
在工程项目开发中,遇到这么个问题:手工计入文件中的数据,每行有三个,前两个是数字,最后一个是标识,现在把这3个数据提取出来. 一提取就出现问题了:由于手工导入,数据间使用空白间隔,有可能是一个空格,有 ...
- [收录] Highcharts-ng —— AngularJS 的图表扩展
原文:http://www.tuicool.com/articles/u6VZJjQ Highcharts-ng 是一个 AngularJS 的指令扩展,实现了在AngularJS 应用中集成High ...
- visual studio添加docker支持简记
很久以前学过一段时间的docker,那时装了电脑卡得受不了,学了一会就卸载了,最近电脑又装上了docker,感觉好像没有以前这么卡了,还是同一台电脑surface pro4, 试了一下visual s ...
- QT_NO_CAST_FROM_ASCII这个宏的,禁用一切来自双引号字符串字面量传入QString(有2种解决方法)
这两天制作了两个Qt Creator增强套装的两个插件,其实也是非常简单的,但是其实花了我超过四天的时间,为什么呢?因为我之前很长一段时间都是在Linux下开发的,一切安好,没有任何问题,但是到了Wi ...
- Realm_King 之 XPDL(XML Process Definition Language)
XPDL(XML Process Definition Language)是由Workflow Management Coalition(简写为:WfMC)所提出的一个标准化规格,使用XML文件让不同 ...
- 宜信开源|数据库审核软件Themis的规则解析与部署攻略
一.介绍 Themis是宜信公司DBA团队开发的一款数据库审核产品,可帮助DBA.开发人员快速发现数据库质量问题,提升工作效率.其名称源自希腊神话中的正义与法律女神.项目取此名称,寓意此平台对数据库质 ...
- python bmp image injection
1. 将原BMP文件的第三,第四字节替换为\x2F\x2A, 对应js中的注释符号/*BMP文件的第三.四.五.六字节表示BMP文件的大小2. 在BMP文件末尾添加(1)\xFF(2)\x2A\x2F ...
- 【python3两小时快速入门】入门笔记03:简单爬虫+多线程爬虫
作用,之间将目标网页保存金本地 1.爬虫代码修改自网络,目前运行平稳,博主需要的是精准爬取,数据量并不大,暂未加多线程. 2.分割策略是通过查询条件进行分类,循环启动多条线程. 1.单线程简单爬虫(第 ...
- 【hibernate-validator+SpringMVC】后台参数校验框架
hibernate-validator+SpringMVC 简介:简单说,就是对Entity进行校验. 1.导包,没有很严谨的对应关系,所以我用了比较新的版本,支持更多的注解. <depende ...
- python的数据类型之字符串(一)
字符串(str) 双引号或者单引号中的数据,就是字符串. 注意事项 1.反斜杠可以用来转义,使用r可以让反斜杠不发生转义. 2.字符串可以用+运算符连接在一起,用*运算符重复. 3.Python中的字 ...