tensorflow 已经发布了 2.0 alpha 版本,所以是时候学一波 tf 了。官方教程有个平面拟合的类似Hello World的例子,但没什么解释,新手理解起来比较困难。

所以本文对这个案例进行详细解释,对关键的numpy, tf, matplotlib 函数加了注释,并且对原始数据和训练效果进行了可视化展示,希望对你理解这个案例有所帮助。

因为 2.0 成熟还需要一段时间,所以本文使用的是 tf 1.13.1 版本,Python 代码也从 Python 2 迁移到了 Python 3。

原始代码见如下链接:

http://www.tensorfly.cn/tfdoc/get_started/introduction.html

原始代码如下:

import tensorflow as tf
import numpy as np # 使用 NumPy 生成假数据(phony data), 总共 100 个点.
x_data = np.float32(np.random.rand(2, 100)) # 随机输入
y_data = np.dot([0.100, 0.200], x_data) + 0.300 # 构造一个线性模型
#
b = tf.Variable(tf.zeros([1]))
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
y = tf.matmul(W, x_data) + b # 最小化方差
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
train = optimizer.minimize(loss) # 初始化变量
init = tf.initialize_all_variables() # 启动图 (graph)
sess = tf.Session()
sess.run(init) # 拟合平面
for step in xrange(0, 201):
sess.run(train)
if step % 20 == 0:
print step, sess.run(W), sess.run(b) # 得到最佳拟合结果 W: [[0.100 0.200]], b: [0.300]

使用 NumPy 生成假数据(phony data), 总共 100 个点.

x_data 是二维数组,每个维度各 100 个点,定义了一个平面

import tensorflow as tf
import numpy as np x_data = np.float32(np.random.rand(2, 100)) # 随机输入
x_data[0][:10]
array([0.35073978, 0.16348423, 0.7059651 , 0.7696817 , 0.4036316 ,
0.52306384, 0.8748454 , 0.52280265, 0.9512267 , 0.10213694],
dtype=float32)
x_data[1][:10]
array([0.33513898, 0.07861521, 0.58426493, 0.87010854, 0.24188931,
0.64622885, 0.39593607, 0.4805421 , 0.6906034 , 0.41190282],
dtype=float32)

y_datax_data 经过变换得到,np.dot 实现矩阵乘法,要求第一个矩阵的列数和第二个矩阵的行数相同,最后加一个偏移量

比如 y_data[0] 就等于 x_data[0][0]*0.1 + x_data[1][0]*0.2 +0.3

这里整体的效果,相当于对原始的平面在三维空间进行了一个倾斜旋转,倾斜的参数由一个权重 W=[0.1, 0.2] 和偏移量 b=0.3 来确定

y_data = np.dot([0.100, 0.200], x_data) + 0.300
y_data[:10]
array([0.40210177, 0.33207147, 0.4874495 , 0.55098988, 0.38874102,
0.48155215, 0.46667175, 0.44838868, 0.53324335, 0.39259426])

原始数据可视化

使用 matplotlib 的 scatter 功能实现 3D 散点图,x 轴是 x_data[0], y 轴是 x_data[1],z 轴是 y_data

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D x, y, z = x_data[0], x_data[1], y_data
fig = plt.figure(figsize=(20, 14))
ax = fig.add_subplot(111, projection='3d')
ax.scatter(x, y, z, c='y')
plt.show()

构造一个线性模型

线性模型一般由权重 W 和偏移量 b 来描述,平面上直线拟合 W 是一个标量数字,而本例在三维空间进行平面拟合,所以 W 是一个有两个分量的向量。

b = tf.Variable(tf.zeros([1]))
b
<tf.Variable 'Variable:0' shape=(1,) dtype=float32_ref>
W = tf.Variable(tf.random_uniform([1, 2], -1.0, 1.0))
W
<tf.Variable 'Variable_1:0' shape=(1, 2) dtype=float32_ref>

y 是模拟的结果,tf.matmul 将矩阵 A 乘以矩阵 B,生成 A * B,最后加上偏移量 b

y = tf.matmul(W, x_data) + b
y
<tf.Tensor 'add:0' shape=(1, 100) dtype=float32>

最小化方差

定义损失函数,线性回归里常用的是均方误差,就是真实值和预测值的差的平方和

loss = tf.reduce_mean(tf.square(y - y_data))

定义优化器,这里使用梯度下降算法

optimizer = tf.train.GradientDescentOptimizer(0.5)

使用指定的优化器和损失函数定义一个训练

