在机器学习中,学习的目标是选择期望风险\(R_{exp}\)(expected loss)最小的模型,但在实际情况下,我们不知道数据的真实分布(包含已知样本和训练样本),仅知道训练集上的数据分布。因此,我们的目标转化为最小化训练集上的平均损失,这也被称为经验风险\(R_{emp}\)(empirical loss)。

严格地说,我们应该计算所有训练数据的损失函数的总和,以此来更新模型参数(Batch Gradient Descent)。但随着数据集的不断增大,以ImagNet数据集为例,该数据集的数据量有百万之多,计算所有数据的损失函数之和显然是不现实的。若采用计算单个样本的损失函数更新参数的方法(Stochastic Gradient Descent),会导致\(R_{emp}\)难以达到最小值,而且在数值处理上不能使用向量化的方法提高运算速度

于是,我们采取一种折衷的想法,即取一部分数据,作为全部数据的代表,让神经网络从这每一批数据中学习,这里的“一部分数据”称为mini-batch,这种方法称为mini-batch学习。

以下图为例,蓝色的线表示Batch Gradient Descent,紫色的线表示Stochastic Gradient Descent,绿色的线表示Mini-Batch Gradient Descent。

从上图可以看出,Mini-Batch相当于结合了Batch Gradient Descent和Stochastic Gradient Descent各自的优点,既能利用向量化方法提高运算速度,又能基本接近全局最小值。

对于mini-batch学习的介绍到此为止。下面我们将MINIST数据集上的分类问题作为背景,以交叉熵cross-entropy损失函数为例,来实现一下mini-bacth版的cross-entropy error。

给出cross-entropy error的定义如下:

\[E = - \sum_{k}t_k \log(y_k)\tag{1}
\]

其中\(y_k\)表示神经网络输出,\(t_k\)表示正确解标签。

等式1表示的是针对单个数据的损失函数,现在我们给出在mini-batch下的损失函数,如下

\[E = -\frac{1}{N}\sum_{n}\sum_{k}t_{nk}\log(y_{nk})\tag{2}
\]

其中N表示这一部分数据的数量,\(t_{nk}\)表示第n个数据在第k个元素的值(\(y_{nk}\)表示神经网络输出,\(t_{nk}\)表示监督数据)

我们来看一下用Python如何实现mini-batch版的cross-entropy error。针对监督数据\(t_{nk}\)的标签形式是否为one-hot,我们分类讨论处理。

此外,需要明确的一点是,对于一个分类神经网络,最后一层经过softmax函数处理后,输出\(y_{nk}\)是一个\(n\)x\(k\)的矩阵,\(y_{ij}\)表示第i个数据被预测为\(j(0 \leq j\leq10)\)的概率,特别地,当\(N=1\)时,\(y\)是一个包含10个元素的向量,类似于[0.1,0.2...0.3],其中0.1表示输入数据预测为0的概率为0.1,0.2表示将输入数据预测为1的概率为0.2,其他情况以此类推。

首先,对于\(t_{nk}\)为one-hot表示的情况,代码块1如下

def cross_entropy_error(y,t):
batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7)) / batch_size

在上面的代码中,我们在y上加了一个微小值,防止出现np.log(0)的情况,因为np.log(0)会变成负无穷大-inf,从而导致后续的计算无法继续进行。在等式2中\(y_{nk}\)与\(t_{nk}\)下标相同,所以我们直接使用*做element-wise运算,即对应元素相乘。

但当我们希望同时能够处理单个数据和批量数据时,代码块1还不能满足我们的要求。因为当\(N=1\)时,\(y\)是一个包含10个元素的一维向量,输入到函数中,batch_size将等于10而不是1,于是我们将代码块1进行进一步完善,如下:

def cross_entropy_error(y,t):
if y.ndim == 1:
y = y.reshape(1,y.size)
t = t.reshape(1,t.size) batch_size = y.shape[0]
return -np.sum(t * np.log(y + 1e-7)) / batch_size

最后,来讨论一下\(t_{nk}\)为非one-hot表示的情况。在one-hot情况的计算中,t为0的元素cross-entropy error也为0,所以对于这些元素的计算可以忽略。换言之,在非one-hot表示的情况下,我们只需要计算正确解标签的交叉熵误差即可。代码如下:

def cross_entropy_error(y,t):
if y.ndim == 1:
y = y.reshape(1,y.size)
t = t.reshape(1,t.size) batch_size = y.shape[0]
return -np.sum(1 * np.log(y[np.arange(batch_size),t]+1e-7))/batch_size

在上面的代码中,y[np.arange(batch_size),t]表示将从神经网络的输出中抽出与正确解标签相对应的元素。

参考文献

[1] 深度学习入门

[2] DeepLearning.ai深度学习课程笔记

[3] 统计学习方法

