常用的梯度下降法分为:

  • 批量梯度下降法(Batch Gradient Descent)
  • 随机梯度下降法(Stochastic Gradient Descent)
  • 小批量梯度下降法(Mini-Batch Gradient Descent)

简单的算法示例

数据

x = np.random.uniform(-3,3,100)
X = x.reshape(-1,1)
y = x * 2 + 5 + np.random.normal(0, 1, 100)

BGD

批量梯度下降法的简单实现:

def gradient_descent(X_b, y, initial_theta, eta, n_iters=1e4, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) def dj(theta):
return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y)) theta = initial_theta
for i in range(1, int(n_iters)):
gradient = dj(theta) # 获得梯度
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出 return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) ## array([4.72619109, 3.08239321])

SGD

这里n_iters表示将所有数据迭代的轮数。

def s_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) # 这是随机梯度下降的,随机一个样本的梯度
def dj_sgd(X_b_i, y_i, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return 2 * X_b_i.T.dot(X_b_i.dot(theta) - y_i) theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_sgd(X_b[j,:], y[j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) ## array([4.72619109, 3.08239321])

MBGD

在随机梯度下降的基础上,对dj做了一点点修改,batch_size指定批量的大小,dj每次计算batch_size个样本的梯度并取平均值。

不得不说,同样是迭代一轮数据,小批量梯度下降法的准确度要比随机梯度下降法高多了。

def b_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) # 这是小批量梯度下降的,随机一个样本的梯度
def dj_bgd(X_b_b, y_b, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return X_b_b.T.dot(X_b_b.dot(theta) - y_b) * (2 / len(y_b)) theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_bgd(X_b[j-batch_size:j,:], y[j-batch_size:j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time b_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) array([4.4649369 , 2.27164876])

三种梯度下降法的对比(BGD & SGD & MBGD)的更多相关文章

  1. 三种梯度下降算法的区别(BGD, SGD, MBGD)

    前言 我们在训练网络的时候经常会设置 batch_size,这个 batch_size 究竟是做什么用的,一万张图的数据集,应该设置为多大呢,设置为 1.10.100 或者是 10000 究竟有什么区 ...

  2. 各种优化器对比--BGD/SGD/MBGD/MSGD/NAG/Adagrad/Adam

    指数加权平均 (exponentially weighted averges) 先说一下指数加权平均, 公式如下: \[v_{t}=\beta v_{t-1}+(1-\beta) \theta_{t} ...

  3. python笔记-20 django进阶 (model与form、modelform对比,三种ajax方式的对比,随机验证码,kindeditor)

    一.model深入 1.model的功能 1.1 创建数据库表 1.2 操作数据库表 1.3 数据库的增删改查操作 2.创建数据库表的单表操作 2.1 定义表对象 class xxx(models.M ...

  4. iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现

    1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...

  5. iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现 -- 转

    1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...

  6. 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)

    https://blog.csdn.net/u012328159/article/details/80252012 我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种 ...

  7. Dynamics CRM2016 查询数据的三种方式的性能对比

    之前写过一个博客,对非声明验证方式下连接组织服务的两种方式的性能进行了对比,但当时只是对比了实例化组织服务的时间,并没有对查询数据的时间进行对比,那有朋友也在我的博客中留言了反映了查询的时间问题,一直 ...

  8. 两个Map的对比,三种方法,将对比结果写入文件。

    三种方法的思维都是遍历一个map的Key,然后2个Map分别取这2个Key值所得到的Value. #第一种用entry private void compareMap(Map<String, S ...

  9. java对象头信息和三种锁的性能对比

    java头的信息分析 首先为什么我要去研究java的对象头呢? 这里截取一张hotspot的源码当中的注释 这张图换成可读的表格如下 |-------------------------------- ...

随机推荐

  1. 技术专家说 | 如何基于 Spark 和 Z-Order 实现企业级离线数仓降本提效?

    [点击了解更多大数据知识] 市场的变幻,政策的完善,技术的革新--种种因素让我们面对太多的挑战,这仍需我们不断探索.克服. 今年,网易数帆将持续推出新栏目「金融专家说」「技术专家说」「产品专家说」等, ...

  2. APT 安装 MySQL 提示错误:dpkg: error: dpkg frontend lock is locked by another process

    在安装 MySQL 的时候提示错误: ubuntu@VM-0-6-ubuntu:/opt$ sudo dpkg -i mysql-apt-config_0.8.22-1_all.deb dpkg: e ...

  3. RabbitMQ 入门系列:2、基础含义理解:链接、通道、队列、交换机

    系列目录 RabbitMQ 入门系列:1.MQ的应用场景的选择与RabbitMQ安装. RabbitMQ 入门系列:2.基础含义:链接.通道.队列.交换机. RabbitMQ 入门系列:3.基础含义: ...

  4. linux --stdin 管道 标准输入重定向

    linux --stdin 标准输入重定向 --stdin This option is used to indicate that passwd should read the new passwo ...

  5. ARC122D XOR Game(博弈论?字典树,贪心)

    题面 ARC122D XOR Game 黑板上有 2 N 2N 2N 个数,第 i i i 个数为 A i A_i Ai​. O I D \rm OID OID(OneInDark) 和 H I D ...

  6. 轻量级RTSP服务和内置RTSP网关有什么不同?

    好多开发者疑惑,什么是内置RTSP网关,和轻量级RTSP服务又有什么区别和联系?本文就以上问题,做个简单的介绍: 轻量级RTSP服务 为满足内网无纸化/电子教室等内网超低延迟需求,避免让用户配置单独的 ...

  7. 阿里云CentOS7安装K8S

    1. 在阿里云山申请三台云服务器 1.1 环境准备 完成配置后的信息 服务器IP 操作系统 CPU 内存 硬盘 主机名 节点角色 172.18.119.145 centos7 2 4G 50G k8s ...

  8. 手把手教你君正X2000开发板的OpenHarmony环境搭建

    摘要:本文主要介绍基于君正X2000开发板的OpenHarmony环境搭建以及简单介绍网络配置情况 本文分享自华为云社区<君正X2000开发板的OpenHarmony环境搭建>,作者: 星 ...

  9. jenkins流水线部署springboot应用到k8s集群(k3s+jenkins+gitee+maven+docker)(1)

    前言:前面写过2篇文章,介绍jenkins通过slave节点部署构建并发布应用到虚拟机中,本篇介绍k8s(k3s)环境下,部署jenkins,通过流水线脚本方式构建发布应用到k8s(k3s)集群环境中 ...

  10. ProxySQL(1):简介和安装

    文章转载自:https://www.cnblogs.com/f-ck-need-u/p/9278818.html ProxySQL有两个版本:官方版和percona版,percona版是在官方版的基础 ...