前言

我们在训练网络的时候经常会设置 batch_size,这个 batch_size 究竟是做什么用的,一万张图的数据集,应该设置为多大呢,设置为 1、10、100 或者是 10000 究竟有什么区别呢?

# 手写数字识别网络训练方法
network.fit(
train_images,
train_labels,
epochs=5,
batch_size=128)

批量梯度下降(Batch Gradient Descent,BGD)

梯度下降算法一般用来最小化损失函数:把原始的数据网络喂给网络,网络会进行一定的计算,会求得一个损失函数,代表着网络的计算结果与实际的差距,梯度下降算法用来调整参数,使得训练出的结果与实际更好的拟合,这是梯度下降的含义。

批量梯度下降是梯度下降最原始的形式,它的思想是使用所有的训练数据一起进行梯度的更新,梯度下降算法需要对损失函数求导数,可以想象,如果训练数据集比较大,所有的数据需要一起读入进来,一起在网络中去训练,一起求和,会是一个庞大的矩阵,这个计算量将非常巨大。当然,这也是有优点的,那就是因为考虑到所有训练集的情况,因此网络一定在向最优(极值)的方向在优化。

随机梯度下降(Stochastic Gradient Descent,SGD)

与批量梯度下降不同,随机梯度下降的思想是每次拿出训练集中的一个,进行拟合训练,进行迭代去训练。训练的过程就是先拿出一个训练数据,网络修改参数去拟合它并修改参数,然后拿出下一个训练数据,用刚刚修改好的网络再去拟合和修改参数,如此迭代,直到每个数据都输入过网络,再从头再来一遍,直到参数比较稳定,优点就是每次拟合都只用了一个训练数据,每一轮更新迭代速度特别快,缺点是每次进行拟合的时候,只考虑了一个训练数据,优化的方向不一定是网络在训练集整体最优的方向,经常会抖动或收敛到局部最优。

小批量梯度下降(Mini-Batch Gradient Descent,MBGD)

小批量梯度下降采用的还是计算机中最常用的折中的解决办法,每次输入网络进行训练的既不是训练数据集全体,也不是训练数据集中的某一个,而是其中的一部分,比如每次输入 20 个。可以想象,这既不会造成数据量过大计算缓慢,也不会因为某一个训练样本的某些噪声特点引起网络的剧烈抖动或向非最优的方向优化。

对比一下这三种梯度下降算法的计算方式:批量梯度下降是大矩阵的运算,可以考虑采用矩阵计算优化的方式进行并行计算,对内存等硬件性能要求较高;随机梯度下降每次迭代都依赖于前一次的计算结果,因此无法并行计算,对硬件要求较低;而小批量梯度下降,每一个次迭代中,都是一个较小的矩阵,对硬件的要求也不高,同时矩阵运算可以采用并行计算,多次迭代之间采用串行计算,整体来说会节省时间。

看下面一张图,可以较好的体现出三种剃度下降算法优化网络的迭代过程,会有一个更加直观的印象。

总结

梯度下降算法的调优,训练数据集很小,直接采用批量梯度下降;每次只能拿到一个训练数据,或者是在线实时传输过来的训练数据,采用随机梯度下降;其他情况或一般情况采用批量梯度下降算法更好。

  • 本文首发自: RAIS

