环境Anaconda

废话不多说,关键看代码

import tensorflow as tf
import os
os.environ['TF_CPP_MIN_LOG_LEVEL']='2' tf.app.flags.DEFINE_integer("max_step", 300, "训练模型的步数")
FLAGS = tf.app.flags.FLAGS def linear_regression():
'''
自实现线性回归
:return:
'''
#1.准备100个样本 特征值X,目标值y_true with tf.variable_scope("original_data"):
#mean是平均值
#stddev代表方差
X = tf.random_normal(shape=(100,1),mean=0,stddev=1) y_true = tf.matmul(X,[[0.8]])+0.7 #2.建立线性模型:
with tf.variable_scope("linear_model"):
weigh = tf.Variable(initial_value=tf.random_normal(shape=(1,1)))
bias = tf.Variable(initial_value=tf.random_normal(shape=(1,1))) y_predict = tf.matmul(X,weigh)+bias # 3 确定损失函数
#均方误差((y-y_repdict)^2)/m = 平均每一个样本的误差
with tf.variable_scope("loss"):
error = tf.reduce_mean(tf.square(y_predict-y_true)) #4梯度下降优化损失:需要指定学习率
with tf.variable_scope("gd_optimizer"):
optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.01).minimize(error) #收集变量
tf.summary.scalar("error",error)
tf.summary.histogram("weights",weigh)
tf.summary.histogram("bias",bias) #合并变量
merge = tf.summary.merge_all() #初始化变量
init = tf.global_variables_initializer() #创建一个saver
saver = tf.train.Saver()
#开启会话进行训练
with tf.Session() as sess:
#初始化变量op
sess.run(init)
print("随机初始化的权重为{},偏执为{}".format(weigh.eval(),bias.eval())) # print(weigh.eval(), bias.eval())
# saver.restore(sess,"./checkpoint/linearregression")
# print(weigh.eval(),bias.eval())
#创建文件事件
file_writer = tf.summary.FileWriter(logdir="./",graph=sess.graph)
#训练模型 for i in range(FLAGS.max_step):
sess.run(optimizer)
summary = sess.run(merge)
file_writer.add_summary(summary,i)
print("第{}步的误差为{},权重为{},偏执为{}".format(i,error.eval(),weigh.eval(),bias.eval()))
#checkpoint:检查点文件
#tf.keras:h5
# saver.save(sess,"./checkpoint/linearregression") if __name__ == '__main__':
linear_regression()

  部分结果输出:

第294步的误差为7.031372661003843e-06,权重为[[0.7978232]],偏执为[[0.69850117]]
第295步的误差为5.66376502320054e-06,权重为[[0.7978593]],偏执为[[0.6985256]]
第296步的误差为5.646746103593614e-06,权重为[[0.7978932]],偏执为[[0.698556]]
第297步的误差为5.33674938196782e-06,权重为[[0.7979515]],偏执为[[0.69858944]]
第298步的误差为5.233380761637818e-06,权重为[[0.79799336]],偏执为[[0.6986183]]
第299步的误差为5.024347956350539e-06,权重为[[0.7980382]],偏执为[[0.6986382]]

  

