The Impact of Imbalanced Training Data for Convolutional Neural Networks

Paulina Hensman and David Masko

摘要

本论文从实验的角度调研了训练数据的不均衡性对采用CNN解决图像分类问题的性能影响。CIFAR-10数据集包含10个不同类别的60000个图像,用来构建不同类间分布的数据集。例如,一些训练集中包含一个类别的图像数目与其他类别的图像数目比例失衡。用这些训练集分别来训练一个CNN,度量其得到的网络的分类性能。实验结果表明:不均衡的训练数据对CNN的整体性能可能具有严重的负面影响,而均衡的训练数据能产生最好的性能。Oversampling技术在不均衡训练数据上可以将性能提升到均衡数据上的水平,所以它是一种对抗不均衡性的重要技术。

概况

在过去的几年里,由于在诸如机器视觉、语音识别及自然语言处理等几个领域获得重大突破,人工神经网络(Artificial Neural Networks)得到广泛的关注。没有任何先验与假设,这些网络采用统计的方法可以近似大量数据中潜在的函数与模式。DNN(Deep Neural Networks)以及CNN(Convolutional Neural Networks)两类特殊的神经网络是常用来解决复杂问题的现代方法。不利的一面是,为了学习到一个令人满意的神经网络,通常需要大量的数据。对于有监督的学习,还需要大量的标注数据。众所周知,标注数据通常是依赖于人工标注获得,因此获取困难。有一些标注好的图像数据是公开可用的,这些数据为研究与应用人员提供了标准资源,便于比较不同分类方法的,用以证明在该领域取得了一些进展。经验上讲,平衡数据集优于非平衡数据集,然而在真实的情况下,可用的数据集通常是不均衡的。如何处理不均衡数据是机器学习中一个很大的挑战。一些方法能够减轻不均衡数据带来的影响,但是并没有系统的研究结果表明DNN与CNN在标准数据集上如何受不均衡数据的影响。

本文重点研究由于训练数据的类别不均衡带来的CNN分类性能的损失。由此进一步探索:什么类型的分布对性能有损?Oversampling在提升性能方面起多大的作用?具体来讲主要包含以下四个问题:

(1)训练数据中均衡的类别分别对CNN的重要性有多大?

(2)CNN的性能如何受训练数据中不同类别分布的影响?

(3)通过调整训练数据的类别分布能否改善CNN的性能?

(4)有什么可行的方法来实现这种调整?

图像分类是判断给定的图像属于哪一类别的过程,直观来讲,就是图像包含了哪些物体。图像分类主要有两种形式:图像级别标注与对象级别标注。图像级别标注是一个二值变量,用来指示一个对象是否出现在图像上,例如,图像上是否有一只猫。对象级别的标注是具体到对象在图像中出现的位置。例如,螺丝刀中心位于(20,25),宽为50像素,高为30像素。本文关注图像级别的标注。

不均衡数据是指机器学习算法在训练的过程中所采用的数据在不同类别上的分布是不均衡的。由于采用均衡数据学习的算法性能远优于不均衡数据的,所以不均衡数据给分类问题带来了挑战。实际中可用的数据通常是不均衡的。然而,大多数的学习算法假设训练数据是均衡,也同样假设未标注的数据也是类间均衡的。若训练数据的分布于测试集并不相同,这类算法通常会降低性能。进一步来讲,多数算法的目标在于最小化整体的错误率,这会导致训练数据中的小众类由于训练数据少而性能不佳。当小众类非常重要时,这种影响是完全负面的。例如,罕见疾病的诊断。不均衡数据已经得到了广泛的关注,有许多有效的方法可以解决这个问题。

已有提升不均衡数据上的学习性能的方法大致分为三类: (1)sampling techniques;(2)Cost sensitive techniques;(3)One-class learning。采样技术改变原始的数据集,从而创建均衡数据集。简单的采样技术包括oversampling(从小众类中重复采样直至均衡),undersampling(移除over-represented类别的数据)与其他采样技术。然而有研究表明将oversampling与undersampling结合可能是应对极端不均衡数据的方式。

  • budget-sensitive progressive sampling algorithm

训练数据数目n

该采样策略依赖于几个假设:(1)与获取训练数据相比,学习算法的执行代价是可以忽略的,因为在该采样算法中学习算法需要运行多次。当训练数据获取代价高时,这一点是成立的。(2)假设每个类别的获取代价是相同的。这样的话预算数目n与训练实例数是一致的。这个假设大多时候是成立的,但也有例外。如,先前提及的电话数据,获取普通消费者和商业电话的代价是一样的,但是欺诈电话的识别代价是高昂的。

  • combination of cost-sensitive technique and undersampling
实验设置

数据集:选用CIFAR-10,包含10个不同的类别,数据集较小,仅包含60000左右的images(不选择ImageNet的原因),便于做批量的实验,但又不至于任务太简单(如MNIST)

数据集划分:5000 images per category for training and 1000 for testing

类别分布:选择11个不同的类别分布,分别考察其分类性能,每种分布其实都是具有代表性的,毕竟10个类别的分布均衡,是很难量化的一个指标,所以这里只是举出几个典型的例子来说明。在本文中,并没有给出class imbalance的一个明确的量化的定义。

