TensorFlow 实现线性回归
1、生成高斯分布的随机数
导入numpy模块,通过numpy模块内的方法生成一组在方程
y = 2 * x + 3
周围小幅波动的随机坐标。代码如下:
import numpy as np
import matplotlib.pyplot as plot def getRandomPoints(count):
xList = []
yList = []
for i in range(count):
x = np.random.normal(0, 0.5)
y = 2 * x + 3 + np.random.normal(0, 0.3)
xList.append(x)
yList.append(y)
return xList, yList if __name__ == '__main__':
X, Y = getRandomPoints(1000)
plot.scatter(X, Y)
plot.show()
运行上述代码,输出图形如下:

2、采用TensorFlow来获取上述方程的系数
首先搭建基本的预估模型y = w * x + b,然后再采用梯度下降法进行训练,通过最小化损失函数的方法进行优化,最终训练得出方程的系数。
在下面的例子中,梯度下降法的学习率为0.2,训练迭代次数为100次。
def train(x, y):
# 生成随机系数
w = tf.Variable(tf.random_uniform([1], -1, 1))
# 生成随机截距
b = tf.Variable(tf.random_uniform([1], -1, 1))
# 预估值
preY = w * x + b # 损失值:预估值与实际值之间的均方差
loss = tf.reduce_mean(tf.square(preY - y))
# 优化器:梯度下降法,学习率为0.2
optimizer = tf.train.GradientDescentOptimizer(0.2)
# 训练:最小化损失函数
trainer = optimizer.minimize(loss) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 打印初始随机系数
print('init w:', sess.run(w), 'b:', sess.run(b))
# 先训练个100次:
for i in range(100):
sess.run(trainer)
# 每10次打印下系数
if i % 10 == 9:
print('w:', sess.run(w), 'b:', sess.run(b)) if __name__ == '__main__':
X, Y = getRandomPoints(1000)
train(X, Y)
运行上面的代码,某次的最终结果为:
w = 1.9738449
b = 3.0027733
仅100次的训练迭代,得出的结果已十分接近方程的实际系数。
某次模拟训练中的输出结果如下:
init w: [-0.6468966] b: [0.52244043]
w: [1.0336646] b: [2.9878206]
w: [1.636582] b: [3.0026987]
w: [1.8528996] b: [3.0027785]
w: [1.930511] b: [3.0027752]
w: [1.9583567] b: [3.0027738]
w: [1.9683474] b: [3.0027735]
w: [1.9719319] b: [3.0027733]
w: [1.9732181] b: [3.0027733]
w: [1.9736794] b: [3.0027733]
w: [1.9738449] b: [3.0027733]
3、完整代码和结果
完整测试代码:
import numpy as np
import matplotlib.pyplot as plot
import tensorflow as tf def getRandomPoints(count, xscale=0.5, yscale=0.3):
xList = []
yList = []
for i in range(count):
x = np.random.normal(0, xscale)
y = 2 * x + 3 + np.random.normal(0, yscale)
xList.append(x)
yList.append(y)
return xList, yList def train(x, y, learnrate=0.2, cycle=100):
# 生成随机系数
w = tf.Variable(tf.random_uniform([1], -1, 1))
# 生成随机截距
b = tf.Variable(tf.random_uniform([1], -1, 1))
# 预估值
preY = w * x + b # 损失值:预估值与实际值之间的均方差
loss = tf.reduce_mean(tf.square(preY - y))
# 优化器:梯度下降法
optimizer = tf.train.GradientDescentOptimizer(learnrate)
# 训练:最小化损失函数
trainer = optimizer.minimize(loss) with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 打印初始随机系数
print('init w:', sess.run(w), 'b:', sess.run(b))
for i in range(cycle):
sess.run(trainer)
# 每10次打印下系数
if i % 10 == 9:
print('w:', sess.run(w), 'b:', sess.run(b))
return sess.run(w), sess.run(b) if __name__ == '__main__':
X, Y = getRandomPoints(1000)
w, b = train(X, Y)
plot.scatter(X, Y)
plot.plot(X, w * X + b, c='r')
plot.show()
最终效果图如下,蓝色为高斯随机分布数据,红色为最终得出的直线:

