大叔学ML第一:梯度下降
原理
梯度下降是一个很常见的通过迭代求解函数极值的方法,当函数非常复杂,通过求导寻找极值很困难时可以通过梯度下降法求解。梯度下降法流程如下:

上图中,用大写字母表示向量,用小写字母表示标量。
假设某人想入坑,他站在某点,他每移动一小步,都朝着他所在点的梯度的负方向移动,这样能保证他尽快入坑,因为某个点的梯度方向是最陡峭的方向(实际上,梯度下降法有时候不是最快的下降方向,比如我们下山时,可能前方遇到一个梁,跨过去是最快的下山方式,而不是绕开,如果是梯度下降法,肯定会绕开。),如下图所示,此图画的不太能表达这个观点,但是懒得盗图了,意会吧:

以下举两个例子,两个例子中的被求函数都很简单,其实直接求导算极值更好,此处仅用来说明梯度下降法的步骤。
实践一:求\(y = x^2 - 4x + 1\)的最小值
# -*- coding: utf-8 -*-
import numpy as np
import matplotlib.pyplot as plt
def descent(p, original_x = 50, steplength = 0.01):
''' gradient descent, return min y '''
deriv = p.deriv(m = 1) # 多项式p的导函数
Y = [] # 保存每次迭代后的y值,方便绘图
count = 0 # 迭代次数
x = original_x # 设置x初始值
d = deriv(x) # x位置的导数
threshold = 0.001 # 阈值,当梯度小于此值时停止迭代
while np.abs(d) > threshold:
x = x - d * steplength
y = p(x)
Y.append(y)
count += 1
d = deriv(x)
plt.plot(np.arange(1, count + 1), Y)
plt.show()
return y
if __name__ == "__main__":
p = np.poly1d([2, -4, 1])
min_y = descent(p)
print(min_y)
把迭代数和对应的函数值绘制出来以查看迭代效果:

实践二:求\(z = x^2 + y^2 + 5\)的最小值
以下代码中,把一组x和y当成一个向量处理,即\(z = X^TX + 5\),其中\(X=[x\ y]^T\)
import numpy as np
import matplotlib.pyplot as plt
def deriv(xy):
dxy = 2 * xy
return dxy
def descent(xy, steplength = 0.01):
''' gradient descent, return min y '''
d = deriv(xy) # x^2 + y^2 + 5的梯度
Y = [] # 保存每次迭代后的y值,方便绘图
count = 0 # 迭代次数
threshold = 0.001 # 阈值,当梯度的模小于此值时停止迭代
while np.linalg.norm(d) > threshold:
xy = xy - d * steplength
y = np.dot(xy, xy) + 5
Y.append(y)
count += 1
d = deriv(xy)
plt.plot(np.arange(1, count + 1), Y)
plt.show()
return Y[-1]
if __name__ == "__main__":
y = descent(np.array([50, 50]))
print(y)
把迭代数和对应的函数值绘制出来以查看迭代效果:

问答时间
Q:无法收敛到某个足够小的函数值,最后报错: overflow ...
A:步长设置太大,步子大了,容易跨过最低点,导致函数值在最低点上下震荡或发散,如图:

可以人为设置迭代次数(而不是通过阈值控制是否继续迭代),然后观察函数值是否收敛:

