Pytorch 类别平衡化处理
采用 WeightedRandomSampler:
def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
N = float(sum(count))
for i in range(nclasses):
weight_per_class[i] = N/float(count[i])
weight = [0] * len(images)
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
dataset_train = datasets.ImageFolder(traindir) # For unbalanced dataset we create a weighted sampler
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,
sampler = sampler, num_workers=args.workers, pin_memory=True)
Reference: Balanced Sampling between classes with torchvision DataLoader
参考方法2: 作者给出了均匀采样和非均匀采样的差别
imbalanced-dataset-sampler
Pytorch 类别平衡化处理的更多相关文章
- pytorch中网络特征图(feture map)、卷积核权重、卷积核最匹配样本、类别激活图(Class Activation Map/CAM)、网络结构的可视化方法
		目录 0,可视化的重要性: 1,特征图(feture map) 2,卷积核权重 3,卷积核最匹配样本 4,类别激活图(Class Activation Map/CAM) 5,网络结构的可视化 0,可视 ... 
- PyTorch官方中文文档:torch.nn
		torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ... 
- “你什么意思”之基于RNN的语义槽填充(Pytorch实现)
		1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ... 
- Pytorch系列教程-使用字符级RNN对姓名进行分类
		前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ... 
- [深度应用]·实战掌握PyTorch图片分类简明教程
		[深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ... 
- Pytorch: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/generic/THCTensorMa
		更换了数据集, 在计算交叉熵损失时出现错误 : cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/ ... 
- pytorch实现性别检测
		卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络. 1.训练 pytorch中自带几种常用的深度学习网络预训练模型,如VGG.ResNet等.往往为了加快学习的进度,在 ... 
- 【转载】Pytorch tutorial 之Datar Loading and Processing
		前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ... 
- Pytorch里的CrossEntropyLoss详解
		在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax.看得我头大,所以整理本文以备日后查阅. 首先要知道上面提 ... 
随机推荐
- VMware15.5版本安装CentOS7
			VMware15.5版本安装CentOS7 一.在VMware15.5中新建虚拟机 1.打开VMware,在首页面选择创建新的虚拟机. 2.新建虚拟机向导,选择典型配置.3.选择稍后安装操作系统.4. ... 
- header中Content-Disposition的作用与使用方法
			下载文件的时候会使用: Content-disposition 是 MIME 协议的扩展,MIME 协议指示 MIME 用户代理如何显示附加的文件.Content-disposition其实可以控制用 ... 
- ELK日志分析系统搭建 windows
			1 分别下载elk包 下载地址 https://www.elastic.co/cn/downloads 2 将这三个解压到同一个目录下,便于管理 3 elasticsearch不需要修改配置 默认即可 ... 
- php的希尔排序
			算是改进了的插入排序, 从性能时间上来看,也确实更有改进. 但比起php内置的功能,性能还有十倍之差呢 <?php /** * 原理:把排序的数据根据增量分成几个子序列,对子序列进行插入排序, ... 
- Python语言程序设计(3)--字符串类型及操作--time库进度条
			1.字符串类型的表示: 三引号可做注释,注释其实也是字符串 2.字符串的操作符 3.字符串处理函数 输出: 
- cifar-10数据集的可视化
			import numpy as np from PIL import Image import pickle import os CHANNEL = 3 WIDTH = 32 HEIGHT = 32 ... 
- Spring boot jpa 设定MySQL数据库的自增ID主键值
			内容简介 本文主要介绍在使用jpa向数据库添加数据时,如果表中主键为自增ID,对应实体类的设定方法. 实现步骤 只需要在自增主键上添加@GeneratedValue注解就可以实现自增,如下图: 关键代 ... 
- EventWaitHandle 第一课
			本篇通过一个列子使用EventWaitHandle实现两个线程的同步.请参看下面的列子. using System; using System.Collections.Generic; using S ... 
- linux 出错 “INFO: task java: xxx blocked for more than 120 seconds.” 的3种解决方案
			1 问题描述 最近搭建的一个linux最小系统在运行到241秒时在控制台自动打印如下图信息,并且以后每隔120秒打印一次. 仔细阅读打印信息发现关键信息是“hung_task_timeout_secs ... 
- HAProxy 2.0 and Beyond
			转自:https://www.haproxy.com/blog/haproxy-2-0-and-beyond/ 关于haproxy 2.0 的新特性说明 HAProxy Technologies i ... 
