简单的采样策略

首先介绍三种简单采样策略:

  1. Instance-balanced sampling, 实例平衡采样。
  2. Class-balanced sampling, 类平衡采样。
  3. Square-root sampling, 平方根采样。

它们可抽象为:

\[p_j=\frac{n_j^q}{\sum_{i=1}^Cn_i^q},
\]

\(p_j\)表示从j类采样数据的概率;\(C\)表示类别数量;\(n_j\)表示j类样本数;\(q\in\{1,0,\frac{1}{2}\}\)

Instance-balanced sampling

最常见的数据采样方式,其中每个训练样本被选择的概率相等(\(q=1\))。j类被采样的概率\(p^{\mathbf{IB}}_j\)与j类样本数\(n_j\)成正比,即\(p^{\mathbf{IB}}_j=\frac{n_j}{\sum_{i=1}^Cn_i}\)。

Class-balanced sampling

实例平衡采样在不平衡的数据集中往往表现不佳,类平衡采样让所有的类有相同的被采样概率:\(p^{\mathbf{CB}}_j=\frac{1}{C}\)。采样可分为两个阶段:1. 从类集中统一选择一个类;2. 对该类中的实例进行统一采样。

Square-root sampling

平方根采样最常见的变体,\(q=\frac{1}{2}\)

由于这三种采样策略都是调整类别的采样概率(权重),因此可用PyTorch提供的WeightedRandomSampler实现:

import numpy as np
from torch.utils.data.sampler import WeightedRandomSampler
def get_sampler(sampling_type, targets):
cls_counts = np.bincount(targets)
if sampling_type == 'instance-balanced':
cls_weights = cls_counts / np.sum(cls_counts) elif sampling_type == 'class-balanced':
cls_num = len(cls_counts)
cls_weights = [1. / cls_num] * cls_num elif sampling_type == 'square-root':
sqrt_and_sum = np.sum([num**0.5 for num in cls_counts])
cls_weights = [num**0.5 / sqrt_and_sum for num in cls_counts]
else:
raise ValueError('sampling_type should be instance-balanced, class-balanced or square-root') cls_weights = np.array(cls_weights)
return WeightedRandomSampler(cls_weights[targets], len(targets), replacement=True)

WeightedRandomSampler,第一个参数表示每个样本的权重,第二个参数表示采样的样本数,第三个参数表示是否有放回采样。

在模拟的长尾数据集测试下:

import torch
from torch.utils.data import Dataset, DataLoader
torch.manual_seed(0)
np.random.seed(0)
class LongTailDataset(Dataset):
def __init__(self, num_classes, max_samples_per_class):
self.num_classes = num_classes
self.max_samples_per_class = max_samples_per_class # Generate number of samples for each class inversely proportional to class index
self.samples_per_class = [self.max_samples_per_class // (i + 1) for i in range(self.num_classes)]
self.total_samples = sum(self.samples_per_class) # Generate targets for the dataset
self.targets = torch.cat([torch.full((samples,), i, dtype=torch.long) for i, samples in enumerate(self.samples_per_class)]) def __len__(self):
return self.total_samples def __getitem__(self, idx):
# For simplicity, just return the index as the data
return idx, self.targets[idx] # Parameters
num_classes = 25
max_samples_per_class = 1000 # Create dataset
dataset = LongTailDataset(num_classes, max_samples_per_class) # Create dataloader
batch_size = 64
sampler1 = get_sampler('instance-balanced', dataset.targets.numpy())
sampler2 = get_sampler('class-balanced', dataset.targets.numpy())
sampler3 = get_sampler('square-root', dataset.targets.numpy())
dataloader1 = DataLoader(dataset, batch_size=64, sampler=sampler1)
dataloader2 = DataLoader(dataset, batch_size=64, sampler=sampler2)
dataloader3 = DataLoader(dataset, batch_size=64, sampler=sampler3) for (_, target1), (_, target2), (_, target3) in zip(dataloader1, dataloader2, dataloader3):
print('Instance-balanced:')
cls_idx, cls_counts = np.unique(target1.numpy(), return_counts=True)
print(f'Class indices: {cls_idx}')
print(f'Class counts: {cls_counts}')
print('-'*20)
print('Class-balanced:')
cls_idx, cls_counts = np.unique(target2.numpy(), return_counts=True)
print(f'Class indices: {cls_idx}')
print(f'Class counts: {cls_counts}')
print('-'*20)
print('Square-root:')
cls_idx, cls_counts = np.unique(target3.numpy(), return_counts=True)
print(f'Class indices: {cls_idx}')
print(f'Class counts: {cls_counts}')
break # just show one batch

Output:

Instance-balanced:
Class indices: [ 0 1 2 3 5 16 22 23]
Class counts: [43 9 5 2 2 1 1 1]
--------------------
Class-balanced:
Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 20 21 23]
Class counts: [21 8 6 4 2 1 2 2 3 3 1 2 1 1 1 1 2 1 1 1]
--------------------
Square-root:
Class indices: [ 0 1 2 3 4 5 6 9 10 21 22 23]
Class counts: [37 8 3 6 3 1 1 1 1 1 1 1]