本文地址:https://www.cnblogs.com/laishenghao/p/9571343.html
TensorFlow 实现线性回归的更多相关文章
- tensorflow实现线性回归、以及模型保存与加载
内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...
- TensorFlow简单线性回归
TensorFlow简单线性回归 将针对波士顿房价数据集的房间数量(RM)采用简单线性回归,目标是预测在最后一列(MEDV)给出的房价. 波士顿房价数据集可从http://lib.stat.cmu.e ...
- 深度学习入门实战(二)-用TensorFlow训练线性回归
欢迎大家关注腾讯云技术社区-博客园官方主页,我们将持续在博客园为大家推荐技术精品文章哦~ 作者 :董超 上一篇文章我们介绍了 MxNet 的安装,但 MxNet 有个缺点,那就是文档不太全,用起来可能 ...
- 利用TensorFlow实现线性回归模型
准备数据: import numpy as np import tensorflow as tf import matplotlib.pylot as plt # 随机生成1000个点,围绕在y=0. ...
- tensorflow实现线性回归总结
1.知识点 """ 模拟一个y = 0.7x+0.8的案例 报警: 1.initialize_all_variables (from tensorflow.python. ...
- 如何用TensorFlow实现线性回归
环境Anaconda 废话不多说,关键看代码 import tensorflow as tf import os os.environ['TF_CPP_MIN_LOG_LEVEL']='2' tf.a ...
- TensorFlow多元线性回归实现
多元线性回归的具体实现 导入需要的所有软件包: 因为各特征的数据范围不同,需要归一化特征数据.为此定义一个归一化函数.另外,这里添加一个额外的固定输入值将权重和偏置结合起来.为此定义函数 appe ...
- TensorFlow实现线性回归模型代码
模型构建 1.示例代码linear_regression_model.py #!/usr/bin/python # -*- coding: utf-8 -* import tensorflow as ...
- 学习TensorFlow,线性回归模型
学习TensorFlow,在MNIST数据集上建立softmax回归模型并测试 一.代码 <span style="font-size:18px;">from tens ...
- tensorflow 学习1——tensorflow 做线性回归
. 首先 Numpy: Numpy是Python的科学计算库,提供矩阵运算. 想想list已经提供了矩阵的形式,为啥要用Numpy,因为numpy提供了更多的函数. 使用numpy,首先要导入nump ...
随机推荐
- 获取本机正在使用的ipv4地址(访问互联网的IP)
[转]原文地址:http://www.cnblogs.com/lijianda/p/6604651.html 1.一个电脑有多个网卡,有线的.无线的.还有vmare虚拟的两个网卡.2.就算只有一个网卡 ...
- 搞定pycharm专业版的安装
学习python也有一段时间了,装了python2,也装了python3.对于IDE当然首选了人人拍掌叫好的pycharm.其实作为小白,一开始的时候并不知道什么是IDE,什么是pychram,以为装 ...
- 天河2号-保持使用yhrun/srun时连接不中断 (screen 命令教程 )
问题重述: 当我们使用天河机进行并行程序实验的时候,都会使用到yhrun/srun命令.在超算环境下,yhrun 命令用来进行提交交互式作业,有屏幕输出.但是容易受到网络波动影响导致断网或者关闭窗口最 ...
- scp机器间远程拷贝
scp是 secure copy的缩写, scp是linux系统下基于ssh登陆进行安全的远程文件拷贝命令.linux的scp命令可以在linux服务器之间复制文件和目录. 使用语法:scp [参数 ...
- [Python_6] Python 配置 MySQL 访问
0. 说明 Python 访问 MySQL 数据库,需要安装 MySQL 的 Python 插件. 1. 安装 MySQL 插件 pip install PyMySQL 2. 编写代码 # -*-co ...
- Linux 基本概念 & 命令
0. Linux 理解 Linux 是一种操作系统,主要应用于服务器. Linux 性能稳定,其中的许多版本不收费(如CentOS),占用资源较少. 1. 命令行的状态 在 Linux 命令行下以上分 ...
- Win7下的C盘重新划分为两个盘
Win 7分盘 注意事项:操作之前,先备份好重要数据,以免误操作导致数据丢失 . 方法步骤如下: 1.在桌面右键点击"计算机"-"管理": 2.鼠标左键单击选& ...
- eclipse 汉化详细方法
1.首先确认自己的 eclipse 是哪个版本,这个很关键,涉及到后面要用到的语言包需要与版本匹配,启动 eclipse,观察对应的版本号,比如我用的是 Photon 版本 2.参照官方给的方法进行下 ...
- Java中Map根据键值(key)或者值(value)进行排序实现
我们都知道,java中的Map结构是key->value键值对存储的,而且根据Map的特性,同一个Map中 不存在两个Key相同的元素,而value不存在这个限制.换句话说,在同一个Map中Ke ...
- SASS对css的管理
一.SASS简介 SASS是一种CSS的开发工具,提供了许多便利的写法,大大节省了设计者的时间,使得CSS的开发,变得简单和可维护. 本文总结了SASS的主要用法.我的目标是,有了这篇文章,日常的一般 ...