The Impact of Imbalanced Training Data for Convolutional Neural Networks

Paulina Hensman and David Masko

摘要

本论文从实验的角度调研了训练数据的不均衡性对采用CNN解决图像分类问题的性能影响。CIFAR-10数据集包含10个不同类别的60000个图像,用来构建不同类间分布的数据集。例如,一些训练集中包含一个类别的图像数目与其他类别的图像数目比例失衡。用这些训练集分别来训练一个CNN,度量其得到的网络的分类性能。实验结果表明:不均衡的训练数据对CNN的整体性能可能具有严重的负面影响,而均衡的训练数据能产生最好的性能。Oversampling技术在不均衡训练数据上可以将性能提升到均衡数据上的水平,所以它是一种对抗不均衡性的重要技术。

概况

在过去的几年里,由于在诸如机器视觉、语音识别及自然语言处理等几个领域获得重大突破,人工神经网络(Artificial Neural Networks)得到广泛的关注。没有任何先验与假设,这些网络采用统计的方法可以近似大量数据中潜在的函数与模式。DNN(Deep Neural Networks)以及CNN(Convolutional Neural Networks)两类特殊的神经网络是常用来解决复杂问题的现代方法。不利的一面是,为了学习到一个令人满意的神经网络,通常需要大量的数据。对于有监督的学习,还需要大量的标注数据。众所周知,标注数据通常是依赖于人工标注获得,因此获取困难。有一些标注好的图像数据是公开可用的,这些数据为研究与应用人员提供了标准资源,便于比较不同分类方法的,用以证明在该领域取得了一些进展。经验上讲,平衡数据集优于非平衡数据集,然而在真实的情况下,可用的数据集通常是不均衡的。如何处理不均衡数据是机器学习中一个很大的挑战。一些方法能够减轻不均衡数据带来的影响,但是并没有系统的研究结果表明DNN与CNN在标准数据集上如何受不均衡数据的影响。

本文重点研究由于训练数据的类别不均衡带来的CNN分类性能的损失。由此进一步探索:什么类型的分布对性能有损?Oversampling在提升性能方面起多大的作用?具体来讲主要包含以下四个问题:

(1)训练数据中均衡的类别分别对CNN的重要性有多大?

(2)CNN的性能如何受训练数据中不同类别分布的影响?

(3)通过调整训练数据的类别分布能否改善CNN的性能?

(4)有什么可行的方法来实现这种调整?

图像分类是判断给定的图像属于哪一类别的过程,直观来讲,就是图像包含了哪些物体。图像分类主要有两种形式:图像级别标注与对象级别标注。图像级别标注是一个二值变量,用来指示一个对象是否出现在图像上,例如,图像上是否有一只猫。对象级别的标注是具体到对象在图像中出现的位置。例如,螺丝刀中心位于(20,25),宽为50像素,高为30像素。本文关注图像级别的标注。

不均衡数据是指机器学习算法在训练的过程中所采用的数据在不同类别上的分布是不均衡的。由于采用均衡数据学习的算法性能远优于不均衡数据的,所以不均衡数据给分类问题带来了挑战。实际中可用的数据通常是不均衡的。然而,大多数的学习算法假设训练数据是均衡,也同样假设未标注的数据也是类间均衡的。若训练数据的分布于测试集并不相同,这类算法通常会降低性能。进一步来讲,多数算法的目标在于最小化整体的错误率,这会导致训练数据中的小众类由于训练数据少而性能不佳。当小众类非常重要时,这种影响是完全负面的。例如,罕见疾病的诊断。不均衡数据已经得到了广泛的关注,有许多有效的方法可以解决这个问题。

已有提升不均衡数据上的学习性能的方法大致分为三类: (1)sampling techniques;(2)Cost sensitive techniques;(3)One-class learning。采样技术改变原始的数据集,从而创建均衡数据集。简单的采样技术包括oversampling(从小众类中重复采样直至均衡),undersampling(移除over-represented类别的数据)与其他采样技术。然而有研究表明将oversampling与undersampling结合可能是应对极端不均衡数据的方式。

  • budget-sensitive progressive sampling algorithm

训练数据数目n

该采样策略依赖于几个假设:(1)与获取训练数据相比,学习算法的执行代价是可以忽略的,因为在该采样算法中学习算法需要运行多次。当训练数据获取代价高时,这一点是成立的。(2)假设每个类别的获取代价是相同的。这样的话预算数目n与训练实例数是一致的。这个假设大多时候是成立的,但也有例外。如,先前提及的电话数据,获取普通消费者和商业电话的代价是一样的,但是欺诈电话的识别代价是高昂的。

  • combination of cost-sensitive technique and undersampling
实验设置

数据集:选用CIFAR-10,包含10个不同的类别,数据集较小,仅包含60000左右的images(不选择ImageNet的原因),便于做批量的实验,但又不至于任务太简单(如MNIST)

数据集划分:5000 images per category for training and 1000 for testing

类别分布:选择11个不同的类别分布,分别考察其分类性能,每种分布其实都是具有代表性的,毕竟10个类别的分布均衡,是很难量化的一个指标,所以这里只是举出几个典型的例子来说明。在本文中,并没有给出class imbalance的一个明确的量化的定义。

