pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)
一、BCELoss 二分类损失函数
输入维度为(n, ), 输出维度为(n, )
如果说要预测二分类值为1的概率,则建议用该函数!
输入比如是3维,则每一个应该是在0——1区间内(随意通常配合sigmoid函数使用),举例如下:
import torch
import torch.nn as nn
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward() input,target,output 返回值:
(tensor([-0.8728, 0.3632, -0.0547], requires_grad=True),
tensor([1., 0., 0.]),
tensor(0.9264, grad_fn=<BinaryCrossEntropyBackward>)) m(input)结果为:
tensor([0.2947, 0.5898, 0.4863]) 计算output = (1 * ln 0.2947+(1-1)*ln(1-0.2947) + 0*ln0.5898 + (1-0)*ln(1-0.5898) + 0*ln0.4863 + (1-0)*ln(1-0.4863)) / 3 = 0.9264
二、nn.CrossEntropyLoss 交叉熵损失函数
输入维度(batch_size, feature_dim)
输出维度 (batch_size, 1)
X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],
[1.1216, -0.0562, 0.0660],
[-1.3939, -0.0967, 0.5853]]
y_target = torch.tensor([1,2,0])
loss_func = nn.CrossEntropyLoss()
loss = loss_func(X_input, y_target)
计算流程:第一,x先softmax再log,得到x_hat 第二,y转0-1编码[1,2,0] 转[[0,1,0], [0,0,1], [1,0,0]] 再与x_hat相乘,取负取平均值
思考问题:多标签的分类任务中,怎么使用损失函数呢,是拆分是多个二分类问题呢,还是不用拆分直接用BCE呢(https://blog.csdn.net/rosefun96/article/details/88058708,参考:BCE 可以应用到多标签的分类任务中)?有什么区别呢?
pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)的更多相关文章
- pytorch中文文档-torch.nn常用函数-待添加-明天继续
https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- pytorch中文文档-torch.nn.init常用函数-待添加
参考:https://pytorch.org/docs/stable/nn.html torch.nn.init.constant_(tensor, val) 使用参数val的值填满输入tensor ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别
在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- pytorch 损失函数
pytorch损失函数: http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medi ...
- Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)
适用场景:一个输入对应多个label,或输入类别间不互斥 调用函数: 1. Pytorch使用torch.nn.BCEloss 2. Tensorflow使用tf.losses.sigmoid_cro ...
- Pytorch的默认初始化分布 nn.Embedding.weight初始化分布
一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布 ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1——查看 ...
随机推荐
- php邮箱发送
php发送邮件 -------------------------------------------------------------------------------- <?php he ...
- 第05组 Alpha冲刺 (2/6)(组长)
.th1 { font-family: 黑体; font-size: 25px; color: rgba(0, 0, 255, 1) } #ka { margin-top: 50px } .aaa11 ...
- fashion数据集训练
下载数据集 fashion数据集总共有7万张28*28像素点的灰度图片和标签,涵盖十个分类:T恤.裤子.套头衫.连衣裙.外套.凉鞋.衬衫.运动鞋.包.靴子. 其中6万张用于训练,1万张用于测试. im ...
- 幻读在 InnoDB 中是被如何解决的?(转)
在MySQL事务初识中,我们了解到不同的事务隔离级别会引发不同的问题,如在 RR 级别下会出现幻读.但如果将存储引擎选为 InnoDB ,在 RR 级别下,幻读的问题就会被解决.在这篇文章中,会先介绍 ...
- UNP第13章——守护进程
1. 守护进程的启动方法 (1)系统初始化脚本启动,在系统启动阶段,按照如/etc目录或/etc/rc开头的目录中的某些脚本启动,这些守护进程一开始就有超级用户权限.如inetd,cron,Web服务 ...
- oracle 相关笔记
1.查询语句执行顺序 from->where->[group by ]-> select ->distinct->count(某一列) 2.用命令执行存储过程用 exec ...
- tar命令打包和压缩与解压
Linux里压缩与打包时分开的: 打包:多个文件变一个文件.该一个文件会大于整体所有文件,因为会添加各个信息说明哪到哪是一个文件. 压缩:大文件变小文件. 归档:将多个文件变成一个文件,这个文件就是归 ...
- 接口自动化测试:apiAutoTest使用re 处理数据依赖
目录 废话 2020/11/19 参数依赖 更新后的效果 新版依赖数据如何使用 源码地址 道谢 废话 目前在工作中写脚本的时候发现了一些之前开源的apiAutoTest的可优化项,后面应该也是会慢慢的 ...
- 创建一个自定义名称的Ceph集群
前言 这里有个条件,系统环境是Centos 7 ,Ceph 的版本为Jewel版本,因为这个组合下是由systemctl来进行服务控制的,所以需要做稍微的改动即可实现 准备工作 部署mon的时候需要修 ...
- Bad magic number ImportError in python
是源码编译里面版本不对,删除掉源码pyc然后重新编译就可以了 find .-name '*.pyc'-delete python -m compileall . 更新历史 why when 创建 20 ...