本文同步自:https://zhuanlan.zhihu.com/p/30738405

本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法

有一组数据:

x=[1,2,3,4,5,6,7,8,9,10]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条过原点的直线,穿过上述所有点

这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis

由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设 =0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢

引入概念代价函数 cost function

在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中, 为2

将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点

那么问题就变成了如何最小化J(θ)

在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低

引入梯度下降法

其中α为步幅

梯度下降法可以解释为:对J(θ)求关于 (本例中只有 )的偏导数并乘以步幅,再用 减去该值,得到的结果赋值给 。此过程需要重复多次

步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛

当选取了一个较大步幅的时候,将震荡收敛

当选取了一个过大步幅的时候,将无法收敛

调整一下例子

x=[5,6,7,8,9,10,11,12,13,14]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条直线,穿过上述所有点

很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])

n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph

plt.figure()
plt.ion()

with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
                  "W=", sess.run(W), "b=", sess.run(b))

            plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
            plt.plot(train_X, train_Y, 'ro', label='Original data')
            plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
            plt.pause(0.001)
            plt.clf()

    print("Optimization Finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

    plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
    plt.pause(10)

线性回归,附tensorflow实现的更多相关文章

  1. 逻辑回归,附tensorflow实现

    本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...

  2. 简单的线性回归问题-TensorFlow+MATLAB·

    首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...

  3. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  4. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  5. 经典损失函数:交叉熵(附tensorflow)

    每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...

  6. Tensorflow之单变量线性回归问题的解决方法

    跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...

  7. Tensorflow学习笔记01

    Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...

  8. TensorFlow 从零到helloWorld

    目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建   2.1 tensorflow安装   2.2 CUDA安装   2.3 ...

  9. TF linear regression

    本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...

随机推荐

  1. adb获取Android性能数据

    环境:Android测试环境 搭建Android测试环境: 1.下载AndroidSDK: 2.配置环境变量: (1).ANDROID_HOME (2).ANDROID_HOME-TOOLS (3). ...

  2. springboot自定义配置源

    概述 我们知道,在Spring boot中可以通过xml或者@ImportResource 来引入自己的配置文件,但是这里有个限制,必须是本地,而且格式只能是 properties(或者 yaml). ...

  3. 彻底理解Java的Future模式

    先上一个场景:假如你突然想做饭,但是没有厨具,也没有食材.网上购买厨具比较方便,食材去超市买更放心. 实现分析:在快递员送厨具的期间,我们肯定不会闲着,可以去超市买食材.所以,在主线程里面另起一个子线 ...

  4. ASP.NET MVC 学习笔记 1

    1. 什么是ASP.Net MVC ASP.Net MVC是一种开发Web应用程序的工具(is a web application development framework),采用Model-Vie ...

  5. Ubuntu 中登录相关的日志

    登录相关的日志涉及到系统的安全,所以是系统管理中非常重要的一部分内容.本文试图对登录相关的日志做一个整理. /var/log/auth.log 这是一个文本文件,记录了所有和用户认证相关的日志.无论是 ...

  6. Coursera_程序设计与算法_计算导论与C语言基础_数组应用练习

    您也可以在我的个人博客中阅读此文章:跳转 编程题#1:求字母的个数 描述 在一个字符串中找出元音字母a,e,i,o,u出现的次数. 输入 输入一行字符串(字符串中可能有空格,请用gets(s)方法把一 ...

  7. UVW源码漫谈(四)

    十一假期后就有点懒散,好长时间都没想起来写东西了.另外最近在打LOL的S赛.接触LOL时间不长,虽然平时玩的比较少,水平也相当菜,但是像这种大型的赛事有时间还是不会错过的.主要能够感受到选手们对竞技的 ...

  8. DNS over TLS到底有多牛?你想知道的都在这儿

    DNS over TLS,让电信.移动等各种ISP无法监视你的浏览轨迹...... SSL证书有助于客户端浏览器和网站服务器之间的加密连接. 这意味着在连接期间,所有的通信和活动都被遮蔽. 但通常意义 ...

  9. Android Studio 3.0 使用问题解决方案总结

    问题:创建新项目非常慢 问题描述: 更新到3.0版本后,出现创建新项目一直停留在如下图的界面: 选择等待?不知道要等到什么时候,这时候怎么办呢?显然,不能一直等待下去呀,需要想办法让他能尽快的加载好才 ...

  10. 用BroadcastReceiver监听手机网络状态变化

    android--解决方案--用BroadcastReceiver监听手机网络状态变化 标签: android网络状态监听方案 2015-01-20 15:23 1294人阅读 评论(3) 收藏 举报 ...