本文同步自:https://zhuanlan.zhihu.com/p/30738405

本文旨在通过介绍线性回归来引出一些基本概念:h(x),J(θ),梯度下降法

有一组数据:

x=[1,2,3,4,5,6,7,8,9,10]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条过原点的直线,穿过上述所有点

这组数据在二维平面表现如下

引入概念,假设函数:h(x)。h代表hypothesis

由于是过原点的直线,所以可以列出方程h(x):

先随意假设一个 ,在这先假设 =0.5 ,函数图如下

显而易见这条直线并不是我们想要的。那么具体的,怎么判断一条直线的好坏呢

引入概念代价函数 cost function

在本例中,可以由拟合数据和原始数据对应点的误差的平方的均值来判断直线的好坏;列出J(θ)如下:

其中m表示数据的总量,在本题中为10; 并不是代表x的i次方,而是代表第i个x的数值,例如在本例中, 为2

将h(x)带入,得

函数图是这样一个形状,数值对不上,凑合着看吧;有一点值得注意,在J(θ)中, 都应该作为常量来处理

显然,J(θ)值越小,点到直线的距离总和越少,画出来的直线效果也就越好。放到题目中就是当J(θ)=0的时候,画出的直线穿过了所有的点

那么问题就变成了如何最小化J(θ)

在这个例子中可以手动计算,也就是正规方程法,但是随着问题复杂度的增加,正规方程法的实用性会越来越低

引入梯度下降法

其中α为步幅

梯度下降法可以解释为:对J(θ)求关于 (本例中只有 )的偏导数并乘以步幅,再用 减去该值,得到的结果赋值给 。此过程需要重复多次

步幅的选择会直接关系到梯度下降法的效果,如下图

当选取了一个较小步幅的时候,将正确收敛

当选取了一个较大步幅的时候,将震荡收敛

当选取了一个过大步幅的时候,将无法收敛

调整一下例子

x=[5,6,7,8,9,10,11,12,13,14]

y=[1,2,3,4,5,6,7,8,9,10]

要求画一条直线,穿过上述所有点

很明显,对于这组数据,仅仅是过原点的直线无法满足要求,所以列出新的h(x):

而判断一条直线的好坏还可以沿用之前的J(θ):

函数图是这样一个形状,数值对不上,凑合着看吧

之后就是如何最小化J(θ)的问题了。下面给出tensorflow的代码实现

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Parameters
learning_rate = 0.05
training_epochs = 2000
display_step = 50
# Training Data
train_X = np.asarray([5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0])
train_Y = train_X - 4
paint_X = np.asarray([-100.0, 100.0])

n_samples = train_X.shape[0]
# tf Graph Input
X = tf.placeholder("float")
Y = tf.placeholder("float")

# Set model weights
W = tf.Variable(-10., name="weight")
b = tf.Variable(10., name="bias")
# Construct a linear model
pred = tf.add(tf.multiply(X, W), b)
# Mean squared error
cost = tf.reduce_sum(tf.pow(pred - Y, 2)) / (2 * n_samples)
# Gradient descent
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
# Initializing the variables
init = tf.global_variables_initializer()
# Launch the graph

plt.figure()
plt.ion()

with tf.Session() as sess:
    sess.run(init)

    # Fit all training data
    for epoch in range(training_epochs):
        for (x, y) in zip(train_X, train_Y):
            sess.run(optimizer, feed_dict={X: x, Y: y})

        # Display logs per epoch step
        if (epoch + 1) % display_step == 0:
            c = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c),
                  "W=", sess.run(W), "b=", sess.run(b))

            plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
            plt.plot(train_X, train_Y, 'ro', label='Original data')
            plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
            plt.pause(0.001)
            plt.clf()

    print("Optimization Finished!")
    training_cost = sess.run(cost, feed_dict={X: train_X, Y: train_Y})
    print("Training cost=", training_cost, "W=", sess.run(W), "b=", sess.run(b), '\n')

    plt.axis([0.0, np.max(train_X) + 1, 0.0, np.max(train_Y) + 1])
    plt.plot(train_X, train_Y, 'ro', label='Original data')
    plt.plot(paint_X, sess.run(W) * paint_X + sess.run(b), label='Fitted line')
    plt.pause(10)