train = optimizer.minimize(loss)

初始化变量

init = tf.global_variables_initializer()

启动图 (graph)

sess = tf.Session()
sess.run(init)

拟合平面

我们知道真实的 W[0.1, 0.2]b0.3,看下迭代训练 200 次的拟合效果怎么样

for step in range(0, 201):
sess.run(train)
if step % 20 == 0:
print(step, sess.run(W), sess.run(b))
0 [[ 0.8425213  -0.12354811]] [0.13099673]
20 [[0.289453 0.12614608]] [0.2357107]
40 [[0.15044135 0.18556874]] [0.28013656]
60 [[0.11361164 0.19769716]] [0.29380444]
80 [[0.10372839 0.1998468 ]] [0.29805225]
100 [[0.10103785 0.20009856]] [0.2993837]
120 [[0.1002938 0.20006898]] [0.29980397]
140 [[0.1000846 0.20003161]] [0.2999374]
160 [[0.10002476 0.20001256]] [0.29997995]
180 [[0.10000735 0.20000464]] [0.29999357]
200 [[0.10000221 0.20000164]] [0.29999793]

这里迭代 200 次的结果 W[0.10000221 0.20000164], b0.29999793,可以看出跟真实值差别非常小了

拟合效果可视化

把原始的分布在三维空间的点,组成一个个的三元组,分别表示 x, y, z 的坐标值

import numpy as np
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt points = list(zip(x_data[0],x_data[1],y_data))
points[:10]
[(0.35073978, 0.33513898, 0.40210177302360534),
(0.16348423, 0.07861521, 0.33207146525382997),
(0.7059651, 0.58426493, 0.4874494969844818),
(0.7696817, 0.87010854, 0.5509898781776428),
(0.4036316, 0.24188931, 0.3887410223484039),
(0.52306384, 0.64622885, 0.4815521538257599),
(0.8748454, 0.39593607, 0.4666717529296875),
(0.52280265, 0.4805421, 0.44838868379592894),
(0.9512267, 0.6906034, 0.5332433462142945),
(0.10213694, 0.41190282, 0.3925942569971085)]
w_val = sess.run(W)
b_val = sess.run(b)
def cross(a, b):
return [a[1]*b[2] - a[2]*b[1],
a[2]*b[0] - a[0]*b[2],
a[0]*b[1] - a[1]*b[0]] # https://stackoverflow.com/questions/20699821/find-and-draw-regression-plane-to-a-set-of-points
def show(points, a, b, c):
# 定义画布
fig = plt.figure(figsize=(20, 14))
ax = fig.add_subplot(111, projection='3d') # 绘制原始的散点
xs, ys, zs = zip(*points)
ax.scatter(xs, ys, zs) # 绘制拟合平面
point = np.array([0.0, 0.0, c])
normal = np.array(cross([1,0,a], [0,1,b]))
d = -point.dot(normal)
xx, yy = np.meshgrid([0,1], [0,1])
z = (-normal[0] * xx - normal[1] * yy - d) * 1. / normal[2]
ax.plot_surface(xx, yy, z, alpha=0.2, color=[0,1,0]) ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z') plt.show()
show(points, w_val[0][0],w_val[0][1],b_val[0])