Learning with Mini-Batch的更多相关文章

  1. 转载: scikit-learn学习之K-means聚类算法与 Mini Batch K-Means算法

    版权声明:<—— 本文为作者呕心沥血打造,若要转载,请注明出处@http://blog.csdn.net/gamer_gyt <—— 目录(?)[+] ================== ...

  2. 聚类K-Means和大数据集的Mini Batch K-Means算法

    import numpy as np from sklearn.datasets import make_blobs from sklearn.cluster import KMeans from s ...

  3. Deep Learning 27:Batch normalization理解——读论文“Batch normalization: Accelerating deep network training by reducing internal covariate shift ”——ICML 2015

    这篇经典论文,甚至可以说是2015年最牛的一篇论文,早就有很多人解读,不需要自己着摸,但是看了论文原文Batch normalization: Accelerating deep network tr ...

  4. knn/kmeans/kmeans++/Mini Batch K-means/Affinity Propagation/Mean Shift/层次聚类/DBSCAN 区别

    可以看出来除了KNN以外其他算法都是聚类算法 1.knn/kmeans/kmeans++区别 先给大家贴个简洁明了的图,好几个地方都看到过,我也不知道到底谁是原作者啦,如果侵权麻烦联系我咯~~~~ k ...

  5. Deep Learning中的Large Batch Training相关理论与实践

    背景 [作者:DeepLearningStack,阿里巴巴算法工程师,开源TensorFlow Contributor] 在分布式训练时,提高计算通信占比是提高计算加速比的有效手段,当网络通信优化到一 ...

  6. Deep learning:四十八(Contractive AutoEncoder简单理解)

    Contractive autoencoder是autoencoder的一个变种,其实就是在autoencoder上加入了一个规则项,它简称CAE(对应中文翻译为?).通常情况下,对权值进行惩罚后的a ...

  7. Deep learning:四十二(Denoise Autoencoder简单理解)

    前言: 当采用无监督的方法分层预训练深度网络的权值时,为了学习到较鲁棒的特征,可以在网络的可视层(即数据的输入层)引入随机噪声,这种方法称为Denoise Autoencoder(简称dAE),由Be ...

  8. Machine Learning Algorithms Study Notes(2)--Supervised Learning

    Machine Learning Algorithms Study Notes 高雪松 @雪松Cedro Microsoft MVP 本系列文章是Andrew Ng 在斯坦福的机器学习课程 CS 22 ...

  9. Coursera Deep Learning 2 Improving Deep Neural Networks: Hyperparameter tuning, Regularization and Optimization - week2, Assignment(Optimization Methods)

    声明:所有内容来自coursera,作为个人学习笔记记录在这里. 请不要ctrl+c/ctrl+v作业. Optimization Methods Until now, you've always u ...

  10. 图像分类(二)GoogLenet Inception_v2:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

    Inception V2网络中的代表是加入了BN(Batch Normalization)层,并且使用 2个 3*3卷积替代 1个5*5卷积的改进版,如下图所示: 其特点如下: 学习VGG用2个 3* ...

随机推荐

  1. Vue常用组件,,,持续更新中

    1.vue-lazyload :图片懒加载 2.vue-virtual-scroll-list 和 vue-virtual-scroller: 优化无限列表的场景 3.vue-table-with-t ...

  2. 一种基于Modbus的工业通信网关设计

    近年来,随着工业自动化领域的发展,工业现场对网络的可靠性及成本有极高的要求.传统基于串口的工业网关可以满足工业现场的应用,但却要付出高额成本.一种基于 ModBus 设计的工业通信网关就走进人们的眼中 ...

  3. nchu第二次面向对象编程博客作业

    前言:   本次博客包含的内容有pta题目集4(四边形).5(五边形)以及期中考试三次题目集.其中第四次和第五次题目集难度较大,比较复杂,涉及的知识点也比较多.而期中考试由于是在课堂上完成,难度较小, ...

  4. labuladong数据结构

    缓存淘汰算法:LRU①.LFU② BST③ 完全二叉树④ 序列化和反序列化二叉树⑤ 最近公共祖先⑥ 单调栈⑦ 单调队列⑧ 递归反转链表⑨ k个一组反转链表

  5. eclipse创建基于Web的Maven项目

    用于方便查找 以下是原博主链接 https://www.cnblogs.com/sam-uncle/p/8676529.html 如何使用maven搭建web项目呢? 第一步:首先创建一个maven项 ...

  6. 注释中的Unicode编码也会被转义

    现象 public class Unicode { public static void main(String[] args) { // \u000d System.out.println(&quo ...

  7. idea 部署项目到 docker 运行

    1.在远程服务器上开启 docker 远程连接 $vim /usr/lib/systemd/system/docker.service # ExecStart=/usr/bin/dockerd -H ...

  8. Windows下配置Hadoop的Java开发环境

    最近在学习用java来编写MapReduce程序,我是先在windows中开发完成,运行没有问题之后,再打成jar包,放到Linux集群中运行,由于在配置windows的开发环境的时候就花了大半天的时 ...

  9. SQL Server datetime类型为null的有趣实验

    @data1 --变量 测试用 @data2 --当前时间 当@data1为null 则格式转换错误 直接控制台什么也不显示 也不报错 当定义'' 显示默认时间

  10. MyBatis-Plus数据源失效找不到

    记一次项目排查问题项目应用了MyBatis-Plus多数据源配置但是在执行定时任务时发现没达到想要的结果于是查询日志分析问题最终发现问题所在多数据源注解会合事务注解冲突导致失效@DS("&q ...