混合采样策略

最早的混合采样是在 \(0\le epoch\le t\)时采用Instance-balanced采样,\(t\le epoch\le T\)时采用Class-balanced采样,这需要设置合适的超参数t。在[1]中,作者提出了soft版本的混合采样策略:Progressively-balanced sampling。随着epoch的增加每个类的采样概率(权重)\(p_j\)也发生变化:

\[p_j^{\mathbf{PB}}(t)=(1-\frac tT)p_j^{\mathbf{IB}}+\frac tTp_j^{\mathbf{CB}}
\]

t表示当前epoch,T表示总epoch数。

不平衡数据集下的采样策略

不平衡的数据集,特别是长尾数据集,为了照顾尾部类,通常设置每个类的采样概率(权重)为样本数的倒数,即\(p_j=\frac{1}{n_j}\)。

...
elif sampling_type == 'inverse':
cls_weights = 1. / cls_counts
...

在[3]中提出了有效数(effective number)的概念,分母的位置不是简单的样本数,而是经过一定计算得到的,这里直接给出结果,证明请详见原论文。关于effective number的计算方式:

\[E_n=(1-\beta^n)/(1-\beta),\ \mathrm{where~}\beta=(N-1)/N.
\]

这里N表示数据集样本总数。

相关代码:

...
elif sampling_type == 'effective':
beta = (len(targets) - 1) / len(targets)
cls_weights = (1.0 - beta) / (1.0 - np.power(beta, cls_counts))
...

Output

Effective number:
Class indices: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 16 17 18 20 21 22 23 24]
Class counts: [2 1 2 3 1 1 4 2 3 4 4 2 3 5 2 4 1 3 1 4 5 6 1]

在和上面一样的模拟长尾数据集上,采样的结果更加均衡。

参考文献

  1. Kang, Bingyi, et al. "Decoupling Representation and Classifier for Long-Tailed Recognition." International Conference on Learning Representations. 2019.
  2. torch.utils.data.WeightedRandomSampler
  3. Cui, Yin, et al. "Class-balanced loss based on effective number of samples." Proceedings of the IEEE/CVF conference on computer vision and pattern recognition. 2019.

