环境Anaconda

废话不多说,关键看代码

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' tf.app.flags.DEFINE_integer("max_step", 300, "训练模型的步数")
FLAGS = tf.app.flags.FLAGS def linear_regression():
'''
自实现线性回归
:return:
'''
#1.准备100个样本 特征值X,目标值y_true with tf.variable_scope("original_data"):
#mean是平均值
#stddev代表方差
X = tf.random_normal(shape=(100,1),mean=0,stddev=1) y_true = tf.matmul(X,[[0.8]])+0.7 #2.建立线性模型:
with tf.variable_scope("linear_model"):
weigh = tf.Variable(initial_value=tf.random_normal(shape=(1,1)))
bias = tf.Variable(initial_value=tf.random_normal(shape=(1,1))) y_predict = tf.matmul(X,weigh)+bias # 3 确定损失函数
#均方误差((y-y_repdict)^2)/m = 平均每一个样本的误差
with tf.variable_scope("loss"):
error = tf.reduce_mean(tf.square(y_predict-y_true)) #4梯度下降优化损失:需要指定学习率
with tf.variable_scope("gd_optimizer"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error) #收集变量
tf.summary.scalar("error",error)
tf.summary.histogram("weights",weigh)
tf.summary.histogram("bias",bias) #合并变量
merge = tf.summary.merge_all() #初始化变量
init = tf.global_variables_initializer() #创建一个saver
saver = tf.train.Saver()
#开启会话进行训练
with tf.Session() as sess:
#初始化变量op
sess.run(init)
print("随机初始化的权重为{},偏执为{}".format(weigh.eval(),bias.eval())) # print(weigh.eval(), bias.eval())
# saver.restore(sess,"./checkpoint/linearregression")
# print(weigh.eval(),bias.eval())
#创建文件事件
file_writer = tf.summary.FileWriter(logdir="./",graph=sess.graph)
#训练模型 for i in range(FLAGS.max_step):
sess.run(optimizer)
summary = sess.run(merge)
file_writer.add_summary(summary,i)
print("第{}步的误差为{},权重为{},偏执为{}".format(i,error.eval(),weigh.eval(),bias.eval()))
#checkpoint:检查点文件
#tf.keras:h5
# saver.save(sess,"./checkpoint/linearregression") if __name__ == '__main__':
linear_regression()

  部分结果输出:

第294步的误差为7.031372661003843e-06,权重为[[0.7978232]],偏执为[[0.69850117]]
第295步的误差为5.66376502320054e-06,权重为[[0.7978593]],偏执为[[0.6985256]]
第296步的误差为5.646746103593614e-06,权重为[[0.7978932]],偏执为[[0.698556]]
第297步的误差为5.33674938196782e-06,权重为[[0.7979515]],偏执为[[0.69858944]]
第298步的误差为5.233380761637818e-06,权重为[[0.79799336]],偏执为[[0.6986183]]
第299步的误差为5.024347956350539e-06,权重为[[0.7980382]],偏执为[[0.6986382]]

  

如何用TensorFlow实现线性回归的更多相关文章

  1. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  2. 一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)

    雷锋网按:本文作者陆池,原文载于作者个人博客,雷锋网已获授权. 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实例,这个星期就用 ...

  3. 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件

    这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件.最后会将如何重新加载这个pb文件. 首先先放出PO主的github: https://github.com/ppp ...

  4. 从原理到代码:大牛教你如何用 TensorFlow 亲手搭建一套图像识别模块 | AI 研习社

    从原理到代码:大牛教你如何用 TensorFlow 亲手搭建一套图像识别模块 | AI 研习社 PPT链接: https://pan.baidu.com/s/1i5Jrr1N 视频链接: https: ...

  5. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  6. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  7. 利用TensorFlow实现线性回归模型

    准备数据: import numpy as np import tensorflow as tf import matplotlib.pylot as plt # 随机生成1000个点,围绕在y=0. ...

  8. tensorflow实现线性回归总结

    1.知识点 """ 模拟一个y = 0.7x+0.8的案例 报警: 1.initialize_all_variables (from tensorflow.python. ...

  9. TensorFlow多元线性回归实现

    多元线性回归的具体实现 导入需要的所有软件包:   因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏置结合起来.为此定义函数 appe ...

随机推荐

  1. 记录一个不同的流媒体网站实现方法,和用Python爬虫爬它的坑

    今天找到一片电影,想把它下载下来. 先开Networks工具分析一下: 初步分析发现,视频加载时会拉取TS格式的文件,推测这是一个m3u8的索引,记录着几百段TS文件,这样方便快进时加载. 但是实际分 ...

  2. VBScript - 动态 Array 实现方法大全!

    记录一些方法,关于 VBScript 中,动态 Array 的实现 ,也适用于 VBA, 很久以前,写 VBA 的时候,就觉得使用 Array 很不方便,因为大小固定, 当时想的是,要是 Array ...

  3. VBScript 打开含有"空格"的路径 (Open Path with Space)

    记录,VBScript 如何打开,含有"空格"的路径.这个问题和常见,却总是忘! 直接上代码了,多说无益. Option Explicit Dim obj Dim path Set ...

  4. 为何给CheckBox设置了checked属性还是没有勾选,行内样式都显示了checked

    为何给CheckBox设置了checked属性还是没有勾选,行内样式都显示了checked 正常情况下我们设置给CheckBox一个checked属性后一般都会选中 然而我今天在做案例的时候却遇到了类 ...

  5. sentry使用

    开篇-Sentry是什么 Sentry是开源错误跟踪,帮助开发人员实时监控和修复崩溃.不断重复.提高效率.改善用户体验. 这篇文章的作用 记录这篇文章是想分享一下,因为本人在配置时因为邮件服务花费了很 ...

  6. Springmvc与Struts区别?

    在一个技术群里看到机器人解释的二者区别,在此Mark下. 一.框架机制 spring mvc 和 struts2的加载机制不同:spring mvc的入口是servlet,而struts2是filte ...

  7. iOS 构建静态库

    一..a 文件静态库打包 打开 Xcode 创建一个新的 Static Library 工程,取名 MyStaticLibrary. 创建工程完毕后,系统自动创建了一个同名类,添加一个方法用于测试. ...

  8. What is MongoDB and For What?

    1.MongoDB是什么? MongoDB是一款为web应用程序和互联网基础设施设计的数据库管理系统.没错MongoDB就是数据库,是NoSQL类型的数据库 2.为什么要用MongoDB? (1)Mo ...

  9. sql MYSQL主从配置

    MYSQL主从配置 1.1 部署环境 主(master_mysql): 192.168.1.200 OS:CentOS 6.5 从(slave_mysql): 192.168.1.201 OS:Cen ...

  10. M_map(五)

    一.圆形区域的画图 1. clear all LATLIMS=[14 22]; LONLIMS=[108 118];%南海边界范围 m_proj('miller','lon',LONLIMS,'lat ...