网络结构:use caffe to create and train a CNN

参数设置:3 convolutional layers and 10 output nodes, trained with learning rate 0.001 for 8 epochs + learning rate 0.0001 for 2 epochs, momentum set to 0.9, weight decay to 0.004

测试数据:mean results of three runs

评价指标:the percentage of correct answers for each class,然后再做平均。

实验结果

(1)数据越均衡,分类性能越好

(2)oversampling可以给imbalance 数据带来性能的提升,数据越不均衡提升越明显。

阅读笔记 The Impact of Imbalanced Training Data for Convolutional Neural Networks [DegreeProject2015] 数据分析型的更多相关文章

  1. 论文笔记(Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration)

    这是CVPR 2019的一篇oral. 预备知识点:Geometric median 几何中位数 \begin{equation}\underset{y \in \mathbb{R}^{n}}{\ar ...

  2. 【论文阅读】Learning Dual Convolutional Neural Networks for Low-Level Vision

    论文阅读([CVPR2018]Jinshan Pan - Learning Dual Convolutional Neural Networks for Low-Level Vision) 本文针对低 ...

  3. [CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks论文笔记

    p.p1 { margin: 0.0px 0.0px 0.0px 0.0px; font: 13.0px "Helvetica Neue"; color: #323333 } p. ...

  4. 论文笔记之:Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking

    Spatially Supervised Recurrent Convolutional Neural Networks for Visual Object Tracking  arXiv Paper ...

  5. 论文笔记之:Learning Multi-Domain Convolutional Neural Networks for Visual Tracking

    Learning Multi-Domain Convolutional Neural Networks for Visual Tracking CVPR 2016 本文提出了一种新的CNN 框架来处理 ...

  6. [论文阅读] MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications (MobileNet)

    论文地址:MobileNets: Efficient Convolutional Neural Networks for Mobile Vision Applications 本文提出的模型叫Mobi ...

  7. 深度学习笔记 (一) 卷积神经网络基础 (Foundation of Convolutional Neural Networks)

    一.卷积 卷积神经网络(Convolutional Neural Networks)是一种在空间上共享参数的神经网络.使用数层卷积,而不是数层的矩阵相乘.在图像的处理过程中,每一张图片都可以看成一张“ ...

  8. Bag of Tricks for Image Classification with Convolutional Neural Networks笔记

    以下内容摘自<Bag of Tricks for Image Classification with Convolutional Neural Networks>. 1 高效训练 1.1 ...

  9. 论文笔记之《Event Extraction via Dynamic Multi-Pooling Convolutional Neural Network》

    1. 文章内容概述 本人精读了事件抽取领域的经典论文<Event Extraction via Dynamic Multi-Pooling Convolutional Neural Networ ...

随机推荐

  1. php和js一起实现倒计时功能

    里获取的php服务端的时间 纯JS是获取客服端时间! <?php //php的时间是以秒算.js的时间以毫秒算 date_default_timezone_set('PRC'); //date_ ...

  2. each循环

    var NA_COUNT=0; var NG_OK_COUNT=0; //获取所有检验明细为同一个编号的下拉选项,看有没有不是N/A的下拉选项 $("#@(Perfix)tbData sel ...

  3. wdcp的安装扩展模块

    其实就是官方包里面的所有附加模块全部支持啦.~~是在官方的基础上修改的优化了每次都解压缩php源码包,按需解压缩使用方法如下wget http://git.oschina.net/loblog/mem ...

  4. [Spring MVC] - 地址路由使用(一)

    常用的一些Spring MVC的路由写法以及参数传递方式. 参考引用: http://docs.spring.io/spring/docs/3.0.x/spring-framework-referen ...

  5. Python 异常机制

    1.异常基础 在编程过程中为了增加友好性,在程序出现bug时一般不会将错误信息显示给用户,而是现实一个提示的页面,通俗来说就是不让用户看见大黄页!!! try: pass # 程序正常执行时做什么操作 ...

  6. SSIS 项目部署模型

    微软 BI 系列随笔 - SSIS 2012 基础 - SSIS 项目部署模型 关于部署 SSIS 2012 支持两种部署模型:项目部署模型和包部署模型. 使用项目部署模型可以将项目部署到 Integ ...

  7. JAVA虚拟机类型转换学习

    Java虚拟机包括血多进行基本类型转换工作的操作码,这些执行转换工作的操作码后面没有操作数,转换的值从栈顶端获得.Java虚拟机从栈顶端弹出一个值,对它进行转换,然后再把转换结果压入栈.进行int.l ...

  8. CRM 403错误

    1 IIS 正常 2 CRM 各项服务正常. 3  应该程序池--CRMAppPool 停止

  9. CSS :hover伪类选择定义和用法

    伪类选择符E:hover的定义和用法: 设置元素在其鼠标悬停时的样式.E元素可以通过其他选择器进行选择,比如使用类选择符.id选择符.类型选择符等等.特别说明:IE6并非不支持此选择符,但能够支持a元 ...

  10. Embed dll Files Within an exe (C# WinForms)—Winform 集成零散dll进exe的方法

    A while back I was working on a small C# WinForms application in Visual Studio 2008. For the sake of ...