线性回归,附tensorflow实现的更多相关文章

  1. 逻辑回归,附tensorflow实现

    本文旨在通过二元分类问题.多元分类问题介绍逻辑回归算法,并实现一个简单的数字分类程序 在生活中,我们经常会碰到这样的问题: 根据苹果表皮颜色判断是青苹果还是红苹果 根据体温判断是否发烧 这种答案只有两 ...

  2. 简单的线性回归问题-TensorFlow+MATLAB·

    首先我们要试验的是 人体脂肪fat和年龄age以及体重weight之间的关系,我们的目标就是得到一个最优化的平面来表示三者之间的关系: TensorFlow的程序如下: import tensorfl ...

  3. 利用VGG19实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  4. 利用卷积神经网络(VGG19)实现火灾分类(附tensorflow代码及训练集)

    源码地址 https://github.com/stephen-v/tensorflow_vgg_classify 1. VGG介绍 1.1. VGG模型结构 1.2. VGG19架构 2. 用Ten ...

  5. 经典损失函数:交叉熵(附tensorflow)

    每次都是看了就忘,看了就忘,从今天开始,细节开始,推一遍交叉熵. 我的第一篇CSDN,献给你们(有错欢迎指出啊). 一.什么是交叉熵 交叉熵是一个信息论中的概念,它原来是用来估算平均编码长度的.给定两 ...

  6. Tensorflow之单变量线性回归问题的解决方法

    跟着网易云课堂上面的免费公开课深度学习应用开发Tensorflow实践学习,学到线性回归这里感觉有很多需要总结,梳理记录下阶段性学习内容. 题目:通过生成人工数据集合,基于TensorFlow实现y= ...

  7. Tensorflow学习笔记01

    Tensorflow官方网站:http://tensorflow.org/ 极客学院Tensorflow中文版:http://wiki.jikexueyuan.com/project/tensorfl ...

  8. TensorFlow 从零到helloWorld

    目录 1.git安装与使用 1.1 git安装 1.2 修改git bash默认路径 1.3 git常用操作 2.环境搭建   2.1 tensorflow安装   2.2 CUDA安装   2.3 ...

  9. TF linear regression

    本文的作者 Nishant Shukla 为加州大学洛杉矶分校的机器视觉研究者,从事研究机器人机器学习技术.Nishant Shukla 一直以来兼任 Microsoft.Facebook 和 Fou ...

随机推荐

  1. 当谈到 GitLab CI 的时候,我们该聊些什么(上篇)

    "微服务"这个概念近两年非常热,正在慢慢改变 DevOps 的思路.微服务架构把一个庞大的业务系统拆解开来,每一个组件变得更加独立自治.松耦合.但是,同时也伴随着部署单元粒度越来越 ...

  2. 初识SQL Server2017 图数据库(一)

    背景: 图数据库对于表现和遍历复杂的实体之间关系是很有效果的.而这些在传统的关系型数据库中尤其是对于报表而言很难实现.如果把传统关系型数据库比做火车的话,那么到现在大数据时代,图数据库可比做高铁.它已 ...

  3. win10 uwp App-to-app communication 应用通信

    这篇文章都是乱说的,如果觉得有不好的,可以发我邮箱 应用之间需要相互的发送信息,就是我们经常用的分享 有个人看到一个网页很好,于是就希望把这个网页发送到邮件,那么这样的话就是使用应用通信. 因为每个应 ...

  4. day2_python的数据类型,sys,os模块,编码解码,列表,字典

    今天主要了解了python的数据类型,sys,os模块,编码解码,列表,字典 1.数据类型:int(python3没有长整型)文本总是Unicode,str表示二进制用byte类表示布尔型:True( ...

  5. 【转】Linux设备驱动--块设备(一)之概念和框架

    原文地址:Linux设备驱动--块设备(一)之概念和框架 基本概念   块设备(blockdevice) --- 是一种具有一定结构的随机存取设备,对这种设备的读写是按块进行的,他使用缓冲区来存放暂时 ...

  6. eclipse项目中引入shiro-freemarker-tags会jar包冲突

    maven项目中引入了这个依赖. <dependency> <groupId>net.mingsoft</groupId> <artifactId>sh ...

  7. python字符串,列表,字符串,元组,集合的一些方法

    字符串方法 __contains__ #等同in name = 'erroy' result = name.__contains__('er') #判断元素是否包含er print(result) T ...

  8. js验证是否为数字的总结(转)

    作者: 字体:[增加 减小] 类型:转载 时间:2013-04-14我要评论 js验证是否为数字的总结,需要的朋友可以参考一下 js验证是否为数字,最简单的方法: isNaN函数的使用: functi ...

  9. 初入WebService

    搭建webservice需要用到的jar applicationContext.xml配置文件 <?xml version="1.0" encoding="UTF- ...

  10. outline

    a标签 两种button按钮  默认带有一个虚线 outline  当他们被单击 和 激活以后   outline和border 很类似 ,但是有不同 1.outline 不能针对特定的边赋值 ,也就 ...