基于cifar数据集合成含开集、闭集噪声的数据集
前言
噪声标签学习下的一个任务是:训练集上存在开集噪声和闭集噪声;然后在测试集上对闭集样本进行分类。
训练集中被加入的开集样本,会被均匀得打上闭集样本的标签充当开集噪声;而闭集噪声的设置与一般的噪声标签学习一致,分为对称噪声:随机将闭集样本的标签替换为其他类别;和非对称噪声:将闭集样本的标签替换为特定的类别。
论文实验中,常用cifar数据集模拟这类任务。目前已知有两类方法:
第一类基于cifar100,将100个类的一部分,通常是20个类作为开集样本,将它们标签替换了前80个类作为开集噪声;然后对于后续80个类,选择部分样本设置为对称/非对称闭集噪声。CVPR2022的PNP: Robust Learning From Noisy Labels by Probabilistic Noise Prediction提供的代码中,使用了这种方法。但是,如果要考虑非对称噪声,在cifar10上就很难实现,cifar10的类的顺序不像cifar100那样有规律,不好设置闭集噪声。
第二类方法适用cifar10和cifar100,保持原始数据集的样本数不变,使用额外的数据集(通常是imagenet32、places365)代替部分样本作为开集噪声,对于剩下的非开集噪声样本再设置闭集噪声。ECCV2022的Embedding contrastive unsupervised features to cluster in-and out-of-distribution noise in corrupted image datasets提供的代码使用了这种方式。
places365可以使用torchvision.datasets.Places365下载,由于训练集较大,通常是用它的验证集作为辅助数据集。
imagenet32是imagnet的32x32版本,同样是1k类,但是类的具体含义的顺序与imagenet不同,imagenet32类的具体含义可见这里。image32下载地址在对应论文A downsampled variant of imagenet as an alternative to the cifar datasets提供的链接。
接下来是用第二种方法,辅助数据集使用imagenet32,基于cifar构造含开集闭集噪声的训练集。
实验
设计imagenet32数据集
import os
import pickle
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
_train_list = ['train_data_batch_1',
'train_data_batch_2',
'train_data_batch_3',
'train_data_batch_4',
'train_data_batch_5',
'train_data_batch_6',
'train_data_batch_7',
'train_data_batch_8',
'train_data_batch_9',
'train_data_batch_10']
_val_list = ['val_data']
def get_dataset(transform_train, transform_test):
# prepare datasets
# Train set
train = Imagenet32(train=True, transform=transform_train) # Load all 1000 classes in memory
# Test set
test = Imagenet32(train=False, transform=transform_test) # Load all 1000 test classes in memory
return train, test
class Imagenet32(Dataset):
def __init__(self, root='~/data/imagenet32', train=True, transform=None):
if root[0] == '~':
root = os.path.expanduser(root)
self.transform = transform
size = 32
# Now load the picked numpy arrays
if train:
data, labels = [], []
for f in _train_list:
file = os.path.join(root, f)
with open(file, 'rb') as fo:
entry = pickle.load(fo, encoding='latin1')
data.append(entry['data'])
labels += entry['labels']
data = np.concatenate(data)
else:
f = _val_list[0]
file = os.path.join(root, f)
with open(file, 'rb') as fo:
entry = pickle.load(fo, encoding='latin1')
data = entry['data']
labels = entry['labels']
data = data.reshape((-1, 3, size, size))
self.data = data.transpose((0, 2, 3, 1)) # Convert to HWC
labels = np.array(labels) - 1
self.labels = labels.tolist()
def __getitem__(self, index):
img, target = self.data[index], self.labels[index]
img = Image.fromarray(img)
if self.transform is not None:
img = self.transform(img)
return img, target, index
def __len__(self):
return len(self.data)
目录结构:
imagenet32
├─ train_data_batch_1
├─ train_data_batch_10
├─ train_data_batch_2
├─ train_data_batch_3
├─ train_data_batch_4
├─ train_data_batch_5
├─ train_data_batch_6
├─ train_data_batch_7
├─ train_data_batch_8
├─ train_data_batch_9
└─ val_data
设计cifar数据集
import torchvision
import numpy as np
from dataset.imagenet32 import Imagenet32
class CIFAR10(torchvision.datasets.CIFAR10):
def __init__(self, root='~/data', train=True, transform=None,
r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):
nb_classes = 10
self.nb_classes = nb_classes
super().__init__(root, train=train, transform=transform)
if train is False:
return
np.random.seed(seed)
if r_ood > 0.:
ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
if corruption == 'imagenet':
imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
else:
raise ValueError(f'Unknown corruption: {corruption}')
self.ids_ood = ids_ood
self.data[ids_ood] = img_ood
if r_id > 0.:
ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
for i, t in enumerate(self.targets):
if i in ids_id:
self.targets[i] = int(np.random.random() * nb_classes)
self.ids_id = ids_id
class CIFAR100(torchvision.datasets.CIFAR100):
def __init__(self, root='~/data', train=True, transform=None,
r_ood=0.2, r_id=0.2, seed=0, corruption='imagenet', ):
nb_classes = 100
self.nb_classes = nb_classes
super().__init__(root, train=train, transform=transform)
if train is False:
return
np.random.seed(seed)
if r_ood > 0.:
ids_ood = [i for i in range(len(self.targets)) if np.random.random() < r_ood]
if corruption == 'imagenet':
imagenet32 = Imagenet32(root='~/data/imagenet32', train=True)
img_ood = imagenet32.data[np.random.permutation(range(len(imagenet32)))[:len(ids_ood)]]
else:
raise ValueError(f'Unknown corruption: {corruption}')
self.ids_ood = ids_ood
self.data[ids_ood] = img_ood
if r_id > 0.:
ids_not_ood = [i for i in range(len(self.targets)) if i not in ids_ood]
ids_id = [i for i in ids_not_ood if np.random.random() < (r_id / (1 - r_ood))]
for i, t in enumerate(self.targets):
if i in ids_id:
self.targets[i] = int(np.random.random() * nb_classes)
self.ids_id = ids_id
查看统计结果
import pandas as pd
import altair as alt
from dataset.cifar import CIFAR10, CIFAR100
# Initialize CIFAR10 dataset
cifar10 = CIFAR10(r_imb=0.)
cifar100 = CIFAR100(r_imb=0.)
def statistics_samples(dataset):
ids_ood = dataset.ids_ood
ids_id = dataset.ids_id
# Collect statistics
statistics = []
for i in range(dataset.nb_classes):
statistics.append({
'class': i,
'id': 0,
'ood': 0,
'clear': 0
})
for i, t in enumerate(dataset.targets):
if i in ids_ood:
statistics[t]['ood'] += 1
elif i in ids_id:
statistics[t]['id'] += 1
else:
statistics[t]['clear'] += 1
df = pd.DataFrame(statistics)
# Melt the DataFrame for Altair
df_melt = df.melt(id_vars='class', var_name='type', value_name='count')
# Create the bar chart
chart = alt.Chart(df_melt).mark_bar().encode(
x=alt.X('class:O', title='Classes'),
y=alt.Y('count:Q', title='Sample Count'),
color='type:N'
)
return chart
chart1 = statistics_samples(cifar10)
chart2 = statistics_samples(cifar100)
chart1 = chart1.properties(
title='cifar10',
width=100, # Adjust width to fit both charts side by side
height=400
)
chart2 = chart2.properties(
title='cifar100',
width=800,
height=400
)
combined_chart = alt.hconcat(chart1, chart2).configure_axis(
labelFontSize=12,
titleFontSize=14
).configure_legend(
titleFontSize=14,
labelFontSize=12
)
combined_chart
运行环境
# Name Version Build Channel
altair 5.3.0 pypi_0 pypi
pytorch 2.3.1 py3.12_cuda12.1_cudnn8_0 pytorch
pandas 2.2.2 pypi_0 pypi
基于cifar数据集合成含开集、闭集噪声的数据集的更多相关文章
- 机器学习数据集,主数据集不能通过,人脸数据集介绍,从r包中获取数据集,中国河流数据集
机器学习数据集,主数据集不能通过,人脸数据集介绍,从r包中获取数据集,中国河流数据集 选自Microsoft www.tz365.Cn 作者:Lee Scott 机器之心编译 参与:李亚洲.吴攀. ...
- R_Studio(决策树算法)鸢尾花卉数据集Iris是一类多重变量分析的数据集【精】
鸢尾花卉数据集Iris是一类多重变量分析的数据集 通过花萼长度,花萼宽度,花瓣长度,花瓣宽度4个属性预测鸢尾花卉属于(Setosa,Versicolour,Virginica)三个种类中的哪一类 针对 ...
- 基于用户的最近邻协同过滤算法(MovieLens数据集)
基于用户的最近邻算法(User-Based Neighbor Algorithms),是一种非概率性的协同过滤算法,也是推荐系统中最最古老,最著名的算法. 我们称那些兴趣相似的用户为邻居,如果用户 ...
- java 实现基于opencv全景图合成
因项目需要,自己做了demo,从中学习很多,所以分享出来,希望有这方面需求的少走一些弯路,opencv怎么安装网上教程多多,这里不加详细说明,我安装的opencv-3.3.0 如上图所示,找到相应的j ...
- Pytorch文本分类(imdb数据集),含DataLoader数据加载,最优模型保存
用pytorch进行文本分类,数据集为keras内置的imdb影评数据(二分类),代码包含六个部分(详见代码) 使用环境: pytorch:1.1.0 cuda:10.0 gpu:RTX2070 (1 ...
- pyTorch 基于以resnet50为backbone的PSPNet 训练VOC2012数据集
代码链接:https://github.com/ggyyzm/pytorch_segmentation 使用PSPNet作为主干分类网络 1.将VOC2012数据集下载并解压到data/VOCtrai ...
- MyBatis操作指南-搭建项目基础环境(基于Java API)含log4j2配置
- 基于tensorflow的bilstm_crf的命名实体识别(数据集是msra命名实体识别数据集)
github地址:https://github.com/taishan1994/tensorflow-bilstm-crf 1.熟悉数据 msra数据集总共有三个文件: train.txt:部分数据 ...
- Windows10+YOLOv3实现检测自己的数据集(1)——制作自己的数据集
本文将从以下三个方面介绍如何制作自己的数据集 数据标注 数据扩增 将数据转化为COCO的json格式 参考资料 一.数据标注 在深度学习的目标检测任务中,首先要使用训练集进行模型训练.训练的数据集好坏 ...
- cifar数据集介绍及到图像转换的实现
CIFAR是一个用于普通物体识别的数据集.CIFAR数据集分为两种:CIFAR-10和CIFAR-100.The CIFAR-10 and CIFAR-100 are labeled subsets ...
随机推荐
- python教程6.5-excel处理模块
第三方开源模块安装 创建文件 打开已有文件 写数据 选择表 保存表 遍历表 按行遍历 按列遍历 遍历指定行列 遍历指定第几列数据 删除表 设置单元格样式 字体 对齐 设置行高列宽
- .net core 微信支付-微信小程序支付(服务端C#代码)
前言 前段时间研究了下微信支付-小程序支付的功能.但微信支付文档中关于.net C#的语言的sdk没有,只有java go 和php版本的,当然社区也有很多已经集成好的微信支付.net core sd ...
- 『手撕Vue-CLI』处理不同指令
前言 在上一篇『手撕Vue-CLI』添加自定义指令中,已经实现了自定义指令的添加,但是指令还是比较简单的,只是简单的打印一句话,那么在实际运用场景中,可能会有更多的需求,比如可能需要在指令中传递参数, ...
- Django 视图views的基本使用
在 Django 中,视图函数是一个 Python 函数或者类,开发者主要通过编写视图函数来实现业务逻辑.视图函数首先接受来自浏览器或者客户端的请求,并最终返回响应,视图函数返回的响应可以是 HTML ...
- golang beego 使用supervisor 部署后台进程管理. 静态文件找不到的解决办法.
directory=/root/go/src/you_self_dir 请在客户端配置文件*.ini中加入一行命令, 等于号后面就是自己的项目目录,这时就能找到项目文件了.
- 提速15%,PaddleOCRSharp新版v4.3发布
PaddleOCRSharp v4.3版本,已经于5月23日发布.该版本的发布,在不影响识别精度的同时,带来了10%~15%速度的提升. 项目地址:https://gitee.com/raoyutia ...
- NET框架下如何使用PaddleOCRSharp
打开VSIDE,新建Windows窗体应用(.NET Framework)类型的项目,选择一个.NET框架,如.NET Framework 4.0,右键点击项目,选择属性>生成,目标平台设置成X ...
- NOIP模拟49
虚伪的眼泪,会伤害别人,虚伪的笑容,会伤害自己. 前言 暑假集训过后的第一次考试,成绩一般,没啥好说的 T1 Reverse 解题思路 看到这个题的第一眼就感觉是最短路,毕竟题目的样子就好像之前做过的 ...
- C++笔记(12) 标准模板库STL
STL提供了一组表示容器.迭代器.函数.函数对象和算法的模板.STL不是面向对象的编程,而是一种不同的编程模式--泛型编程. 容器:与数组类似的单元,可以存储若干个值,存储的值的类型相同: 算法:完成 ...
- Linux设备驱动--阻塞与非阻塞I/O
注:本文是<Linux设备驱动开发详解:基于最新的Linux 4.0内核 by 宋宝华 >一书学习的笔记,大部分内容为书籍中的内容. 书籍可直接在微信读书中查看:Linux设备驱动开发详解 ...