梯度下降法及一元线性回归的python实现

一、梯度下降法形象解释

  设想我们处在一座山的半山腰的位置,现在我们需要找到一条最快的下山路径,请问应该怎么走?根据生活经验,我们会用一种十分贪心的策略,即在现在所处的位置上找到一个能够保证我们下山最快的方向,然后向着该方向行走;每到一个新位置,重复地应用上述贪心策略,我们就可以顺利到达山底了。其实梯度下降法的运行过程和上述下山的例子没有什么区别,不同的是我们人类可以凭借我们的感官直觉,根据所处的位置来选择最佳的行走方向,而梯度下降法所依据的是严格的数学法则来进行每一步的更新。本文不再对该算法进行严格的数理讨论,只介绍梯度下降法进行数据拟合的流程和利用梯度下降法解决一元线性回归的python实现。

二、梯度下降法算法应用流程

  假设有一组数据X=[x1,x2,x3,...],Y=[y1,y2,y3,...],现求由X到Y的函数关系:

  1、为所需要拟合的数据,构造合适的假设函数:y=f(x;θ),以θ=[θ123,...]为参数;

  2、选择合适的损失函数:cost(θ),用损失函数来衡量假设函数对数据的拟合程度;

  3、设定梯度下降法的学习率 α,参数的优化初始值及迭代终止条件;

  4、迭代更新θ,直到满足迭代终止条件,更新公式为:

    θ11-α*dcost(θ)/dθ1

    θ22-α*dcost(θ)/dθ2,...

三、一元线性回归的python实现

  下面以一个一元线性回归的例子来更进一步理解梯度下降法的过程。笔者通过在函数y=3*x+2的基础之上添加一些服从均匀分布的随机数来构造如下的待拟合数据:X,Y,训练数据图像如下图1所示。假设函数为一元线性函数: y=f(x;θ,k)=θ*x+k,损失函数为:cost(θ,k)=1/2*∑(f(xi;θ,k)-yi),xi属于X,yi属于Y,损失函数的图像如下图2所示。应用梯度下降法进行参数更新的过程如图3中的蓝色圆点所示。 

(1) 

(2)

(3)

  程序源代码如下:

 import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D np.random.seed(1)
#生成样本数据
x=np.arange(-1,1,step=0.04)#自变量
noise=np.random.uniform(low=-0.5,high=0.5,size=50)#噪声
y=x*3+2+noise#因变量
#显示待拟合数据
plt.figure(1)
plt.xlabel('x')
plt.ylabel('y')
plt.scatter(x,y) #假设函数为一元线性函数:y=theta*x+k,需要求解的参数为theta和k
#损失函数为
def cost(theta, k, x, y):
return 1/2*np.mean((theta*x+k-y)**2) def cost_mesh(theta_m, k_m, x, y):
z_m=np.zeros((theta_m.shape[0],theta_m.shape[1]))
for i in range(theta_m.shape[0]):
for j in range(theta_m.shape[1]):
z_m[i,j]=cost(theta_m[i,j], k_m[i,j],x,y)
return z_m
#可视化损失函数
theta_axis=np.linspace(start=0, stop=5,num=50)
k_axis=np.linspace(start=0, stop=5,num=50)
(theta_m, k_m)=np.meshgrid(theta_axis,k_axis)#网格化
z_m=cost_mesh(theta_m, k_m, x, y)
#绘制损失函数的3D图像
fig=plt.figure(2)
ax=Axes3D(fig)#为figure添加3D坐标轴
ax.set_xlabel('theta')
ax.set_ylabel('k')
ax.set_zlabel('cost')
ax.plot_surface(theta_m, k_m, z_m,rstride=1, cstride=1,cmap=plt.cm.hot, alpha=0.5)#绘制3D的表面, rstide为行跨度,cstride为列跨度 #梯度下降法
#参数设置
lr=0.01#学习率
epoches=600#迭代次数,即迭代终止条件 #参数初始数值
theta=0
k=0 #迭代更新参数
for i in range(epoches):
theta_gra=np.mean((theta*x+k-y)*x)#theta梯度
k_gra=np.mean(theta*x+k-y)#k梯度
#更新梯度
theta-=theta_gra*lr
k-=k_gra*lr
#绘制当前参数所在的位置
if i%50==0:
ax.scatter3D(theta, k, cost(theta, k, x,y), marker='o', s=30, c='b')
print('最终的结果为:theta=%f, k=%f'%(theta, k))
plt.show()