如何用TensorFlow实现线性回归的更多相关文章

  1. 深度学习入门实战(二)-用TensorFlow训练线性回归

    欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...

  2. 一文详解如何用 TensorFlow 实现基于 LSTM 的文本分类(附源码)

    雷锋网按:本文作者陆池,原文载于作者个人博客,雷锋网已获授权. 引言 学习一段时间的tensor flow之后,想找个项目试试手,然后想起了之前在看Theano教程中的一个文本分类的实例,这个星期就用 ...

  3. 如何用Tensorflow训练模型成pb文件和和如何加载已经训练好的模型文件

    这篇薄荷主要是讲了如何用tensorflow去训练好一个模型,然后生成相应的pb文件.最后会将如何重新加载这个pb文件. 首先先放出PO主的github: https://github.com/ppp ...

  4. 从原理到代码:大牛教你如何用 TensorFlow 亲手搭建一套图像识别模块 | AI 研习社

    从原理到代码:大牛教你如何用 TensorFlow 亲手搭建一套图像识别模块 | AI 研习社 PPT链接: https://pan.baidu.com/s/1i5Jrr1N 视频链接: https: ...

  5. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

  6. TensorFlow简单线性回归

    TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...

  7. 利用TensorFlow实现线性回归模型

    准备数据: import numpy as np import tensorflow as tf import matplotlib.pylot as plt # 随机生成1000个点,围绕在y=0. ...

  8. tensorflow实现线性回归总结

    1.知识点 """ 模拟一个y = 0.7x+0.8的案例 报警: 1.initialize_all_variables (from tensorflow.python. ...

  9. TensorFlow多元线性回归实现

    多元线性回归的具体实现 导入需要的所有软件包:   因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏置结合起来.为此定义函数 appe ...

随机推荐

  1. 解决GPU显存未释放问题

    前言 今早我想用多块GPU测试模型,于是就用了PyTorch里的torch.nn.parallel.DistributedDataParallel来支持用多块GPU的同时使用(下面简称其为Dist). ...

  2. rabbitmq++:rabbitmq 三种常用的交换机

    更多 rabbitmq 介绍 首先先介绍一个简单的一个消息推送到接收的流程,提供一个简单的图: 黄色的圈圈就是我们的消息推送服务,将消息推送到 中间方框里面也就是 rabbitMq的服务器: 然后经过 ...

  3. Downgrade ASM DATABASE_COMPATIBILITY (from 11.2.0.4.0 to 11.2.0.0.0) on 12C CRS stack.

    使用Onecommand完成快速Oracle 12c RAC部署后 发现ASM database compatibilty无法设置,默认为11.2.0.4.0. 由于我们还有些数据库低于这个版本,所以 ...

  4. Polya 定理相关题目

    参考知识链接   关于枚举旋转置换:   前两题都是枚举了 n 种旋转, 但这个可以优化到\(O(\sqrt{n})\) (这个其实是基本操作). 考虑到每个循环节的长度都是 n 的因数, 所以可以枚 ...

  5. [noip模拟]小猫爬山<迭代深搜>

    [题目描述]: Freda和rainbow饲养了N只小猫,这天,小猫们要去爬山.经历了千辛万苦,小猫们终于爬上了山顶,但是疲倦的它们再也不想徒步走下山了(呜咕>_<). Freda和rai ...

  6. Gin框架系列02:路由与参数

    回顾 上一节我们用Gin框架快速搭建了一个GET请求的接口,今天来学习路由和参数的获取. 请求动词 熟悉RESTful的同学应该知道,RESTful是网络应用程序的一种设计风格和开发方式,每一个URI ...

  7. python爬虫之requests的高级使用

    1.requests能上传文件 # 导入requests模块 import requests # 定义一个dict files = {'file': open('D:/360Downloads/1.t ...

  8. 分享淘宝时间服务器同步时间接口api和苏宁时间服务器接口api

    最近要开发一款抢购秒杀的小工具,需要同步系统时间,这里分享两个时间服务器接口api给大家: 1.淘宝时间服务器时间接口 http://api.m.taobao.com/rest/api3.do?api ...

  9. PTA数据结构与算法题目集(中文) 7-40奥运排行榜 (25 分)

    PTA数据结构与算法题目集(中文)  7-40奥运排行榜 (25 分) 7-40 奥运排行榜 (25 分)   每年奥运会各大媒体都会公布一个排行榜,但是细心的读者发现,不同国家的排行榜略有不同.比如 ...

  10. Node.js快速创建一个Express应用的几个步骤

    Node.js 的 Express 框架学习 转载和参考地址: https://developer.mozilla.org/zh-CN/docs/Learn/Server-side/Express_N ...