网络结构:use caffe to create and train a CNN

参数设置:3 convolutional layers and 10 output nodes, trained with learning rate 0.001 for 8 epochs + learning rate 0.0001 for 2 epochs, momentum set to 0.9, weight decay to 0.004

测试数据:mean results of three runs

评价指标:the percentage of correct answers for each class,然后再做平均。

实验结果

(1)数据越均衡,分类性能越好

(2)oversampling可以给imbalance 数据带来性能的提升,数据越不均衡提升越明显。

阅读笔记 The Impact of Imbalanced Training Data for Convolutional Neural Networks [DegreeProject2015] 数据分析型的更多相关文章

  1. 论文笔记(Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration)

    这是CVPR 2019的一篇oral. 预备知识点:Geometric median 几何中位数 \begin{equation}\underset{y \in \mathbb{R}^{n}}{\ar ...

  2. 【论文阅读】Learning Dual Convolutional Neural Networks for Low-Level Vision

    论文阅读([CVPR2018]Jinshan Pan - Learning Dual Convolutional Neural Networks for Low-Level Vision) 本文针对低 ...

  3. [CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks论文笔记

    p.p1 { margin: 0.0px 0.0px 0.0px 0.0px; font: 13.0px "Helvetica Neue"; color: #323333 } p. ...

  4. 论文笔记之:Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking

    Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking  arXiv Paper ...

  5. 论文笔记之:Learning Multi-Domain Convolutional Neural Networks for Visual Tracking

    Learning Multi-Domain Convolutional Neural Networks for Visual Tracking CVPR 2016 本文提出了一种新的CNN 框架来处理 ...

  6. [论文阅读] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (MobileNet)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 本文提出的模型叫Mobi ...

  7. 深度学习笔记 (一) 卷积神经网络基础 (Foundation of Convolutional Neural Networks)

    一.卷积 卷积神经网络(Convolutional Neural Networks)是一种在空间上共享参数的神经网络.使用数层卷积,而不是数层的矩阵相乘.在图像的处理过程中,每一张图片都可以看成一张“ ...

  8. Bag of Tricks for Image Classification with Convolutional Neural Networks笔记

    以下内容摘自<Bag of Tricks for Image Classification with Convolutional Neural Networks>. 1 高效训练 1.1 ...

  9. 论文笔记之《Event Extraction via Dynamic Multi-Pooling Convolutional Neural Network》

    1. 文章内容概述 本人精读了事件抽取领域的经典论文<Event Extraction via Dynamic Multi-Pooling Convolutional Neural Networ ...

随机推荐

  1. sp_change_users_login解决孤立用户问题

    孤立帐户,指的是某个数据库的帐户只有用户名而没有登录名,这样的用户在用户库的sysusers系统表中存在,而在master数据库的syslogins中却没有对应的记录. 孤立帐户的产生一般是一下两种: ...

  2. JAVA单向/双向链表的实现

    一.JAVA单向链表的操作(增加节点.查找节点.删除节点) class Link { // 链表类 class Node { // 保存每一个节点,此处为了方便直接定义成内部类 private Str ...

  3. python---hashlib

    简介 用于加密相关的操作,代替了md5模块和sha模块,主要提供SHA1,SHA224,SHA256,SHA384,SHA512,MD5算法. 在python3中已经废弃了md5和sha模块,简单说明 ...

  4. Hadoop学习17--yarn配置篇-内存管理

    这篇文章来自于:董的博客,记录备查 内存管理,主要是管理nodemanager上的物理内存和虚拟内存. YARN允许用户配置每个节点上可用的物理内存资源,注意,这里是“可用的”,因为一个节点上的内存会 ...

  5. 有关项目上潜在需要的移动端GIS系统源码整理,待后续更新

    GPS Tools For Android 前言: GPS数据在做GIS开发时的一份宝贵的数据,在不侵犯他人隐私的情况下通过互联网的模式收集GPS是成本最为低廉的一种模式. 背景: 现在公司在做一个项 ...

  6. load、init和initialize的区别

    在NSObject.h中找到三个方法 + (void)load; + (void)initialize; - (instancetype)init 1. 可知三个方法类型,两个类方法,一个对象方法 2 ...

  7. SpringMVC 中获取所有的路由配置。

    ApplicationContext context = TMSContextLookup.getApplicationContext(); String[] controllerList = con ...

  8. 网络--三种网络通讯方式及Android的网络通讯机制

    Android平台有三种网络接口可以使用,他们分别是:java.net.*(标准Java接口).Org.apache接口和Android.net.*(Android网络接口).下面分别介绍这些接口的功 ...

  9. DirectBuffer

    1.如何分配,分配是哪里的内存 ByteBuffer.allocateDirect()来分配(ByteBuffer.allocate()分配堆内内存),分配的是非Heap(堆外)的内存,不排除操作系统 ...

  10. NGUI之渲染DrawCall的合并

    在Unity中,每次引擎准备数据并通知GPU的过程称为一次Draw Call.Draw Call值越低,会得到更好的渲染性能. (NGUI 查看DrawCall工具(NGUI-OPEN-Draw Ca ...