本文同步自:https://zhuanlan.zhihu.com/p/30738405

本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法

有一组数据:

x=[1,2,3,4,5,6,7,8,9,10]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条过原点的直线,穿过上述所有点

这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis

由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设 =0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢

引入概念代价函数 cost function

在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中, 为2

将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点

那么问题就变成了如何最小化J(θ)

在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低

引入梯度下降法

其中α为步幅

梯度下降法可以解释为:对J(θ)求关于 (本例中只有 )的偏导数并乘以步幅,再用 减去该值,得到的结果赋值给 。此过程需要重复多次

步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛

当选取了一个较大步幅的时候,将震荡收敛

当选取了一个过大步幅的时候,将无法收敛

调整一下例子

x=[5,6,7,8,9,10,11,12,13,14]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条直线,穿过上述所有点

很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])

n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph

plt.figure()
plt.ion()

with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
                  "W=", sess.run(W), "b=", sess.run(b))

            plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
            plt.plot(train_X, train_Y, 'ro', label='Original data')
            plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
            plt.pause(0.001)
            plt.clf()

    print("Optimization Finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

    plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
    plt.pause(10)

线性回归,附tensorflow实现的更多相关文章

  1. 逻辑回归,附tensorflow实现

    本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...

  2. 简单的线性回归问题-TensorFlow+MATLAB·

    首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...

  3. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  4. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  5. 经典损失函数:交叉熵(附tensorflow)

    每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...

  6. Tensorflow之单变量线性回归问题的解决方法

    跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...

  7. Tensorflow学习笔记01

    Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...

  8. TensorFlow 从零到helloWorld

    目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建   2.1 tensorflow安装   2.2 CUDA安装   2.3 ...

  9. TF linear regression

    本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...

随机推荐

  1. app启动页问题

    今天自己做的小作品准备提交,就差一个启动页,各种百度,各种搜,结果还好最后终于出来了,和大家分享一下,这个过程中遇到的各种小问题.(注XCode版本为7.2) 1.启动页一般都是图片,因为苹果有4,4 ...

  2. Django REST FrameWork中文教程2:请求和响应

    从这一点开始,我们将真正开始覆盖REST框架的核心.我们来介绍几个基本的构建块. 请求对象REST框架引入了Request扩展常规的对象HttpRequest,并提供更灵活的请求解析.Request对 ...

  3. Logcat monkey命令

    1. monkey命令 adb shell monkey -p com.autonavi.gxdtaojin --bugreport --ignore-crashes --ignore-timeout ...

  4. Spring装配bean

    Spring配置的可选方案 Spring提供了如下三种装配机制: (1)在XML中显式配置 (2)在Java中显式配置 (3)隐式的bean发现机制和自动装配 Spring有多种方式可以装配bean, ...

  5. 《Unity3D/2D游戏开发从0到1(第二版本)》 书稿完结总结

    前几天,个人著作<Unity3D/2D游戏开发从0到1(第二版)>经过七八个月的技术准备以及近3个月的日夜编写,在十一长假后终于完稿.今天抽出一点时间来,给广大热心小伙伴们汇报一下书籍概况 ...

  6. wordpress 显示数学公式 (MathJax-LaTeX)

    blog 不放一堆数学公式怎么能显得高大上,所以 MathJax-LaTeX 也是必装的插件之一了. 一.安装 MathJax-LaTex 插件 直接在 wordpress 插件中,搜索并安装 Mat ...

  7. [BC]Four Inages Strategy(三维空间判断正方形)

    题目连接 :http://bestcoder.hdu.edu.cn/contests/contest_showproblem.php?cid=577&pid=1001 题目大意:在三维空间中, ...

  8. 从零自学Hadoop(25):Impala相关操作下

    阅读目录 序 导入数据 查询 系列索引 本文版权归mephisto和博客园共有,欢迎转载,但须保留此段声明,并给出原文链接,谢谢合作. 文章是哥(mephisto)写的,SourceLink 序 上一 ...

  9. Android 圆角的效果实现

    Android 自定义ImageView实现圆角图片昨天给学生布置作业,写微信首页,也就是聊天的界面,listView里的item中联系人的头像是圆角的,图形界面如下: 那么我就仔细研究了圆角的具体实 ...

  10. java实现在线文档浏览

    目前发现两种方法: 1.http://dxx23.iteye.com/blog/1947083 FlexPaper+SWFTools ,java实现在线文档浏览 2.webOffice