采用  WeightedRandomSampler:

def make_weights_for_balanced_classes(images, nclasses):
count = [0] * nclasses
for item in images:
count[item[1]] += 1
weight_per_class = [0.] * nclasses
N = float(sum(count))
for i in range(nclasses):
weight_per_class[i] = N/float(count[i])
weight = [0] * len(images)
for idx, val in enumerate(images):
weight[idx] = weight_per_class[val[1]]
return weight
dataset_train = datasets.ImageFolder(traindir)                                                                         

# For unbalanced dataset we create a weighted sampler
weights = make_weights_for_balanced_classes(dataset_train.imgs, len(dataset_train.classes))
weights = torch.DoubleTensor(weights)
sampler = torch.utils.data.sampler.WeightedRandomSampler(weights, len(weights)) train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=args.batch_size, shuffle = True,
sampler = sampler, num_workers=args.workers, pin_memory=True)

Reference:   Balanced Sampling between classes with torchvision DataLoader

参考方法2: 作者给出了均匀采样和非均匀采样的差别

imbalanced-dataset-sampler

Pytorch 类别平衡化处理的更多相关文章

  1. pytorch中网络特征图(feture map)、卷积核权重、卷积核最匹配样本、类别激活图(Class Activation Map/CAM)、网络结构的可视化方法

    目录 0,可视化的重要性: 1,特征图(feture map) 2,卷积核权重 3,卷积核最匹配样本 4,类别激活图(Class Activation Map/CAM) 5,网络结构的可视化 0,可视 ...

  2. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  3. “你什么意思”之基于RNN的语义槽填充(Pytorch实现)

    1. 概况 1.1 任务 口语理解(Spoken Language Understanding, SLU)作为语音识别与自然语言处理之间的一个新兴领域,其目的是为了让计算机从用户的讲话中理解他们的意图 ...

  4. Pytorch系列教程-使用字符级RNN对姓名进行分类

    前言 本系列教程为pytorch官网文档翻译.本文对应官网地址:https://pytorch.org/tutorials/intermediate/char_rnn_classification_t ...

  5. [深度应用]·实战掌握PyTorch图片分类简明教程

    [深度应用]·实战掌握PyTorch图片分类简明教程 个人网站--> http://www.yansongsong.cn/ 项目GitHub地址--> https://github.com ...

  6. Pytorch: cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/THC/generic/THCTensorMa

    更换了数据集, 在计算交叉熵损失时出现错误 : cuda runtime error (59) : device-side assert triggered at /pytorch/aten/src/ ...

  7. pytorch实现性别检测

    卷积神经网络的训练是耗时的,很多场合不可能每次都从随机初始化参数开始训练网络.   1.训练 pytorch中自带几种常用的深度学习网络预训练模型,如VGG.ResNet等.往往为了加快学习的进度,在 ...

  8. 【转载】Pytorch tutorial 之Datar Loading and Processing

    前言 上文介绍了数据读取.数据转换.批量处理等等.了解到在PyTorch中,数据加载主要有两种方式: 1.自定义的数据集对象.数据集对象被抽象为Dataset类,实现自定义的数据集需要继承Datase ...

  9. Pytorch里的CrossEntropyLoss详解

    在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax.看得我头大,所以整理本文以备日后查阅. 首先要知道上面提 ...

随机推荐

  1. Python_正则表达式语法

    1.正则表达式中的操作符: 2.re库的使用: import re #search方法要求只要待匹配的字符串中包含正则表达式中的字符串就可以 match = re.search('python+',' ...

  2. 面试中的nginx高可用高并发!

    本文转自:91博客:原文地址:http://www.9191boke.com/439923471.html 面试题: nginx高可用?nginx 是如何实现并发的?为什么nginx不使用多线程?ng ...

  3. spriingboot使用thymeleaf

    1 添加jar包 <dependency> <groupId>org.springframework.boot</groupId> <artifactId&g ...

  4. 洛谷 P1219 八皇后题解

    题目描述 检查一个如下的6 x 6的跳棋棋盘,有六个棋子被放置在棋盘上,使得每行.每列有且只有一个,每条对角线(包括两条主对角线的所有平行线)上至多有一个棋子. 上面的布局可以用序列2 4 6 1 3 ...

  5. kvm创建windows2008虚拟机

    virt-install -n win2008-fushi001 -r 16384 --vcpus=4 --os-type=windows --accelerate -c /data/kvm/imag ...

  6. postgres高可用学习篇一:如何通过patroni如何管理3个postgres节点

    环境: CentOS Linux release 7.6.1810 (Core) 内核版本:3.10.0-957.10.1.el7.x86_64 node1:192.168.216.130 node2 ...

  7. spring配置文件ApplicationContext.xml里面class等没有提示功能

    实现效果: 解决方法: windows–>preference—>myeclipse—>files and editors–>xml—>xmlcatalog 点击add ...

  8. linux中的alias命令详解

    功能说明:设置指令的别名.语 法:alias[别名]=[指令名称]参 数 :若不加任何参数,则列出目前所有的别名设置.举    例 :ermao@lost-desktop:~$ alias       ...

  9. 【Spring】如何配置多个applicationContext.xml文件

    在web.xml中通过contextConfigLocation配置spring 开发Java Web程序,使用ssh架构时,默认情况下,Spring的配置文件applicationContext.x ...

  10. LeetCode 931. Minimum Falling Path Sum

    原题链接在这里:https://leetcode.com/problems/minimum-falling-path-sum/ 题目: Given a square array of integers ...