系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI
点击star加星不要吝啬,星越多笔者越努力。

4.2 梯度下降法

有了上一节的最小二乘法做基准,我们这次用梯度下降法求解w和b,从而可以比较二者的结果。

4.2.1 数学原理

在下面的公式中,我们规定x是样本特征值(单特征),y是样本标签值,z是预测值,下标 \(i\) 表示其中一个样本。

预设函数(Hypothesis Function)

为一个线性函数:

\[z_i = x_i \cdot w + b \tag{1}\]

损失函数(Loss Function)

为均方差函数:

\[loss(w,b) = \frac{1}{2} (z_i-y_i)^2 \tag{2}\]

与最小二乘法比较可以看到,梯度下降法和最小二乘法的模型及损失函数是相同的,都是一个线性模型加均方差损失函数,模型用于拟合,损失函数用于评估效果。

区别在于,最小二乘法从损失函数求导,直接求得数学解析解,而梯度下降以及后面的神经网络,都是利用导数传递误差,再通过迭代方式一步一步逼近近似解。

4.2.2 梯度计算

计算z的梯度

根据公式2:
\[
{\partial loss \over \partial z_i}=z_i - y_i \tag{3}
\]

计算w的梯度

我们用loss的值作为误差衡量标准,通过求w对它的影响,也就是loss对w的偏导数,来得到w的梯度。由于loss是通过公式2->公式1间接地联系到w的,所以我们使用链式求导法则,通过单个样本来求导。

根据公式1和公式3:

\[
{\partial{loss} \over \partial{w}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{w}}=(z_i-y_i)x_i \tag{4}
\]

计算b的梯度

\[
\frac{\partial{loss}}{\partial{b}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{b}}=z_i-y_i \tag{5}
\]

4.2.3 代码实现

if __name__ == '__main__':

    reader = SimpleDataReader()
    reader.ReadData()
    X,Y = reader.GetWholeTrainSamples()

    eta = 0.1
    w, b = 0.0, 0.0
    for i in range(reader.num_train):
        # get x and y value for one sample
        xi = X[i]
        yi = Y[i]
        # 公式1
        zi = xi * w + b
        # 公式3
        dz = zi - yi
        # 公式4
        dw = dz * xi
        # 公式5
        db = dz
        # update w,b
        w = w - eta * dw
        b = b - eta * db

    print("w=", w)
    print("b=", b)

大家可以看到,在代码中,我们完全按照公式推导实现了代码,所以,大名鼎鼎的梯度下降,其实就是把推导的结果转化为数学公式和代码,直接放在迭代过程里!另外,我们并没有直接计算损失函数值,而只是把它融入在公式推导中。

4.2.4 运行结果

w= [1.71629006]
b= [3.19684087]

读者可能会注意到,上面的结果和最小二乘法的结果(w1=2.056827, b1=2.965434)相差比较多,这个问题我们留在本章稍后的地方解决。

代码位置

ch04, Level2

[ch04-02] 用梯度下降法解决线性回归问题的更多相关文章

  1. C / C ++ 基于梯度下降法的线性回归法(适用于机器学习)

    写在前面的话: 在第一学期做项目的时候用到过相应的知识,觉得挺有趣的,就记录整理了下来,基于C/C++语言 原贴地址:https://helloacm.com/cc-linear-regression ...

  2. tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    # Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...

  3. tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)

    iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...

  4. 机器学习中梯度下降法原理及用其解决线性回归问题的C语言实现

    本文讲梯度下降(Gradient Descent)前先看看利用梯度下降法进行监督学习(例如分类.回归等)的一般步骤: 1, 定义损失函数(Loss Function) 2, 信息流forward pr ...

  5. 梯度下降法及一元线性回归的python实现

    梯度下降法及一元线性回归的python实现 一.梯度下降法形象解释 设想我们处在一座山的半山腰的位置,现在我们需要找到一条最快的下山路径,请问应该怎么走?根据生活经验,我们会用一种十分贪心的策略,即在 ...

  6. 最小二乘法 及 梯度下降法 分别对存在多重共线性数据集 进行线性回归 (Python版)

    网上对于线性回归的讲解已经很多,这里不再对此概念进行重复,本博客是作者在听吴恩达ML课程时候偶然突发想法,做了两个小实验,第一个实验是采用最小二乘法对数据进行拟合, 第二个实验是采用梯度下降方法对数据 ...

  7. 梯度下降法实现最简单线性回归问题python实现

    梯度下降法是非常常见的优化方法,在神经网络的深度学习中更是必会方法,但是直接从深度学习去实现,会比较复杂.本文试图使用梯度下降来优化最简单的LSR线性回归问题,作为进一步学习的基础. import n ...

  8. 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)

    在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...

  9. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

随机推荐

  1. UNIX环境高级编程 使用方法

    1.解压文件到apue.2e目录2.修改相应平台的文件,我使用的是linux,所以修改Make.defines.linux你修改的只需要这一行WKDIR=/home/your_dir/apue2e_s ...

  2. 基于jquery,php实现AJAX长轮询(LongPoll),类似推送机制

    HTTP是无状态.单向的协议,用户只能够通过客服端向服务器发送请求并由服务器处理发回一个响应.若要实现聊天室.WEBQQ.在线客服.邮箱等这些即时通讯的应用,就要用到“ 服务器推送技术(Comet)” ...

  3. django & celery - 关于并发处理能力和内存使用的小结

    背景 众所周知,celery 是python世界里处理分布式任务的好助手,它的出现结合赋予了我们强大的处理异步请求,分布式任务,周期任务等复杂场景的能力. 然鹅,今天我们所要讨论的则是如何更好的在使用 ...

  4. HTTP 304状态码的详细讲解

    首先,对于304状态码不应该认为是一种错误,而是对客户端有缓存情况下服务端的一种响应. 客户端在请求一个文件的时候,发现自己缓存的文件有 Last Modified ,那么在请求中会包含 If Mod ...

  5. SpringBoot自定义starter及自动配置

    SpringBoot的核心就是自动配置,而支持自动配置的是一个个starter项目.除了官方已有的starter,用户自己也可以根据规则自定义自己的starter项目. 自定义starter条件 自动 ...

  6. 划艇:dp/组合数/区间离散化

    Description 在首尔城中,汉江横贯东西.在汉江的北岸,从西向东星星点点地分布着 N 个划艇学校,编号依次为 1 到 N.每个学校都拥有若干艘划艇.同一所学校的所有划艇颜色相同,不同的学校的划 ...

  7. 精心整理(含图版)|你要的全拿走!(R数据分析,可视化,生信实战)

    本文首发于“生信补给站”公众号,https://mp.weixin.qq.com/s/ZEjaxDifNATeV8fO4krOIQ更多关于R语言,ggplot2绘图,生信分析的内容,敬请关注小号. 为 ...

  8. Redis 的底层数据结构(压缩列表)

    上一篇我们介绍了 redis 中的整数集合这种数据结构的实现,也谈到了,引入这种数据结构的一个很大的原因就是,在某些仅有少量整数元素的集合场景,通过整数集合既可以达到字典的效率,也能使用远少于字典的内 ...

  9. mysql批量更新写法

    mysql批量更新写法<pre> $namedmp=filter($_POST['namedmp']); $namedsp=filter($_POST['namedsp']); $name ...

  10. 在 ASP.NET Core 项目中使用 MediatR 实现中介者模式

    一.前言  最近有在看 DDD 的相关资料以及微软的 eShopOnContainers 这个项目中基于 DDD 的架构设计,在 Ordering 这个示例服务中,可以看到各层之间的代码调用与我们之前 ...