梯度下降法及一元线性回归的python实现的更多相关文章

  1. 最小二乘法 及 梯度下降法 运行结果对比(Python版)

    上周在实验室里师姐说了这么一个问题,对于线性回归问题,最小二乘法和梯度下降方法所求得的权重值是一致的,对此我颇有不同观点.如果说这两个解决问题的方法的等价性的确可以根据数学公式来证明,但是很明显的这个 ...

  2. sklearn中实现随机梯度下降法(多元线性回归)

    sklearn中实现随机梯度下降法 随机梯度下降法是一种根据模拟退火的原理对损失函数进行最小化的一种计算方式,在sklearn中主要用于多元线性回归算法中,是一种比较高效的最优化方法,其中的梯度下降系 ...

  3. 最小二乘法 及 梯度下降法 分别对存在多重共线性数据集 进行线性回归 (Python版)

    网上对于线性回归的讲解已经很多,这里不再对此概念进行重复,本博客是作者在听吴恩达ML课程时候偶然突发想法,做了两个小实验,第一个实验是采用最小二乘法对数据进行拟合, 第二个实验是采用梯度下降方法对数据 ...

  4. 梯度下降法的python代码实现(多元线性回归)

    梯度下降法的python代码实现(多元线性回归最小化损失函数) 1.梯度下降法主要用来最小化损失函数,是一种比较常用的最优化方法,其具体包含了以下两种不同的方式:批量梯度下降法(沿着梯度变化最快的方向 ...

  5. Python实现——一元线性回归(梯度下降法)

    2019/3/25 一元线性回归--梯度下降/最小二乘法_又名:一两位小数点的悲剧_ 感觉这个才是真正的重头戏,毕竟前两者都是更倾向于直接使用公式,而不是让计算机一步步去接近真相,而这个梯度下降就不一 ...

  6. 机器学习---用python实现最小二乘线性回归算法并用随机梯度下降法求解 (Machine Learning Least Squares Linear Regression Application SGD)

    在<机器学习---线性回归(Machine Learning Linear Regression)>一文中,我们主要介绍了最小二乘线性回归算法以及简单地介绍了梯度下降法.现在,让我们来实践 ...

  7. 梯度下降法实现最简单线性回归问题python实现

    梯度下降法是非常常见的优化方法,在神经网络的深度学习中更是必会方法,但是直接从深度学习去实现,会比较复杂.本文试图使用梯度下降来优化最简单的LSR线性回归问题,作为进一步学习的基础. import n ...

  8. 简单线性回归(梯度下降法) python实现

    grad_desc .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { bord ...

  9. [Machine Learning] 单变量线性回归(Linear Regression with One Variable) - 线性回归-代价函数-梯度下降法-学习率

    单变量线性回归(Linear Regression with One Variable) 什么是线性回归?线性回归是利用数理统计中回归分析,来确定两种或两种以上变量间相互依赖的定量关系的一种统计分析方 ...

随机推荐

  1. 解决MVC中Model上的特性在EF框架刷新时清空的问题

    MVC中关于前端数据的效验一般都是通过在Model中相关的类上打上特性来实现. 但是在我们数据库发生改变,EF框架需要刷新时会把我们在Model上的特性全部清除,这样的话,我们前端的验证就会失效. 因 ...

  2. JDK-基于Windows环境搭建

    JDK安装: 毋庸置疑你要跑java程序,肯定少不了JDK,如jemter还有还有~ 下载jdk地址1:https://pan.baidu.com/s/1FIvGNvZSy0EpCBxHCz07nA  ...

  3. Executor线程池原理详解

    线程池 线程池的目的就是减少多线程创建的开销,减少资源的消耗,让系统更加的稳定.在web开发中,服务器会为了一个请求分配一个线程来处理,如果每次请求都创建一个线程,请求结束就销毁这个线程.那么在高并发 ...

  4. 非后端开发Mysql日常使用小结

    数据库的五个概念 数据库服务器 数据库 数据表 数据字段 数据行 那么这里下面既是对上面几个概念进行基本的日常操作. 数据库引擎使用 这里仅仅只介绍常用的两种引擎,而InnoDB是从MySQL 5.6 ...

  5. python 报错TypeError: 'range' object does not support item assignment,解决方法

    贴问题 nums = range(5)#range is a built-in function that creates a list of integers print(nums)#prints ...

  6. 程序员需要了解的硬核知识之CPU

    大家都是程序员,大家都是和计算机打交道的程序员,大家都是和计算机中软件硬件打交道的程序员,大家都是和CPU打交道的程序员,所以,不管你是玩儿硬件的还是做软件的,你的世界都少不了计算机最核心的 - CP ...

  7. Android 列表(ListView、RecyclerView)不断刷新最佳实践

    本文微信公众号「AndroidTraveler」首发. 背景 在 Android 列表开发过程中,有时候我们的 Item 会有一些组件,比如倒计时.这类组件要求不断刷新,这个时候由于列表复用的机制,因 ...

  8. 后渗透神器Cobalt Strike的安装

    0x01 简介 Cobalt Strike集成了端口转发.扫描多模式端口监听Windows exe木马,生成Windows dll(动态链接库)木马,生成java木马,生成office宏病毒,生成木马 ...

  9. Vbox中unbuntu15.10与win10共享文件 及开启复制粘贴功能

    学习linux,一直使用的是VMware虚拟机,虽然功能很强大,但总感觉页面切换很麻烦.所以转入Vbox的使用,下面介绍下unbuntu15.10与win10共享文件. 一 共享文件夹 步骤1:启动u ...

  10. 浏览器安装Tampermonkey(俗称油猴子插件),实现免费观看Vip视频、免费下载付费资源等……

    应用场景 说起浏览器,本人常用google,谷歌浏览器,速度快,里面有很多插件,可以实现用户百度云盘下载限制,破解vip视频.百度广告屏蔽,视频广告的屏蔽,百度网盘资源直接下载等实用功能.今天就来分享 ...