[ch04-02] 用梯度下降法解决线性回归问题
系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI,
点击star加星不要吝啬,星越多笔者越努力。
4.2 梯度下降法
有了上一节的最小二乘法做基准,我们这次用梯度下降法求解w和b,从而可以比较二者的结果。
4.2.1 数学原理
在下面的公式中,我们规定x是样本特征值(单特征),y是样本标签值,z是预测值,下标 \(i\) 表示其中一个样本。
预设函数(Hypothesis Function)
为一个线性函数:
\[z_i = x_i \cdot w + b \tag{1}\]
损失函数(Loss Function)
为均方差函数:
\[loss(w,b) = \frac{1}{2} (z_i-y_i)^2 \tag{2}\]
与最小二乘法比较可以看到,梯度下降法和最小二乘法的模型及损失函数是相同的,都是一个线性模型加均方差损失函数,模型用于拟合,损失函数用于评估效果。
区别在于,最小二乘法从损失函数求导,直接求得数学解析解,而梯度下降以及后面的神经网络,都是利用导数传递误差,再通过迭代方式一步一步逼近近似解。
4.2.2 梯度计算
计算z的梯度
根据公式2:
\[
{\partial loss \over \partial z_i}=z_i - y_i \tag{3}
\]
计算w的梯度
我们用loss的值作为误差衡量标准,通过求w对它的影响,也就是loss对w的偏导数,来得到w的梯度。由于loss是通过公式2->公式1间接地联系到w的,所以我们使用链式求导法则,通过单个样本来求导。
根据公式1和公式3:
\[
{\partial{loss} \over \partial{w}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{w}}=(z_i-y_i)x_i \tag{4}
\]
计算b的梯度
\[
\frac{\partial{loss}}{\partial{b}} = \frac{\partial{loss}}{\partial{z_i}}\frac{\partial{z_i}}{\partial{b}}=z_i-y_i \tag{5}
\]
4.2.3 代码实现
if __name__ == '__main__':
reader = SimpleDataReader()
reader.ReadData()
X,Y = reader.GetWholeTrainSamples()
eta = 0.1
w, b = 0.0, 0.0
for i in range(reader.num_train):
# get x and y value for one sample
xi = X[i]
yi = Y[i]
# 公式1
zi = xi * w + b
# 公式3
dz = zi - yi
# 公式4
dw = dz * xi
# 公式5
db = dz
# update w,b
w = w - eta * dw
b = b - eta * db
print("w=", w)
print("b=", b)
大家可以看到,在代码中,我们完全按照公式推导实现了代码,所以,大名鼎鼎的梯度下降,其实就是把推导的结果转化为数学公式和代码,直接放在迭代过程里!另外,我们并没有直接计算损失函数值,而只是把它融入在公式推导中。
4.2.4 运行结果
w= [1.71629006]
b= [3.19684087]
读者可能会注意到,上面的结果和最小二乘法的结果(w1=2.056827, b1=2.965434)相差比较多,这个问题我们留在本章稍后的地方解决。
代码位置
ch04, Level2
[ch04-02] 用梯度下降法解决线性回归问题的更多相关文章
- C / C ++ 基于梯度下降法的线性回归法(适用于机器学习)
写在前面的话: 在第一学期做项目的时候用到过相应的知识,觉得挺有趣的,就记录整理了下来,基于C/C++语言 原贴地址:https://helloacm.com/cc-linear-regression ...
- tensorflow实现svm多分类 iris 3分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)
# Multi-class (Nonlinear) SVM Example # # This function wll illustrate how to # implement the gaussi ...
- tensorflow实现svm iris二分类——本质上在使用梯度下降法求解线性回归(loss是定制的而已)
iris二分类 # Linear Support Vector Machine: Soft Margin # ---------------------------------- # # This f ...
- 机器学习中梯度下降法原理及用其解决线性回归问题的C语言实现
本文讲梯度下降(Gradient Descent)前先看看利用梯度下降法进行监督学习(例如分类.回归等)的一般步骤: 1, 定义损失函数(Loss Function) 2, 信息流forward pr ...
- 梯度下降法及一元线性回归的python实现
梯度下降法及一元线性回归的python实现 一.梯度下降法形象解释 设想我们处在一座山的半山腰的位置,现在我们需要找到一条最快的下山路径,请问应该怎么走?根据生活经验,我们会用一种十分贪心的策略,即在 ...
- 最小二乘法 及 梯度下降法 分别对存在多重共线性数据集 进行线性回归 (Python版)
网上对于线性回归的讲解已经很多,这里不再对此概念进行重复,本博客是作者在听吴恩达ML课程时候偶然突发想法,做了两个小实验,第一个实验是采用最小二乘法对数据进行拟合, 第二个实验是采用梯度下降方法对数据 ...
- 梯度下降法实现最简单线性回归问题python实现
梯度下降法是非常常见的优化方法,在神经网络的深度学习中更是必会方法,但是直接从深度学习去实现,会比较复杂.本文试图使用梯度下降来优化最简单的LSR线性回归问题,作为进一步学习的基础. import n ...
- 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)
在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...
- 简单线性回归(梯度下降法) python实现
grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...
随机推荐
- Python中的is和==的区别
Python中的is和==的区别 1. is 是比较内存地址id() a = "YongJie" b = "YongJie" print(id(a)) #233 ...
- 前端技术之:如何创建一个NodeJs命令行交互项目
方法一:通过原生的NodeJs API,方法如下: #!/usr/bin/env node # test.js var argv = process.argv; console.log(argv) ...
- 在VMware下的Linux中的RAID10校验位算法下的磁盘管理
988年由加利福尼亚大学伯克利分校发表的文章首次提到并定义了RAID,当今CPU性能每年可提升30%-50%但硬盘仅提升7%,渐渐的已经成为计算机整体性能的瓶颈,并且为了避免硬盘的突然损坏导致数据丢失 ...
- CMMS系统中的物联监测
有条件的设备物联后,可时实查看设备运行状态,如发现异常,可提前干预.
- [考试反思]1010csp-s模拟测试67:摸索
嗯...所谓RP守恒? 仍然延续着好一场烂一场的规律. 虽说我也想打破这个规律,但是并不想在考烂之后打破这个规律.(因为下一场要考好???) 我也不知道我现在是什么状态,相较于前一阶段有所提升(第一鸡 ...
- [考试反思]0809NOIP模拟测试15:解剖
说在前面: 不建议阅读.这里没有考试经验,只有一大堆负面情绪. 看了你不会有什么收获.看完了就不要怪我影响了你的心情. 以后不粘排行榜了.没什么意思没什么用. 但是我的意思并不是因为这次没考好的一时兴 ...
- 测试面试题集-测试用例设计:登录、购物车、QQ收藏表情、转账、充值、提现
以下内容首发于微信公众号[ITester软件测试小栈]: 测试面试题集-2.测试用例设计 大家好 我是coco小锦鲤 上周五给大家分享了测试基础理论题 这个周五给大家分享测试用例设计题 测试用例的考察 ...
- jsp页面时间的转换js
/** * 日期 转换为 Unix时间戳 * @param <string> 2014-01-01 20:2 ...
- [springboot 开发单体web shop] 6. 商品分类和轮播广告展示
商品分类&轮播广告 因最近又被困在了OSGI技术POC,更新进度有点慢,希望大家不要怪罪哦. 上节 我们实现了登录之后前端的展示,如: 接着,我们来实现左侧分类栏目的功能. ## 商品分类|P ...
- 插入排序的代码实现(C语言)
void insert_sort(int arr[], int len) { for (int i = 1; i < len; ++i) { if (arr[i] < arr[i - 1] ...