机器学习常见的sampling策略 附PyTorch实现的更多相关文章

  1. 常见性能优化策略的总结 good

    阅读目录 代码 数据库 缓存 异步 NoSQL JVM调优 多线程与分布式 度量系统(监控.报警.服务依赖管理) 案例一:商家与控制区关系的刷新job 案例二:POI缓存设计与实现 案例三:业务运营后 ...

  2. AI - 机器学习常见算法简介(Common Algorithms)

    机器学习常见算法简介 - 原文链接:http://usblogs.pwc.com/emerging-technology/machine-learning-methods-infographic/ 应 ...

  3. 吴恩达《深度学习》-第三门课 结构化机器学习项目(Structuring Machine Learning Projects)-第一周 机器学习(ML)策略(1)(ML strategy(1))-课程笔记

    第一周 机器学习(ML)策略(1)(ML strategy(1)) 1.1 为什么是 ML 策略?(Why ML Strategy?) 希望在这门课程中,可以教给一些策略,一些分析机器学习问题的方法, ...

  4. 深度学习训练过程中的学习率衰减策略及pytorch实现

    学习率是深度学习中的一个重要超参数,选择合适的学习率能够帮助模型更好地收敛. 本文主要介绍深度学习训练过程中的6种学习率衰减策略以及相应的Pytorch实现. 1. StepLR 按固定的训练epoc ...

  5. [Machine Learning] 机器学习常见算法分类汇总

    声明:本篇博文根据http://www.ctocio.com/hotnews/15919.html整理,原作者张萌,尊重原创. 机器学习无疑是当前数据分析领域的一个热点内容.很多人在平时的工作中都或多 ...

  6. paper 12:机器学习常见算法分类汇总

    机器学习无疑是当前数据分析领域的一个热点内容.很多人在平时的工作中都或多或少会用到机器学习的算法.这里南君先生为您总结一下常见的机器学习算法,以供您在工作和学习中参考. 机器学习的算法很多.很多时候困 ...

  7. mysql常见安全加固策略

    原创 2017年01月17日 21:36:50 标签: 数据库 / mysql / 安全加固 5760 常见Mysql配置文件:linux系统下是my.conf,windows环境下是my.ini: ...

  8. 10 种机器学习算法的要点(附 Python 和 R 代码)

    本文由 伯乐在线 - Agatha 翻译,唐尤华 校稿.未经许可,禁止转载!英文出处:SUNIL RAY.欢迎加入翻译组. 前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关 ...

  9. 10 种机器学习算法的要点(附 Python)(转载)

    一.前言 谷歌董事长施密特曾说过:虽然谷歌的无人驾驶汽车和机器人受到了许多媒体关注,但是这家公司真正的未来在于机器学习,一种让计算机更聪明.更个性化的技术 也许我们生活在人类历史上最关键的时期:从使用 ...

  10. 机器学习常见的几种评价指标:精确率(Precision)、召回率(Recall)、F值(F-measure)、ROC曲线、AUC、准确率(Accuracy)

    原文链接:https://blog.csdn.net/weixin_42518879/article/details/83959319 主要内容:机器学习中常见的几种评价指标,它们各自的含义和计算(注 ...

随机推荐

  1. C++ //栈 stack 容器 先进后出 不允许遍历

    1 //栈 stack 容器 先进后出 不允许遍历 2 3 4 #include<iostream> 5 #include<stack> 6 7 using namespace ...

  2. liquibase customChange

    liquibase customChange liquibase changeset 执行Java代码. liquibase支持yml等文件,支持引入sql文件,还支持Java这种方式执行change ...

  3. XAF Blazor FilterPanel 布局样式

    从上一篇关于ListView布局样式的文章中,我们知道XAFBlazor是移动优先的,如果想在PC端有更好的用户体验,我们需要对布局样式进行修改.这篇介绍在之前文章中提到的FilterPanel,它的 ...

  4. Zabbix6.0使用教程 (二)—zabbix6.0常用术语

    上一次我们已经详细介绍了zabbix6.0的新增功能,本篇我们来说说zabbix6.0常用的一些术语,这个对小伙伴日常使用zabbix的时候还是非常有用,建议大家收藏起来,话不多说,附上干货. 概览 ...

  5. linux 无法找到“/usr/bin/core_perl/gcc” vscode

    解决问题的思路 查看有没有gcc,没有安装 有的话就是,修改安装路径就可以? "/usr/bin/core_perl/gcc".修改成Gcc的绝对路径 我的修改是./usr/bin ...

  6. 基于wifi的音频采集及处理解决方案小结

    一沉浮    这些年,一直围绕着音频来做案子,做出来的案子自己都数不清楚了.记得前几年,刚出道的时候,就把wifi音频传输的设备做出来了.可惜的是,当初太超前市场了,鲜有人问.随着时间的推移,在疫情之 ...

  7. 关于初始化page入参的设计思路

    最近在重构老的代码,在写的过程中发现之前的逻辑如果遇到没有入参pageNo会Npe,于是乎我想找找公司项目有啥方式处理page入参的有两种如下 使用三元表达式直接判断是否null,然后赋值 使用map ...

  8. DL基础补全计划(四)---对抗过拟合:权重衰减、Dropout

    PS:要转载请注明出处,本人版权所有. PS: 这个只是基于<我自己>的理解, 如果和你的原则及想法相冲突,请谅解,勿喷. 前置说明   本文作为本人csdn blog的主站的备份.(Bl ...

  9. PAT 甲级【1015 Reversible Primes】

    考察素数判断 考察进制转换 import java.io.IOException; import java.io.InputStreamReader; import java.io.StreamTok ...

  10. Linux 运维工程师面试真题-1-必会Linux 操作系统知识

    Linux 运维工程师面试真题-1-必会Linux 操作系统知识 运维的整个面试流程其实是非常繁杂的,为了方便大家准备,我们特地在这里给大家整理了 一些 Linux 系统运维相关的面试题,有些问题没有 ...