线性回归,附tensorflow实现
本文同步自:https://zhuanlan.zhihu.com/p/30738405
本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法
有一组数据:
x=[1,2,3,4,5,6,7,8,9,10]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条过原点的直线,穿过上述所有点
这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis
由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设
=0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢
引入概念代价函数 cost function
在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中,
为2
将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 与
都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点
那么问题就变成了如何最小化J(θ)
在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低
引入梯度下降法

其中α为步幅
梯度下降法可以解释为:对J(θ)求关于 (本例中只有
)的偏导数并乘以步幅,再用
减去该值,得到的结果赋值给
。此过程需要重复多次
步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛
当选取了一个较大步幅的时候,将震荡收敛
当选取了一个过大步幅的时候,将无法收敛
调整一下例子
x=[5,6,7,8,9,10,11,12,13,14]
y=[1,2,3,4,5,6,7,8,9,10]
要求画一条直线,穿过上述所有点
很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])
n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")
# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph
plt.figure()
plt.ion()
with tf.Session() as sess:
sess.run(init)
# Fit all training data
for epoch in range(training_epochs):
for (x, y) in zip(train_X, train_Y):
sess.run(optimizer, feed_dict={X: x, Y: y})
# Display logs per epoch step
if (epoch + 1) % display_step == 0:
c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
"W=", sess.run(W), "b=", sess.run(b))
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(0.001)
plt.clf()
print("Optimization Finished!")
training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')
plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
plt.plot(train_X, train_Y, 'ro', label='Original data')
plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
plt.pause(10)
线性回归,附tensorflow实现的更多相关文章
- 逻辑回归,附tensorflow实现
本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...
- 简单的线性回归问题-TensorFlow+MATLAB·
首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...
- 利用VGG19实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)
源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...
- 经典损失函数:交叉熵(附tensorflow)
每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...
- Tensorflow之单变量线性回归问题的解决方法
跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...
- Tensorflow学习笔记01
Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...
- TensorFlow 从零到helloWorld
目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建 2.1 tensorflow安装 2.2 CUDA安装 2.3 ...
- TF linear regression
本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...
随机推荐
- 云计算---openstack各服务端口使用情况
端口占用情况 端口情况可以使用ss -tanp命令进行查看 监听的所有端口ss -tanp | grep LISTEN 基础服务 22 --SSH 3306 --MariaDB(MySQL) 2701 ...
- MySql5.7安装及配置
MySQL是一个关系型数据库管理系统,由瑞典MySQL AB 公司开发,目前属于 Oracle 旗下公司.MySQL 最流行的关系型数据库管理系统,在 WEB 应用方面MySQL是最好的 RDBMS ...
- python 字典详解
1.字典的定义 字典类似于列表,但相对于列表来说字典更加通用,列表的下标必须必须为整数,而字典下标则可以为任意字符串/数字等,不可以是可变数据类型(列表,数组,元组) 字典包含下标(keys)集合和值 ...
- win10 uwp 切换主题
本文主要说如何在UWP切换主题,并且如何制作主题. 一般我们的应用都要有多种颜色,一种是正常的白天颜色,一种是晚上的黑夜颜色,还需要一种辅助的高对比颜色.这是微软建议的,一般应用都要包含的颜色. 我们 ...
- linux 内核 zImage 生成过程分析
1. 依据arch/arm/kernel/vmlinux.lds 生成linux内核源码根目录下的vmlinux,这个vmlinux属于未压缩,带调试信息.符号表的最初的内核,大小约23MB: arm ...
- ThreadPoolExecutor系列<一、ThreadPoolExecutor 机制>
本文系作者原创,转载请注明出处:http://www.cnblogs.com/further-further-further/p/7681529.html 解决问题: 1. 处理大量异步任务时能减少每 ...
- (转)java内存泄漏的定位与分析
转自:http://blog.csdn.net/x_i_y_u_e/article/details/51137492 1.为什么会发生内存泄漏 java 如何检测内在泄漏呢?我们需要一些工具进行检测, ...
- 【ASP.NET MVC 学习笔记】- 19 REST和RESTful Web API
本文参考:http://www.cnblogs.com/willick/p/3441432.html 1.目前使用Web服务的三种主流的方式是:远程过程调用(RPC),面向服务架构(SOA)以及表征性 ...
- 【20171027中】alert(1) to win 第13,14,15,16题
第13题 题目: function escape(s) { var tag = document.createElement('iframe'); // For this one, you get t ...
- IOS 中的JS
文章摘自: http://www.cocoachina.com/ios/20150127/11037.html JSContext/JSValue JSContext 即JavaScript代码的 ...