Q:如何选择合适的步长
A:步长太大会导致函数值不收敛,步长太小又浪费性能,可以通过绘制如上面的迭代次数和函数值关系图,刚才结果后调整步长,尽量选择满足需求的最大步长。达爷在他的网课中给出的建议是:按照这样的序列试验步长:..., 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, ...。通过算法自动预测步长十分复杂,非大叔所能为。
Q:何时停止迭代?
A:可设定一个阈值,当梯度的模长小于这个阈值时停止迭代(当函数接近极值时,梯度接近0)。也可以人为通过刚才迭代次数和函数值图像设定迭代次数。
Q:是否还有其他迭代法?
A:还有牛顿法和拟牛顿法,和梯度下降法的区别是牛顿法不是沿着梯度负方向下降的,而是另一套算法得出的方向,下降速度更快。
Q:迭代法是否一定会找到函数值域内的最小值?
A:不是,如果函数不是一个凸函数,那么迭代法可能会找到一个局部最小值或鞍点值。
Q:函数最大值怎么找
A:给函数取个负号然后找最小值,或者沿着梯度方向前进而不是负梯度方向前进
大叔学ML第一:梯度下降的更多相关文章
- 大叔学ML第二:线性回归
目录 基本形式 求解参数\(\vec\theta\) 梯度下降法 正规方程导法 调用函数库 基本形式 线性回归非常直观简洁,是一种常用的回归模型,大叔总结如下: 设有样本\(X\)形如: \[\beg ...
- 大叔学ML第五:逻辑回归
目录 基本形式 代价函数 用梯度下降法求\(\vec\theta\) 扩展 基本形式 逻辑回归是最常用的分类模型,在线性回归基础之上扩展而来,是一种广义线性回归.下面举例说明什么是逻辑回归:假设我们有 ...
- 大叔学ML第四:线性回归正则化
目录 基本形式 梯度下降法中应用正则化项 正规方程中应用正则化项 小试牛刀 调用类库 扩展 正则:正则是一个汉语词汇,拼音为zhèng zé,基本意思是正其礼仪法则:正规:常规:正宗等.出自<楚 ...
- 大叔学ML第三:多项式回归
目录 基本形式 小试牛刀 再试牛刀 调用类库 基本形式 上文中,大叔说道了线性回归,线性回归是个非常直观又简单的模型,但是很多时候,数据的分布并不是线性的,如: 如果我们想用高次多项式拟合上面的数据应 ...
- ML(附录1)——梯度下降
梯度下降是迭代法的一种,可以用于求解最小二乘问题(线性和非线性都可以).在求解机器学习算法的模型参数,即无约束优化问题时,梯度下降(Gradient Descent)是最常采用的方法之一,另一种常用的 ...
- ML:梯度下降(Gradient Descent)
现在我们有了假设函数和评价假设准确性的方法,现在我们需要确定假设函数中的参数了,这就是梯度下降(gradient descent)的用武之地. 梯度下降算法 不断重复以下步骤,直到收敛(repeat ...
- ML:多变量代价函数和梯度下降(Linear Regression with Multiple Variables)
代价函数cost function 公式: 其中,变量θ(Rn+1或者R(n+1)*1) 向量化: Octave实现: function J = computeCost(X, y, theta) %C ...
- 机器学习(ML)十五之梯度下降和随机梯度下降
梯度下降和随机梯度下降 梯度下降在深度学习中很少被直接使用,但理解梯度的意义以及沿着梯度反方向更新自变量可能降低目标函数值的原因是学习后续优化算法的基础.随后,将引出随机梯度下降(stochastic ...
- 深度学习(二)BP求解过程和梯度下降
一.原理 重点:明白偏导数含义,是该函数在该点的切线,就是变化率,一定要理解变化率. 1)什么是梯度 梯度本意是一个向量(矢量),当某一函数在某点处沿着该方向的方向导数取得该点处的最大值,即函数在该点 ...
随机推荐
- HDU 1522 Marriage is Stable 稳定婚姻匹配
http://acm.hdu.edu.cn/showproblem.php?pid=1522 #include<bits/stdc++.h> #define INF 0x3f3f3f3f ...
- OO_多项式求导_单元总结
概述: 面向对象第一单元的作业是三次难度依次递增的多项式求导.第一次作业是仅包含带符号整数和幂函数的多项式求导,例如:-1+xˆ233-xˆ06:第二次是在前面的基础上增加了三角函数的求导,例如:-1 ...
- Codeforces 1083C Max Mex
Description 一棵\(N\)个节点的树, 每个节点上都有 互不相同的 \([0, ~N-1]\) 的数. 定义一条路径上的数的集合为 \(S\), 求一条路径使得 \(Mex(S)\) 最大 ...
- SpringBoot跨域问题
1.先来说说跨域原理: 跨域原理简单来说就是发起跨域请求的时候,浏览器会对请求域返回的响应信息检查HTTP头,如果Access-Control-Allow-Origin包含了自身域,则允许访问,否则报 ...
- Sublime Text 3安装emmet(ZenCoding)
1.安装 Package Ctrol: 使用 ctrl + - 打开控制台,输入以下代码 import urllib.request,os; pf = 'Package Control.sublime ...
- 父组件传值给子组件的v-model属性
父组件如何修改子组件中绑定的v-model属性 因为v-model属性是双向数据绑定,而vue的通信方式又是单向通信,所以,当子组件想要改变父组件传过来的值的属性时,就会报错,典型的就是父组件传值给子 ...
- 2019浙大校赛--G--Postman(简单思维题)
一个思维水题 题目大意为,一个邮递员要投递N封信,一次从邮局来回只能投递K封.求最短的投递总距离.需注意,最后一次投递后无需返回邮局. 本题思路要点: 1.最后一次投递无需返回邮局,故最后一次投递所行 ...
- 关于数据库连接时URL的问题
最近在写一个简单的增删改查的代码时,遇到保存的中文都会变成问号(?),由于刚开始只是一些数据的保存,所以认为之后只要对数据库的编码进行修改即可,但是后来要对数据进行查找的时候发现根本查找不到, 当时用 ...
- 第四次scrum冲刺
一.第四次Scrum任务 继续上次的任务,完成校园服务中的成绩查询,失物招领,长大集市的功能. 小组的地址链接:https://github.com/Weifeng513/-1/tree/master ...
- python_flask 基础巩固 (URL_FOR 详解)
URL_FOR 详解 url_for 通过 视图函数能够返回对应的url,url_for 有两个参数,endpoint(视图 函数)和关键字参数 url_for('my_list',page=2),多 ...