pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)
一、BCELoss 二分类损失函数
输入维度为(n, ), 输出维度为(n, )
如果说要预测二分类值为1的概率,则建议用该函数!
输入比如是3维,则每一个应该是在0——1区间内(随意通常配合sigmoid函数使用),举例如下:
import torch
import torch.nn as nn
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward() input,target,output 返回值:
(tensor([-0.8728, 0.3632, -0.0547], requires_grad=True),
tensor([1., 0., 0.]),
tensor(0.9264, grad_fn=<BinaryCrossEntropyBackward>)) m(input)结果为:
tensor([0.2947, 0.5898, 0.4863]) 计算output = (1 * ln 0.2947+(1-1)*ln(1-0.2947) + 0*ln0.5898 + (1-0)*ln(1-0.5898) + 0*ln0.4863 + (1-0)*ln(1-0.4863)) / 3 = 0.9264
二、nn.CrossEntropyLoss 交叉熵损失函数
输入维度(batch_size, feature_dim)
输出维度 (batch_size, 1)
X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],
[1.1216, -0.0562, 0.0660],
[-1.3939, -0.0967, 0.5853]]
y_target = torch.tensor([1,2,0])
loss_func = nn.CrossEntropyLoss()
loss = loss_func(X_input, y_target)
计算流程:第一,x先softmax再log,得到x_hat 第二,y转0-1编码[1,2,0] 转[[0,1,0], [0,0,1], [1,0,0]] 再与x_hat相乘,取负取平均值
思考问题:多标签的分类任务中,怎么使用损失函数呢,是拆分是多个二分类问题呢,还是不用拆分直接用BCE呢(https://blog.csdn.net/rosefun96/article/details/88058708,参考:BCE 可以应用到多标签的分类任务中)?有什么区别呢?
pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)的更多相关文章
- pytorch中文文档-torch.nn常用函数-待添加-明天继续
https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- pytorch中文文档-torch.nn.init常用函数-待添加
参考:https://pytorch.org/docs/stable/nn.html torch.nn.init.constant_(tensor, val) 使用参数val的值填满输入tensor ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别
在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- pytorch 损失函数
pytorch损失函数: http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medi ...
- Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)
适用场景:一个输入对应多个label,或输入类别间不互斥 调用函数: 1. Pytorch使用torch.nn.BCEloss 2. Tensorflow使用tf.losses.sigmoid_cro ...
- Pytorch的默认初始化分布 nn.Embedding.weight初始化分布
一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布 ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1——查看 ...
随机推荐
- 《<SPRING5高级编程(第5版)>_王净译》笔记-【目录】
第一次写这玩意,不知道什么时候能写完,今天项目比较近,期望年底能看完吧. 先定个小目标 20201228 完成 第1章 Spring介绍 第2章 入门 第3章 在Spring中引入IoC和DI 第4章 ...
- 11content_processor
1,content_processor 上下文处理器应该返回一个字典,字典中的key会被模板中当成变量来渲染 上下文处理器返回的字典,在所有页面中都是可以使用的 被这个装饰器修饰的钩子函数,必须要返回 ...
- 关于“Failed to configure a DataSource: 'url' attribute is not specified and no embedded datasource could be configured.”
Consider the following: If you want an embedded database (H2, HSQL or Derby), please put it on the c ...
- 企业网络拓扑MSTP功能实例(二)
组网图形 MSTP简介 以太网交换网络中为了进行链路备份,提高网络可靠性,通常会使用冗余链路.但是使用冗余链路会在交换网络上产生环路,引发广播风暴以及MAC地址表不稳定等故障现象,从而导致用户通信质量 ...
- ASCII、Unicode、UTF-8、UTF-8(without BOM)、UTF-16、UTF-32傻傻分不清
ASCII.Unicode.UTF-8.UTF-8(without BOM).UTF-16.UTF-32傻傻分不清 目录 ASCII.Unicode.UTF-8.UTF-8(without BOM). ...
- Java基础 之一 基本知识
Java基础 之一 基本知识 1.数据类型 Java有8种基本数据类型 int.short .long.byte.float.double.char.boolean 先说明以下单位之间的关系 1位 = ...
- 科普干货|漫谈鸿蒙LiteOS-M与HUAWEI LiteOS内核的几大不同
摘要:鸿蒙和LiteOS的内核都是一样的名字,可它们究竟有什么不同呢?一起来对比一下文件吧! HarmonyOS系统 HarmonyOS是一款"面向未来".面向全场景(移动办公.运 ...
- rbd-mirror新功能
RBD 的 mirroring 功能将会在下一个稳定版本Jewel中实现,这个Jewel版本已经发布了第一个版本10.1.0,这个功能已经在这个发布的版本中实现了 一.基本原理 我们试图解决的或者至少 ...
- 【javascript】掌握ES6-10,附xmind思维导图,每个知识点备注说明案例,请享用
前段时间一直想掌握ES6-10,陆陆续续花了1个月的时间,自学了ES6-10的新知识点,大部分都是非常实用的,花了2天时间整理思维导图 思维导图已上传博客园,请享用. ES6-10思维导图xmind ...
- BeanFactory and FactoryBean
BeanFactory,这是Spring容器的基础实现类,它负责生产和管理Bean的一个工厂.当然BeanFactory只是一个接口,它的常用实现有XmlBeanFactory.DefaultList ...