一、BCELoss 二分类损失函数

输入维度为(n, ), 输出维度为(n, )

如果说要预测二分类值为1的概率,则建议用该函数!

输入比如是3维,则每一个应该是在0——1区间内(随意通常配合sigmoid函数使用),举例如下:

import torch
import torch.nn as nn

m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward() input,target,output 返回值:
(tensor([-0.8728, 0.3632, -0.0547], requires_grad=True),
tensor([1., 0., 0.]),
tensor(0.9264, grad_fn=<BinaryCrossEntropyBackward>)) m(input)结果为:
tensor([0.2947, 0.5898, 0.4863]) 计算output = (1 * ln 0.2947+(1-1)*ln(1-0.2947) + 0*ln0.5898 + (1-0)*ln(1-0.5898) + 0*ln0.4863 + (1-0)*ln(1-0.4863)) / 3 = 0.9264

二、nn.CrossEntropyLoss 交叉熵损失函数

输入维度(batch_size, feature_dim)

输出维度  (batch_size, 1)

X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],

          [1.1216, -0.0562, 0.0660],

          [-1.3939, -0.0967, 0.5853]]

y_target = torch.tensor([1,2,0])

loss_func = nn.CrossEntropyLoss()

loss = loss_func(X_input, y_target)

计算流程:第一,x先softmax再log,得到x_hat  第二,y转0-1编码[1,2,0] 转[[0,1,0], [0,0,1], [1,0,0]] 再与x_hat相乘,取负取平均值

思考问题:多标签的分类任务中,怎么使用损失函数呢,是拆分是多个二分类问题呢,还是不用拆分直接用BCE呢(https://blog.csdn.net/rosefun96/article/details/88058708,参考:BCE 可以应用到多标签的分类任务中)?有什么区别呢?

pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)的更多相关文章

  1. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  2. [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList

    1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...

  3. pytorch中文文档-torch.nn.init常用函数-待添加

    参考:https://pytorch.org/docs/stable/nn.html torch.nn.init.constant_(tensor, val) 使用参数val的值填满输入tensor ...

  4. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  5. Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别

    在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...

  6. 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系

    从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...

  7. pytorch 损失函数

    pytorch损失函数: http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medi ...

  8. Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)

    适用场景:一个输入对应多个label,或输入类别间不互斥 调用函数: 1. Pytorch使用torch.nn.BCEloss 2. Tensorflow使用tf.losses.sigmoid_cro ...

  9. Pytorch的默认初始化分布 nn.Embedding.weight初始化分布

    一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布  ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1——查看 ...

随机推荐

  1. php邮箱发送

    php发送邮件 -------------------------------------------------------------------------------- <?php he ...

  2. 第05组 Alpha冲刺 (2/6)(组长)

    .th1 { font-family: 黑体; font-size: 25px; color: rgba(0, 0, 255, 1) } #ka { margin-top: 50px } .aaa11 ...

  3. fashion数据集训练

    下载数据集 fashion数据集总共有7万张28*28像素点的灰度图片和标签,涵盖十个分类:T恤.裤子.套头衫.连衣裙.外套.凉鞋.衬衫.运动鞋.包.靴子. 其中6万张用于训练,1万张用于测试. im ...

  4. 幻读在 InnoDB 中是被如何解决的?(转)

    在MySQL事务初识中,我们了解到不同的事务隔离级别会引发不同的问题,如在 RR 级别下会出现幻读.但如果将存储引擎选为 InnoDB ,在 RR 级别下,幻读的问题就会被解决.在这篇文章中,会先介绍 ...

  5. UNP第13章——守护进程

    1. 守护进程的启动方法 (1)系统初始化脚本启动,在系统启动阶段,按照如/etc目录或/etc/rc开头的目录中的某些脚本启动,这些守护进程一开始就有超级用户权限.如inetd,cron,Web服务 ...

  6. oracle 相关笔记

    1.查询语句执行顺序 from->where->[group by ]-> select ->distinct->count(某一列) 2.用命令执行存储过程用 exec ...

  7. tar命令打包和压缩与解压

    Linux里压缩与打包时分开的: 打包:多个文件变一个文件.该一个文件会大于整体所有文件,因为会添加各个信息说明哪到哪是一个文件. 压缩:大文件变小文件. 归档:将多个文件变成一个文件,这个文件就是归 ...

  8. 接口自动化测试:apiAutoTest使用re 处理数据依赖

    目录 废话 2020/11/19 参数依赖 更新后的效果 新版依赖数据如何使用 源码地址 道谢 废话 目前在工作中写脚本的时候发现了一些之前开源的apiAutoTest的可优化项,后面应该也是会慢慢的 ...

  9. 创建一个自定义名称的Ceph集群

    前言 这里有个条件,系统环境是Centos 7 ,Ceph 的版本为Jewel版本,因为这个组合下是由systemctl来进行服务控制的,所以需要做稍微的改动即可实现 准备工作 部署mon的时候需要修 ...

  10. Bad magic number ImportError in python

    是源码编译里面版本不对,删除掉源码pyc然后重新编译就可以了 find .-name '*.pyc'-delete python -m compileall . 更新历史 why when 创建 20 ...