线性回归,附tensorflow实现
本文同步自:https://zhuanlan.zhihu.com/p/30738405
本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法
有一组数据:
x=[1,2,3,4,5,6,7,8,9,10]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条过原点的直线,穿过上述所有点
这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis
由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设
=0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢
引入概念代价函数 cost function
在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中,
为2
将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 与
都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点
那么问题就变成了如何最小化J(θ)
在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低
引入梯度下降法

其中α为步幅
梯度下降法可以解释为:对J(θ)求关于 (本例中只有
)的偏导数并乘以步幅,再用
减去该值,得到的结果赋值给
。此过程需要重复多次
步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛
当选取了一个较大步幅的时候,将震荡收敛
当选取了一个过大步幅的时候,将无法收敛
调整一下例子
x=[5,6,7,8,9,10,11,12,13,14]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条直线,穿过上述所有点
很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])
n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")
# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph
plt.figure()
plt.ion()
with tf.Session() as sess:
sess.run(init)
# Fit all training data
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X: x, Y: y})
# Display logs per epoch step
if (epoch + 1) % display_step == 0:
c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
"W=", sess.run(W), "b=", sess.run(b))
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(0.001)
plt.clf()
print("Optimization Finished!")
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(10)
线性回归,附tensorflow实现的更多相关文章
- 逻辑回归,附tensorflow实现
本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...
- 简单的线性回归问题-TensorFlow+MATLAB·
首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...
- 利用VGG19实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 经典损失函数:交叉熵(附tensorflow)
每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...
- Tensorflow之单变量线性回归问题的解决方法
跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...
- Tensorflow学习笔记01
Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...
- TensorFlow 从零到helloWorld
目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建 2.1 tensorflow安装 2.2 CUDA安装 2.3 ...
- TF linear regression
本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...
随机推荐
- Oracle RAC + ASM + Grid安装
(一)环境准备 主机操作系统 windows10 虚拟机平台 vmware workstation 12 虚拟机操作系统 redhat 5.5 x86(32位) :Linux.5.5.for.x86. ...
- python中sys.exit()和os._exit(0)退出程序
python中退出程序的两种方法,0为默认状态,可以为空,两者均会退出当前运行的程序,os._exit(0)中的0不能省略 sys.exit(0):可以捕获SystemExit异常,然后做相应的清理工 ...
- C++基础知识2
2 变量和基本类型 2.1 基本内置类型 C++定义了一系列包括算术类型(arithmetic type)和空类型(void)在内的基本数据类型.其中算术类型包含字符,整型数,布尔值和浮点数.空类型不 ...
- abapGit简介与教程
你是ABAP开发者?你用abapGit吗? 看到这个问题,读者也许会想,什么是abapGit?就让我们从这个问题开始.简单地说,abapGit是一个以ABAP写成为ABAP服务的Git客户端. 有的读 ...
- Hibernate映射类型
- JPA的一对多映射(双向)关联
实体Customer:用户. 实体Order:订单. Customer和Order是一对多关系.那么在JPA中,如何表示一对多的双向关联呢? JPA使用@OneToMany和@ManyToOne来标识 ...
- 2017 多校训练 1006 Function
Function Time Limit: 4000/2000 MS (Java/Others) Memory Limit: 131072/131072 K (Java/Others)Total ...
- Best Coder #86 1001 Price List(大水题)
Price List Accepts: 880 Submissions: 2184 Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 26214 ...
- (纪念第一道完全自己想的树DP)CodeForces 219D Choosing Capital for Treeland
Choosing Capital for Treeland time limit per test 3 seconds memory limit per test 256 megabytes inpu ...
- 2015-2016 ACM-ICPC, NEERC, Southern Subregional Contest A Email Aliases(模拟STL vector+map)
Email AliasesCrawling in process... Crawling failed Time Limit:2000MS Memory Limit:524288KB ...