今天让我们一起来学习如何用TF实现线性回归模型。所谓线性回归模型就是y = W * x + b的形式的表达式拟合的模型。

我们先假设一条直线为 y = 0.1x + 0.3,即W = 0.1,b = 0.3,然后利用随机数在这条直线附近产生1000个随机点,然后利用tensorflow构造的线性模型去学习,最后对比模型所得的W和b与真实值的差距即可。

(某天在浏览Github的时候,发现了一个好东西,Github上有一个比较好的有关tensorflow的Demo合集,有注释有源代码非常适合新手入门。)

import numpy as np     #numpy库可用来存储和处理大型矩阵
import tensorflow as tf
import matplotlib.pyplot as plt    #主要用于画图

#产生1000个随机点
num_points = 1000

vectors_set = []
for i in range(num_points):
#利用random的内置函数产生1000个符合 均值为0,标准差为0.55的正态分布
  x1 = np.random.normal(0.0, 0.55)
  y1 = x1 * 0.1 + 0.3 + np.random.normal(0.0, 0.03)
  vectors_set.append([x1,y1])

x_data = [v[0] for v in vectors_set]
y_data = [v[1] for v in vectors_set]

plt.scatter(x_data, y_data, c = 'r')
plt.show()

#生成1维的W矩阵,取值为【-1,1】之间的随机数
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name = 'W')
#生成1维的b矩阵,初始值为0
b = tf.Variable(tf.zeros([1]), name = 'b')
#经过计算得出预估值Y
y = W * x_data + b

#以预估值Y和实际值Y_data之间的均方误差作为损失
loss = tf.reduce_mean(tf.square(y - y_data), name = 'loss')

#采用梯度下降法进行优化参数(梯度下降原理详情见另一篇博客)
#optimizer = tf.train.GradientDescentOptimizera(0.5).minimize(loss)
optimizer = tf.train.GradientDescentOptimizer(0.5)

#训练的过程就是最小化这个误差值
train = optimizer.minimize(loss, name = 'train')

sess = tf.Session()

init = tf.global_variables_initializer()
sess.run(init)

#打印初始化的W和b的值
print('W = ', sess.run(W), 'b = ', sess.run(b), "loss = ", sess.run(loss))
#因为数据规模不大且符合正态分布,所以执行20次训练就能达到一定效果
for step in range(20):
  sess.run(train)
#输出训练后的W和B
  print('W = ', sess.run(W), 'b = ', sess.run(b), "loss = ", sess.run(loss))

实验结果如下:

1.1000个散点图

2.预测出W、b以及loss的值

W = [0.40727448] b = [0.] loss = 0.12212546
W = [0.30741683] b = [0.30278787] loss = 0.014318982
W = [0.24240384] b = [0.3016729] loss = 0.0071945195
W = [0.19786316] b = [0.30094698] loss = 0.0038506198
W = [0.16734858] b = [0.30044967] loss = 0.0022811447
W = [0.1464432] b = [0.30010894] loss = 0.001544504
W = [0.13212104] b = [0.29987553] loss = 0.0011987583
W = [0.122309] b = [0.2997156] loss = 0.0010364805
W = [0.11558682] b = [0.29960606] loss = 0.00096031476
W = [0.11098149] b = [0.29953098] loss = 0.0009245659
W = [0.1078264] b = [0.29947957] loss = 0.00090778706
W = [0.10566486] b = [0.29944435] loss = 0.00089991186
W = [0.10418401] b = [0.2994202] loss = 0.0008962157
W = [0.10316949] b = [0.29940367] loss = 0.0008944806
W = [0.10247444] b = [0.29939234] loss = 0.00089366647
W = [0.10199826] b = [0.2993846] loss = 0.00089328433
W = [0.10167204] b = [0.29937926] loss = 0.0008931049
W = [0.10144854] b = [0.29937562] loss = 0.00089302065
W = [0.10129543] b = [0.29937312] loss = 0.00089298113
W = [0.10119054] b = [0.29937142] loss = 0.0008929627
W = [0.10111867] b = [0.29937026] loss = 0.000892954

根据实验结果可以看出第20次预测出的W和b值基本符合我们之前假设直线的值

