适用场景:一个输入对应多个label,或输入类别间不互斥

调用函数:

1. Pytorch使用torch.nn.BCEloss

2. Tensorflow使用tf.losses.sigmoid_cross_entropy

3. Caffe使用SigmoidCrossEntropyLoss

在output和target之间构建binary cross entropy,其中i为每一个类。

以pytorch为例:Caffe,TensorFlow版本类比,输入均为相同形式的向量

m = nn.Sigmoid()
loss = nn.BCELoss()
input = autograd.Variable(torch.randn(3), requires_grad=True)
target = autograd.Variable(torch.FloatTensor(3).random_(2))
output = loss(m(input), target)
output.backward()

注意target的形式,要写成01编码形式,eg:如果同时为第一类和第三类则,[1, 0, 1]

主要是结合sigmoid来使用,经过classifier分类过后的输出为(batch_size,num_class)为每个数据的标签, 标签不是one-hot的主要体现在sigmoid输出之后,仍然为(batch_size,num_class),对于一个实例,它的各个label的分数加起来不一定等于1,bceloss在每个类维度上求cross entropy loss然后加和求平均得到,这里就体现了多标签的思想。

[CVPR2015] Is object localization for free? – Weakly-supervised learning with convolutional neural networks这篇论文里设计了针对多标签问题的loss,传统的类别分类不适用,作者把这个任务视为多个二分类问题,loss function和分类的分数如下:

Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)的更多相关文章

  1. Tensorflow学习教程------lenet多标签分类

    本文在上篇的基础上利用lenet进行多标签分类.五个分类标准,每个标准分两类.实际来说,本文所介绍的多标签分类属于多任务学习中的联合训练,具体代码如下. #coding:utf-8 import te ...

  2. 实战caffe多标签分类——汽车品牌与车辆外观(C++接口)[详细实现+数据集]

    前言 很多地方我们都需要用到多标签分类,比如一张图片,上面有只蓝猫,另一张图片上面有一只黄狗,那么我们要识别的时候,就可以采用多标签分类这一思想了.任务一是识别出这个到底是猫还是狗?(类型)任务二是识 ...

  3. scikit-learn一般实例之八:多标签分类

    本例模拟一个多标签文档分类问题.数据集基于下面的处理随机生成: 选取标签的数目:泊松(n~Poisson,n_labels) n次,选取类别C:多项式(c~Multinomial,theta) 选取文 ...

  4. CSS.02 -- 样式表 及标签分类(块、行、行内块元素)、CSS三大特性、背景属性

    样式表书写位置  内嵌式写法 <head> <style type="text/css"> 样式表写法 </style> </head&g ...

  5. 前端 HTML 标签分类

    三种: 1.块级标签: 独占一行,可设置宽度,高度.如果设置了宽度和高度,则就是当前的宽高.如果宽度和高度没有设置,宽度是父盒子的宽度,高度根据内容填充. 2.行内标签:在一行内显示,不能设置宽度,高 ...

  6. 如何用softmax和sigmoid来做多分类和多标签分类

    首先,说下多类分类和多标签分类的区别 多标签分类:一个样本可以属于多个类别(或标签),不同类之间是有关联的,比如一个文本被被划分成“人物”和“体育人物”两个标签.很显然这两个标签不是互斥的,而是有关联 ...

  7. 使用 scikit-learn 实现多类别及多标签分类算法

    多标签分类格式 对于多标签分类问题而言,一个样本可能同时属于多个类别.如一个新闻属于多个话题.这种情况下,因变量yy需要使用一个矩阵表达出来. 而多类别分类指的是y的可能取值大于2,但是y所属类别是唯 ...

  8. 使用MXNet远程编写卷积神经网络用于多标签分类

    最近试试深度学习能做点什么事情.MXNet是一个与Tensorflow类似的开源深度学习框架,在GPU显存利用率上效率高,比起Tensorflow显著节约显存,并且天生支持分布式深度学习,单机多卡.多 ...

  9. k-近邻算法 标签分类

    k-近邻算法根据特征比较,然后提取样本集中特征最相似数据(最邻近)的分类标签.那么,如何进行比较呢? 怎么判断红色圆点标记的电影所属的类别呢? 如下图所示. 答:距离度量.这个电影分类的例子有2个特征 ...

随机推荐

  1. 关于apache配置映射端口

    step1.打开httpd.conf找到Listen 80这一行在后面添加Listen 8080Listen 8001Listen 8002Listen 8003也就是意味着每个项目占用一个端口,就像 ...

  2. linux快速将磁盘额外空间扩展到某一挂载点

    由于之前在创建用户时,为该用户目录分配的空间只有5G,在后续的开发,存放的东西越来越多,空间眼看就不够用了,网上查了一下,很多都是教我们将其余挂载点分配过多的空间分配到空间不足的挂载点,步骤还不算太复 ...

  3. hostapd修改beacon帧和probe response帧

    在AP模式下,热点会不断定期地发送Beacon帧来宣告自己存在,告知设备可以加入网络: Probe Response帧是用于应答Probe Request帧,Probe Request帧是移动工作站用 ...

  4. python操作mysql数据库增删改查的dbutils实例

    python操作mysql数据库增删改查的dbutils实例 # 数据库配置文件 # cat gconf.py #encoding=utf-8 import json # json里面的字典不能用单引 ...

  5. java乱码问题解决

    1.通过统一的过滤器进行了页面过滤(问题排除) 2.通过debug功能发现页面传到servelet和DAO中文都是OK的,可以说明在web程序端没有问题 问题就可能出现在数据库上面 首先查看数据库编码 ...

  6. LabVIEW--为设备添加配置文件.ini

    需求:我同一个程序下载到两台机器人上,有些参数是不一样的,比如说服务器的ID或者端口,以及存放文件的位置,如果我每次下载之前改程序的话就非常麻烦了(虽然在程序里面是作为全局变量来存的),不利于后期的更 ...

  7. python学习第11天 迭代器

    函数的名称 闭包 迭代器 递归

  8. Mysql 的安装(压缩文件)和基本管理

    MySql安装和基本管理   本节掌握内容: mysql的安装.启动 mysql破解密码 统一字符编码 MySQL是一个关系型数据库管理系统,由瑞典MySQL AB 公司开发,目前属于 Oracle ...

  9. 快速开发android,离不开这10个优秀的开源项目

    作为一名菜鸡Android,时常瞻仰大佬们的开源项目是非常必要的.这里我为大家收集整理了10个优秀的开源项目,方便我们日常开发中学习! 作者:ListenToCode博客:https://www.ji ...

  10. 配置 Confluence 6 安全的最佳实践

    让一个系统能够变得更加坚固的最好办法是将系统独立出来.请参考你公司的安全管理策略和相关人员来找到你公司应该采用何种安全策略.这里有很多事情需要我们考虑,例如考虑如何安装我们的操作系统,应用服务器,数据 ...