常用的梯度下降法分为:

  • 批量梯度下降法(Batch Gradient Descent)
  • 随机梯度下降法(Stochastic Gradient Descent)
  • 小批量梯度下降法(Mini-Batch Gradient Descent)

简单的算法示例

数据

x = np.random.uniform(-3,3,100)
X = x.reshape(-1,1)
y = x * 2 + 5 + np.random.normal(0, 1, 100)

BGD

批量梯度下降法的简单实现:

def gradient_descent(X_b, y, initial_theta, eta, n_iters=1e4, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) def dj(theta):
return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y)) theta = initial_theta
for i in range(1, int(n_iters)):
gradient = dj(theta) # 获得梯度
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出 return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) ## array([4.72619109, 3.08239321])

SGD

这里n_iters表示将所有数据迭代的轮数。

def s_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) # 这是随机梯度下降的,随机一个样本的梯度
def dj_sgd(X_b_i, y_i, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return 2 * X_b_i.T.dot(X_b_i.dot(theta) - y_i) theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_sgd(X_b[j,:], y[j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time s_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) ## array([4.72619109, 3.08239321])

MBGD

在随机梯度下降的基础上,对dj做了一点点修改,batch_size指定批量的大小,dj每次计算batch_size个样本的梯度并取平均值。

不得不说,同样是迭代一轮数据,小批量梯度下降法的准确度要比随机梯度下降法高多了。

def b_gradient_descent(X_b, y, initial_theta, eta, batch_size=10, n_iters=10, epsilon=1e-8):
def J(theta):
return np.mean((X_b.dot(theta) - y) ** 2) # 这是小批量梯度下降的,随机一个样本的梯度
def dj_bgd(X_b_b, y_b, theta):
# return X_b.T.dot((X_b.dot(theta) - y)) * (2 / len(y))
return X_b_b.T.dot(X_b_b.dot(theta) - y_b) * (2 / len(y_b)) theta = initial_theta
for i in range(0, int(n_iters)):
for j in range(batch_size, len(y), batch_size):
gradient = dj_bgd(X_b[j-batch_size:j,:], y[j-batch_size:j], theta)
last_theta = theta
theta = theta - eta * gradient # 迭代梯度
if np.absolute(J(theta) - J(last_theta)) < epsilon:
break # 满足条件就跳出
return theta

结果是:

X_b = np.hstack([np.ones((len(y), 1)), X])
initial_theta = np.ones(X_b.shape[1])
eta = 0.1
%time b_gradient_descent(X_b, y, initial_theta, eta, n_iters=1) array([4.4649369 , 2.27164876])

三种梯度下降法的对比(BGD & SGD & MBGD)的更多相关文章

  1. 三种梯度下降算法的区别(BGD, SGD, MBGD)

    前言 我们在训练网络的时候经常会设置 batch_size,这个 batch_size 究竟是做什么用的,一万张图的数据集,应该设置为多大呢,设置为 1.10.100 或者是 10000 究竟有什么区 ...

  2. 各种优化器对比--BGD/SGD/MBGD/MSGD/NAG/Adagrad/Adam

    指数加权平均 (exponentially weighted averges) 先说一下指数加权平均, 公式如下: \[v_{t}=\beta v_{t-1}+(1-\beta) \theta_{t} ...

  3. python笔记-20 django进阶 (model与form、modelform对比,三种ajax方式的对比,随机验证码,kindeditor)

    一.model深入 1.model的功能 1.1 创建数据库表 1.2 操作数据库表 1.3 数据库的增删改查操作 2.创建数据库表的单表操作 2.1 定义表对象 class xxx(models.M ...

  4. iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现

    1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...

  5. iOS- NSThread/NSOperation/GCD 三种多线程技术的对比及实现 -- 转

    1.iOS的三种多线程技术 1.NSThread 每个NSThread对象对应一个线程,量级较轻(真正的多线程) 2.以下两点是苹果专门开发的“并发”技术,使得程序员可以不再去关心线程的具体使用问题 ...

  6. 几种梯度下降方法对比(Batch gradient descent、Mini-batch gradient descent 和 stochastic gradient descent)

    https://blog.csdn.net/u012328159/article/details/80252012 我们在训练神经网络模型时,最常用的就是梯度下降,这篇博客主要介绍下几种梯度下降的变种 ...

  7. Dynamics CRM2016 查询数据的三种方式的性能对比

    之前写过一个博客,对非声明验证方式下连接组织服务的两种方式的性能进行了对比,但当时只是对比了实例化组织服务的时间,并没有对查询数据的时间进行对比,那有朋友也在我的博客中留言了反映了查询的时间问题,一直 ...

  8. 两个Map的对比,三种方法,将对比结果写入文件。

    三种方法的思维都是遍历一个map的Key,然后2个Map分别取这2个Key值所得到的Value. #第一种用entry private void compareMap(Map<String, S ...

  9. java对象头信息和三种锁的性能对比

    java头的信息分析 首先为什么我要去研究java的对象头呢? 这里截取一张hotspot的源码当中的注释 这张图换成可读的表格如下 |-------------------------------- ...

随机推荐

  1. 基于 Apache Hudi 和DBT 构建开放的Lakehouse

    本博客的重点展示如何利用增量数据处理和执行字段级更新来构建一个开放式 Lakehouse. 我们很高兴地宣布,用户现在可以使用 Apache Hudi + dbt 来构建开放Lakehouse. 在深 ...

  2. grep使用常用操作十五条

    grep的全部使用语法参照grep --help,日常工作常用的语法如下:构造数据如下:test001.txt与test002.txt 一.在单个文件中查询指定字符串 grep abc test01/ ...

  3. bat-CSV文件转MD文件

    目录 1. bat文件里面写死文件名 2. 拖入文件 1. bat文件里面写死文件名 @echo off & setlocal enabledelayedexpansion SET filep ...

  4. 区块相隔虽一线,俱在支付同冶熔,Vue3.0+Tornado6前后端分离集成Web3.0之Metamask区块链虚拟三方支付功能

    最近几年区块链技术的使用外延持续扩展,去中心化的节点认证机制可以大幅度改进传统的支付结算模式的经营效率,降低交易者的成本并提高收益.但不能否认的是,区块链技术也存在着极大的风险,所谓身怀利器,杀心自起 ...

  5. KingbaseES 如何查看应用执行的SQL的执行计划

    通过explain ,我们可以获取特定SQL 的执行计划.但对于同一条SQL,不同的变量.不同的系统负荷,其执行计划可能不同.我们要如何取得SQL执行时间点的执行计划?KingbaseES 提供了 a ...

  6. Netty使用手册翻译

    前言 痛点 时至今日,我们通常会使用应用程序或第三方库去提供通信功能.比如:我们通常使用HTTP客户端库去Web服务器检索信息;通过web服务调用一个远程程序.然而,一个通用协议或者它的实现往往不能适 ...

  7. 操作服务器的神奇工具Tmux

    Tmux 是什么? 会话与进程 命令行的典型使用方式是,打开一个终端窗口(terminal window,以下简称"窗口"),在里面输入命令.用户与计算机的这种临时的交互,称为一次 ...

  8. vue3+three.js实现疫情可视化

    前言 自成都九月份以来疫情原因被封了一两周,居家着实无聊,每天都是盯着微信公众号发布的疫情数据看,那种页面,就我一个前端仔来说,看着是真的丑啊!(⊙_⊙)?既然丑,那就自己动手开整!项目是2022.9 ...

  9. 深度剖析Istio共享代理新模式Ambient Mesh

    摘要:今年9月份,Istio社区宣布Ambient Mesh开源,由此引发国内外众多开发者的热烈讨论. 本文分享自华为云社区<深度剖析!Istio共享代理新模式Ambient Mesh>, ...

  10. 升级Gogs版本

    今天早上收到阿里云发的报警短信,大致内容如下: 前提分析: 公司代码代码仓库使用是Gogs搭建的,版本是0.11.34,二进制方式安装的,连接的是其他主机上的MySQL数据库,因此被检测到有这个漏洞 ...