tensorflow入门(1):构造线性回归模型的更多相关文章

  1. 用Tensorflow完成简单的线性回归模型

    思路:在数据上选择一条直线y=Wx+b,在这条直线上附件随机生成一些数据点如下图,让TensorFlow建立回归模型,去学习什么样的W和b能更好去拟合这些数据点. 1)随机生成1000个数据点,围绕在 ...

  2. tensorflow学习笔记四----------构造线性回归模型

    首先通过构造随机数,模拟数据. import numpy as np import tensorflow as tf import matplotlib.pyplot as plt # 随机生成100 ...

  3. 【TensorFlow入门完全指南】模型篇·线性回归模型

    首先呢,进行import,对于日常写代码来说,第二行经常写成:import numpy as np,这样会更加简洁.第三行import用于绘图. 定义了学习率.迭代数epoch,以及展示的学习步骤,三 ...

  4. 【TensorFlow入门完全指南】模型篇·逻辑斯蒂回归模型

    import库,加载mnist数据集. 设置学习率,迭代次数,batch并行计算数量,以及log显示. 这里设置了占位符,输入是batch * 784的矩阵,由于是并行计算,所以None实际上代表并行 ...

  5. 【TensorFlow入门完全指南】模型篇·最近邻模型

    最近邻模型,更为常见的是k-最近邻模型,是一种常见的机器学习模型,原理如下: KNN算法的前提是存在一个样本的数据集,每一个样本都有自己的标签,表明自己的类型.现在有一个新的未知的数据,需要判断它的类 ...

  6. [tensorflow] 线性回归模型实现

    在这一篇博客中大概讲一下用tensorflow如何实现一个简单的线性回归模型,其中就可能涉及到一些tensorflow的基本概念和操作,然后因为我只是入门了点tensorflow,所以我只能对部分代码 ...

  7. TensorFlow入门教程集合

    TensorFlow入门教程之0: BigPicture&极速入门 TensorFlow入门教程之1: 基本概念以及理解 TensorFlow入门教程之2: 安装和使用 TensorFlow入 ...

  8. TensorFlow从1到2(七)线性回归模型预测汽车油耗以及训练过程优化

    线性回归模型 "回归"这个词,既是Regression算法的名称,也代表了不同的计算结果.当然结果也是由算法决定的. 不同于前面讲过的多个分类算法或者逻辑回归,线性回归模型的结果是 ...

  9. 线性回归模型的 MXNet 与 TensorFlow 实现

    本文主要探索如何使用深度学习框架 MXNet 或 TensorFlow 实现线性回归模型?并且以 Kaggle 上数据集 USA_Housing 做线性回归任务来预测房价. 回归任务,scikit-l ...

随机推荐

  1. HTML5的学习(二)HTML5标签

    3.按功能排列标签 (注:红色为HTML5不支持的,蓝色为HTML5新增的标签元素.)   3.1基本 标签 描述 HTML4 HTML5 <!--...--> 定义注释. √ √ < ...

  2. php rsa理解

    参考链接:http://www.cnblogs.com/firstForEver/p/5803940.html 自己封装的一个类: <?php class CRsaAuthorization { ...

  3. 2018-2019-2 网络对抗技术 20165230 Exp5 MSF基础应用

    目录 1.实验内容 2.基础问题回答 3.实验内容 任务一:一个主动攻击实践 漏洞MS08_067(成功) 任务二:一个针对浏览器的攻击 ms11_050(成功) ms14_064(成功) 任务三:一 ...

  4. Java获取资源路径——(八)

     获取文件资源有两种方式: 第一种是: 获取Java项目根目录开始制定文件夹下指定文件,不用类加载器(目录开始要加/) // 获取工程路径 System.out.println(System.getP ...

  5. Linux性能分析的前60000毫秒【转】

    Linux性能分析的前60000毫秒 为了解决性能问题,你登入了一台Linux服务器,在最开始的一分钟内需要查看什么? 在Netflix我们有一个庞大的EC2 Linux集群,还有非常多的性能分析工具 ...

  6. WPF 未能加载文件或程序集“CefSharp.Core.dll”或它的某一个依赖项

    1.检查代码不存在问题,最后找到问题,Nut管理包没有安装CefSharp.wpf. 2.安装对应的版本即可.

  7. 带你玩转Visual Studio——带你理解微软的预编译头技术

    原文地址:http://blog.csdn.net/luoweifu/article/details/49010627 不陌生的stdafx.h 还记得带你玩转Visual Studio——带你新建一 ...

  8. 量化投资与Python之NumPy

      数组计算 NumPy是高性能科学计算和数据分析的基础包.它是pandas等其他各种工具的基础.NumPy的主要功能:ndarray,一个多维数组结构,高效且节省空间无需循环对整组数据进行快速运算的 ...

  9. Project Euler Problem6

    Sum square difference Problem 6 The sum of the squares of the first ten natural numbers is, 12 + 22  ...

  10. 020_秘钥管理服务器vault

    一. https://github.com/hashicorp/vault     #待研究