三种梯度下降算法的区别(BGD, SGD, MBGD)的更多相关文章

  1. 三种梯度下降法的对比(BGD & SGD & MBGD)

    常用的梯度下降法分为: 批量梯度下降法(Batch Gradient Descent) 随机梯度下降法(Stochastic Gradient Descent) 小批量梯度下降法(Mini-Batch ...

  2. Java语言----三种循环语句的区别

    ------- android培训.java培训.期待与您交流! ---------- 第一种:for循环 循环结构for语句的格式:       for(初始化表达式;条件表达式;循环后的操作表达式 ...

  3. png、jpg、gif三种图片格式的区别

    png.jpg.gif三种图片格式的区别   2014-06-17 为什么想整理这方面的类容,我觉得就像油画家要了解他的颜料和画布.雕塑家要了解他的石材一样,作为网页设计师也应该对图片格式的特性有一定 ...

  4. VMware 三种网络模式的区别

    VMware 三种网络模式的区别 VMware 三种网络模式的区别 我们首先说一下VMware的几个虚拟设备 VMnet0:用于虚拟桥接网络下的虚拟交换机 VMnet1:用于虚拟Host-Only网络 ...

  5. JavaScript:学习笔记(7)——VAR、LET、CONST三种变量声明的区别

    JavaScript:学习笔记(7)——VAR.LET.CONST三种变量声明的区别 ES2015(ES6)带来了许多闪亮的新功能,自2017年以来,许多JavaScript开发人员已经熟悉并开始使用 ...

  6. 转:VMware中三种网络连接的区别

    转自:http://www.cnblogs.com/rainman/archive/2013/05/06/3063925.html VMware中三种网络连接的区别   1.概述 2.bridged( ...

  7. (转)VMware虚拟机三种网络模式的区别及配置方法;

    我的一点实际经验理解桥接和NAT 桥接是虚拟机完全作为一个独立的地址接在局域网中,NAT是虚拟机依赖宿主主机地址转换的一种方式 例子我的虚拟机如果用桥接模式,连接外部网站如百度时会提示此pc没有装公司 ...

  8. 各种梯度下降 bgd sgd mbgd adam

    转载  https://blog.csdn.net/itchosen/article/details/77200322 各种神经网络优化算法:从梯度下降到Adam方法     在调整模型更新权重和偏差 ...

  9. JdbcTemplate查询数据 三种callback之间的区别

    JdbcTemplate针对数据查询提供了多个重载的模板方法,你可以根据需要选用不同的模板方法. 如果你的查询很简单,仅仅是传入相应SQL或者相关参数,然后取得一个单一的结果,那么你可以选择如下一组便 ...

随机推荐

  1. Unity GameObject

    GameObject 游戏对象 GameObject是unity所有实体的基类 gameObject 获取当前脚本所挂载的游戏对象 一般来说,在属性视图中能看到或修改的属性,我们同样可以在脚本中获取并 ...

  2. 解决因缺少驱动程序,导致“未在本地计算机上注册microsoft.ace.12.0”异常

    写了一个winform程序,功能是选择一个excel表格,把里面的内容写进sqlite数据库中,在本地测试没问题,但是在其他电脑上就会报错"未在本地计算机上注册microsoft.ace.1 ...

  3. 移动端 Swiper

    一.什么是swiper 开源.免费.强大的触摸滑动插件 Swiper常用于移动端网站的内容触摸滑动 Swiper能实现触屏焦点图.触屏Tab切换.触屏多图切换等常用效果 #二.如何使用 1.首先加载插 ...

  4. CTF练习 ①

    最近学校要打比赛,,,把我这个混子也给算上了,,不得不赶紧学习学习. 今天学习的是SQL注入的一道题,参考的文章是  https://blog.csdn.net/qq_42939527/article ...

  5. 安装nodejs 版本控制器

    安装下载地址: https://pan.baidu.com/s/1Ed_IPDTOHxR9NShUEau-ZA 下载好后,放在安装nodejs的文件夹下 然后敲cmd,进入安装nodejs的文件夹下. ...

  6. 不一样的资产安全 3D 可视化平台

    前言   数字经济时代,应用好数据是企业数字化转型的关键,基于前沿科学技术进行数据的有效管控,更是对数字增值服务的新趋势.近年来,整个安全行业对资产管理的重视程度正在提高.据IDC发布的相关数据显示, ...

  7. Visual Studio 2013中安装Resharper之后一些快捷键无法使用,比如F6和F12

    快捷键是一个很好用的东西,尤其对于计算机从业者来说,好的快捷键能够高程度提高工作效率.像我们程序员经常需要团队开发,我们会遇到一个问题,那就是快捷键不一致问题,我一般会安装resharper,但是有的 ...

  8. Java Int类型与字符,汉字之间的转换

    /** * java 中的流主要是分为字节流和字符流 * 再一个角度分析的话可以分为输入流和输出流 * 输入和输出是一个相对的概念 相对的分别是jvm虚拟机的内存大小 * 从另一个角度讲Java或者用 ...

  9. Oracle数据库常见sql

    -新建表:create table table_name( id varchar2(300) primary key, name varchar2(200) not null); --插入数据 ins ...

  10. 事务的概念,以及事务在JDBC编程中处理事务的步骤

    事务是作为单个逻辑工作单元执行的一系列操作,一个逻辑工作单元必须有四个属性,称为原子性.一致性.隔离性和持久性 (ACID) 属性,只有这样才能成为一个事务 .JDBC处理事务有如下操作: 1,con ...