蛙蛙推荐: TensorFlow Hello World 之平面拟合的更多相关文章

  1. 【蛙蛙推荐】Lucene.net试用

    [蛙蛙推荐]Lucene.net试用   [简介] lucene.net好多人都知道的吧,反正我是最近才好好的看了一下,别笑我拿历史当新闻哦,不太了解Lucence的朋友先听我说两句哦.Lucene的 ...

  2. 蛙蛙推荐:快速自定义Boostrap样式

    现在越来越多的网站使用Bootstrap,相信大家也审美疲劳了,所以我们要用Bootstrap的第一步就是先把顶部的导航栏来自定义一下. 我现在使用的是bootstrap3.0,顶部导航定义如下 &l ...

  3. html打造动画【系列2】- 可爱的蛙蛙表情

    先感受一下全部表情包: 在开始之前先安利一个知识点:Flex弹性布局 我们一般做水平三列布局都是用的float方法,将每一块浮动显示在同一行.这种方法会导致元素没有原来的高度属性,要用清除浮动来解决空 ...

  4. 数据的平面拟合 Plane Fitting

    数据的平面拟合 Plane Fitting 看到了一些利用Matlab的平面拟合程序 http://www.ilovematlab.cn/thread-220252-1-1.html

  5. [CC]平面拟合

    常见的平面拟合方法一般是最小二乘法.当误差服从正态分布时,最小二乘方法的拟合效果还是很好的,可以转化成PCA问题. 当观测值的误差大于2倍中误差时,认为误差较大.采用最小二乘拟合时精度降低,不够稳健. ...

  6. 蛙蛙推荐:WEB安全入门

    信息安全基础 信息安全目标 真实性:对信息的来源进行判断,能对伪造来源的信息予以鉴别, 就是身份认证. 保密性:保证机密信息不被窃听,盗取,或窃听者不能了解信息的真实含义. 完整性:保证数据的一致性, ...

  7. 蛙蛙推荐:AngularJS学习笔记

    为了降低前端代码的数量,提高可维护性,可测试性,学习了下AngularJS,正在准备投入项目开发中. AngularJS的概念比较多,如果面向对象方面的书理解的不透的话学习起来有些费劲,它的官方有个快 ...

  8. 蛙蛙推荐:如何实时监控MySql状态

    大多网站的性能瓶颈都会出在数据库上,所以想把Mysql监控起来,就搜索了下相关资料. 后来和同事讨论了下cacti和nagios有些老套和过时,graphite比较时尚,然后就搜了下相关的资料,最后搞 ...

  9. [蛙蛙推荐]SICP第一章学习笔记-编程入门

    本书简介 <计算机程序的构造与解释>这本书是MIT计算机科学学科的入门课程, 大部分学生在学这门课程前都没有接触过程序设计,也就是说这本书是针对编程新手写的. 虽然是入门课程,但起点比较高 ...

随机推荐

  1. YourSQLDba遭遇.NET Framework Error 6522

    一工厂的SQL Server数据库服务器上的YourSQLDba_LogBackups作业做事务日志备份时,突然出现异常,异常的错误信息指向.NET Framework,出现这个问题时,一般我估计是该 ...

  2. mssql sqlserver 关键字 GROUPING用法简介及说明

    转自: http://www.maomao365.com/?p=6208  摘要: GROUPING 用于区分列是否由 ROLLUP.CUBE 或 GROUPING SETS 聚合而产生的行 如果是原 ...

  3. mssql sql server ceiling floor 函数用法简介

    摘自: http://www.maomao365.com/?p=5581摘要: 下文主要讲述ceiling.floor函数的功能及举例说明  一.ceiling floor函数功能简介 ceiling ...

  4. Oracle EBS 查看执行计划

    explain plan forSELECT MMT.TRANSACTION_ID,GIR.JE_HEADER_ID,GIR.JE_LINE_NUMFROM   GL_IMPORT_REFERENCE ...

  5. 洗礼灵魂,修炼python(85)-- 知识拾遗篇 —— 深度剖析让人幽怨的编码

    编码 这篇博文的主题是,编码问题,老生常谈的问题了对吧?从我这一套的文章来看,前面已经提到好多次编码问题了,的确这个确实很重要,这可是难道了很多能人异士的,当你以为你学懂了,在研究爬虫时你发现你错了, ...

  6. SQLServer2016 AlwaysOn AG基于工作组的搭建笔记

    最近搭建了一套SQLServer2016 AlwaysOn AG. (后记:经实际测试,使用SQLServer2012 也同样可以在Winserver2016上搭建基于工作组的AlwaysOn AG, ...

  7. Servlet(二):初识Servlet

    在手动写完一个Servlet小例子后,是不是有很多疑问,接下来会为大家详细介绍Servlet的知识. 1.什么是Servlet 是在服务器上运行的小程序.一个servlet就是一个Java类,并且可以 ...

  8. ESLint笔记

    ESLint是JavaScript的代码检查工具.因为JS是弱类型的语言,不需要编译,代码错误是在运行时调适的,所以需要个工具在编码的过程发现问题.ESLint的初衷是为了让程序员可以创建自己的检测规 ...

  9. 16.Python网络爬虫之Scrapy框架(CrawlSpider)

    引入 提问:如果想要通过爬虫程序去爬取”糗百“全站数据新闻数据的话,有几种实现方法? 方法一:基于Scrapy框架中的Spider的递归爬取进行实现(Request模块递归回调parse方法). 方法 ...

  10. python 类与类之间的关系

    一.依赖关系(紧密程度最低) (1)简单的定义:就是方法中传递一个对象.此时类与类之间存在依赖关系,此关系比较低. (2)实例植物大战僵尸简易版 题目要求:创建一个植物,创建一个僵尸 1.植物:名字. ...