tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域、tensorboard收集、模型保存与加载、自定义命令行参数
1、知识点
"""
1、训练过程:
1、准备好特征和目标值
2、建立模型,随机初始化权重和偏置; 模型的参数必须要使用变量
3、求损失函数,误差为均方误差
4、梯度下降去优化损失过程,指定学习率 2、Tensorflow运算API:
1、矩阵运算:tf.matmul(x,w)
2、平方:tf.square(error)
3、均值:tf.reduce_mean(error)
4、梯度下降API: tf.train.tf.train.GradientDescentOptimizer(learning_rate)
learning_rate:学习率
minimize(lose):优化最小损失
return:梯度下降op
3、注意项:
1、tf.Variable()中的trainable表示为变量在训练过程可变
2、学习率设置很大时,可能会出现权重和偏置为NAN,这种现象表现叫梯度爆炸
解决方法:1、重新设计网络 2、调整学习率 3、使用梯度截断 4、使用激活函数 4、变量作用域:主要用于tensorboard查看,同时使代码更加清晰 with tf.variable_scope("data"): 5、添加权重、参数、损失值等在tensoroard观察的情况:
1、收集tensor变量 tf.summary.scalar('losses', loss)、tf.summary.histogram('weight',weight)
2、合并变量并写入事件文件:merged = tf.summary.merge_all()
3、运行合并的tensor:summary = sess.run(merged)、fileWriter.add_summary(summary,i) 6、模型保存与加载: tf.train.Saver(var_list=None,max_to_keep=5)
var_list:指定将要保存和还原的变量。它可以作为一个dict或一个列表传递.
max_to_keep:指示要保留的最近检查点文件的最大数量。
创建新文件时,会删除较旧的文件。如果无或0,则保留所有
检查点文件。默认为5(即保留最新的5个检查点文件。
a)例如:saver.save(sess, '/tmp/ckpt/test/model')
saver.restore(sess, '/tmp/ckpt/test/model')
保存文件格式:checkpoint文件 b)模型加载:
if os.path.exists('./ckpt/checkpoint'):
saver.restore(sess,'./ckpt/model') 7、自定义命令行参数:
1、首先定义有哪些参数需要在运行时候指定
2、程序当中获取定义命令行参数
3、运行 python *.py --max_step=500 --model_dir='./ckpt/model'
本例执行命令:python tensorflow实现线性回归.py --max_step=50 --model_dir="./ckpt/model" """
2、代码
# coding = utf-8
import tensorflow as tf
import os #自定义命令行参数
tf.app.flags.DEFINE_integer("max_step",100,"模型训练的步数")
tf.app.flags.DEFINE_string("model_dir"," ","模型文件加载路径")
#定义获取命令行参数名字
FLAGS = tf.app.flags.FLAGS
def myLinear():
"""
自实现一个线性回归预测
:return:
"""
#定义作用域
with tf.variable_scope("data"):
#1、准备数据,特征值
x = tf.random_normal([100,1],mean=1.75,stddev=0.5)
#目标值。矩阵相乘,必须是二维的
y_true = tf.matmul(x,[[0.7]])+0.8 with tf.variable_scope("model"):
#2、建立线性模型 y = wx+b ,随机给定w和b的值,必须定义成变量
weight = tf.Variable(tf.random_normal([1,1],mean=0.0,stddev=1.0,name="w"))
bias = tf.Variable(0.0,name="b")
y_predict = tf.matmul(x,weight)+bias with tf.variable_scope("loss"):
#3、建立损失函数,均方误差
loss = tf.reduce_mean(tf.square(y_true-y_predict))
with tf.variable_scope("optimizer"):
#4、梯度下降优化损失
train_op = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(loss) ##############模型保存################
with tf.variable_scope("save_model"):
saver = tf.train.Saver(); # 初始化变量
init_op = tf.global_variables_initializer()
####################收集变量#########################
# 收集tensor变量
tf.summary.scalar('losses', loss)
tf.summary.histogram('weight',weight) #合并变量并写入事件文件
merged = tf.summary.merge_all()
#通过会话运行程序
with tf.Session() as sess:
#必须要运行初始化变量
sess.run(init_op) #打印随机最先初始化的权重和偏置
print("随机初始化的参数权重为:%f,偏置为:%f" % (weight.eval(),bias.eval()))
# 建立事件文件
fileWriter = tf.summary.FileWriter("./tmp", graph=sess.graph) ###########加载模型,覆盖之前的参数##############
if os.path.exists('./ckpt/checkpoint'):
#saver.restore(sess,'./ckpt/model')
saver.restore(sess, FLAGS.model_dir) #循环优化
for i in range(FLAGS.max_step):
#运行优化
sess.run(train_op)
#运行合并的tensor
summary = sess.run(merged)
fileWriter.add_summary(summary,i)
print("第%d次优化参数权重为:%f,偏置为:%f" % (i,weight.eval(), bias.eval()))
################模型保存##############
# if i%1000==0:
# #saver.save(sess,'./ckpt/model')
# saver.save(sess,FLAGS.model_dir)
saver.save(sess, FLAGS.model_dir)
return None if __name__ == '__main__':
myLinear()
tensorflow实现线性回归、以及模型保存与加载的更多相关文章
- tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署
TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...
- sklearn模型保存与加载
sklearn模型保存与加载 sklearn模型的保存和加载API 线性回归的模型保存加载案例 保存模型 sklearn模型的保存和加载API from sklearn.externals impor ...
- [PyTorch 学习笔记] 7.1 模型保存与加载
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...
- 转 tensorflow模型保存 与 加载
使用tensorflow过程中,训练结束后我们需要用到模型文件.有时候,我们可能也需要用到别人训练好的模型,并在这个基础上再次训练.这时候我们需要掌握如何操作这些模型数据.看完本文,相信你一定会有收获 ...
- TensorFlow构建卷积神经网络/模型保存与加载/正则化
TensorFlow 官方文档:https://www.tensorflow.org/api_guides/python/math_ops # Arithmetic Operators import ...
- Tensorflow模型保存与加载
在使用Tensorflow时,我们经常要将以训练好的模型保存到本地或者使用别人已训练好的模型,因此,作此笔记记录下来. TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提 ...
- TensorFlow的模型保存与加载
import os os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' import tensorflow as tf #tensorboard --logdir=&qu ...
- tensorflow 之模型的保存与加载(一)
怎样让通过训练的神经网络模型得以复用? 本文先介绍简单的模型保存与加载的方法,后续文章再慢慢深入解读. #!/usr/bin/env python3 #-*- coding:utf-8 -*- ### ...
- TensorFlow保存、加载模型参数 | 原理描述及踩坑经验总结
写在前面 我之前使用的LSTM计算单元是根据其前向传播的计算公式手动实现的,这两天想要和TensorFlow自带的tf.nn.rnn_cell.BasicLSTMCell()比较一下,看看哪个训练速度 ...
随机推荐
- IoU-Net论文笔记
原论文标题:Acquisition of Localization Confidence for Accurate Object Detection 1. 前言 Megvii在ECCV 2018上的一 ...
- 运行TensorFlow代码时报错
运行TensorFlow代码时报错 错误信息ImportError: libcublas.so.10.0: cannot open shared object file 原因:TensorFlow版本 ...
- Zookeeper01——zk的基本信息和安装
一.Zookeeper的基本信息 1.1背景 无论在前面,我们学习hdfs,还是学习redis集群,我们都会使用到一个zookeeper进行选举.这导致了Redis的产生. 我们知道,在先前我们使用Z ...
- 性能篇——函数调用结果的 LRU 缓存
1. 应用场景: 多次调用同一函数 2. 普通写法: def say(name): print("hellow:%s"%name) now = datetime.datetime. ...
- LeetCode01 - 两数之和(Java 实现)
LeetCode01 - 两数之和(Java 实现) 来源:力扣(LeetCode) 链接:https://leetcode-cn.com/problems/two-sum 题目描述 给定一个整数数组 ...
- MySQL 关于视图的操作
-- 视图就是一条select 语句 执行后返回结果集,是一种虚拟表,是一个逻辑表 -- 方便操作,减少复杂的SQL语句,增加可读性,更加安全一些 create view demo_view as s ...
- Linux - TCP/IP网络协议基础
1.0 Tcp / IP 背景介绍 上世纪70年代,随着计算机的发展,人们意识到如果想要发挥计算机的更大作用,就要讲世界各地的计算机连接起来. 但是简单的连接时不够的,因为计算机之间无法沟通.因此设计 ...
- django环境配置(基于命令行安装)
一.django简介 Python服务端开发框架,Django是一个开放源代码的Web应用框架,由Python写成,Django采用了MVC的软件设计模式,即模型M,视图V和控制器C 二.安装配置dj ...
- java上传超大文件解决方案
用JAVA实现大文件上传及显示进度信息 ---解析HTTP MultiPart协议 (本文提供全部源码下载,请访问 https://github.com/1269085759/up6-jsp-mysq ...
- 2018 焦作网络赛 K Transport Ship ( 二进制优化 01 背包 )
题目链接 题意 : 给出若干个物品的数量和单个的重量.问你能不能刚好组成总重 S 分析 : 由于物品过多.想到二进制优化 其实这篇博客就是存个二进制优化的写法 关于二进制优化的详情.百度一下有更多资料 ...