pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)
一、BCELoss 二分类损失函数
输入维度为(n, ), 输出维度为(n, )
如果说要预测二分类值为1的概率,则建议用该函数!
输入比如是3维,则每一个应该是在0——1区间内(随意通常配合sigmoid函数使用),举例如下:
import torch
import torch.nn as nn
m = nn.Sigmoid()
loss = nn.BCELoss()
input = torch.randn(3,requires_grad=True)
target = torch.empty(3).random_(2)
output = loss(m(input), target)
output.backward() input,target,output 返回值:
(tensor([-0.8728, 0.3632, -0.0547], requires_grad=True),
tensor([1., 0., 0.]),
tensor(0.9264, grad_fn=<BinaryCrossEntropyBackward>)) m(input)结果为:
tensor([0.2947, 0.5898, 0.4863]) 计算output = (1 * ln 0.2947+(1-1)*ln(1-0.2947) + 0*ln0.5898 + (1-0)*ln(1-0.5898) + 0*ln0.4863 + (1-0)*ln(1-0.4863)) / 3 = 0.9264
二、nn.CrossEntropyLoss 交叉熵损失函数
输入维度(batch_size, feature_dim)
输出维度 (batch_size, 1)
X_input = torch.tensor[ [2.8883, 0.1760, 1.0774],
[1.1216, -0.0562, 0.0660],
[-1.3939, -0.0967, 0.5853]]
y_target = torch.tensor([1,2,0])
loss_func = nn.CrossEntropyLoss()
loss = loss_func(X_input, y_target)
计算流程:第一,x先softmax再log,得到x_hat 第二,y转0-1编码[1,2,0] 转[[0,1,0], [0,0,1], [1,0,0]] 再与x_hat相乘,取负取平均值
思考问题:多标签的分类任务中,怎么使用损失函数呢,是拆分是多个二分类问题呢,还是不用拆分直接用BCE呢(https://blog.csdn.net/rosefun96/article/details/88058708,参考:BCE 可以应用到多标签的分类任务中)?有什么区别呢?
pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)的更多相关文章
- pytorch中文文档-torch.nn常用函数-待添加-明天继续
https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- pytorch中文文档-torch.nn.init常用函数-待添加
参考:https://pytorch.org/docs/stable/nn.html torch.nn.init.constant_(tensor, val) 使用参数val的值填满输入tensor ...
- 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())
在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...
- Pytorch本人疑问(1) torch.nn和torch.nn.functional之间的区别
在写代码时发现我们在定义Model时,有两种定义方法: torch.nn.Conv2d()和torch.nn.functional.conv2d() 那么这两种方法到底有什么区别呢,我们通过下述代码看 ...
- 从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系
从 relu 的多种实现来看 torch.nn 与 torch.nn.functional 的区别与联系 relu多种实现之间的关系 relu 函数在 pytorch 中总共有 3 次出现: torc ...
- pytorch 损失函数
pytorch损失函数: http://blog.csdn.net/zhangxb35/article/details/72464152?utm_source=itdadao&utm_medi ...
- Multi label 多标签分类问题(Pytorch,TensorFlow,Caffe)
适用场景:一个输入对应多个label,或输入类别间不互斥 调用函数: 1. Pytorch使用torch.nn.BCEloss 2. Tensorflow使用tf.losses.sigmoid_cro ...
- Pytorch的默认初始化分布 nn.Embedding.weight初始化分布
一.nn.Embedding.weight初始化分布 nn.Embedding.weight随机初始化方式是标准正态分布 ,即均值$\mu=0$,方差$\sigma=1$的正态分布. 论据1——查看 ...
随机推荐
- 学习.NET 挑战赛
今天访问dot.net 网站看到了一个学习.NET 挑战赛,发现已经赛程过半了,这是一个为那些想更多地了解 C# 和 .NET 的人举办的一个完全免费的课程活动,这些模块必须在 11 月底前完成.参加 ...
- 为什么继承 Python 内置类型会出问题?!
本文出自"Python为什么"系列,请查看全部文章 不久前,Python猫 给大家推荐了一本书<流畅的Python>(点击可跳转阅读),那篇文章有比较多的"溢 ...
- Find Any File for Mac(文件搜索软件)v2.1.2b6
Find Any File for Mac是应用在Mac上的一款文件搜索工具,Find Any File Mac可以通过名称.创建或修改日期,大小或类型和创建者代码(而不是内容)在本地磁盘上搜索文件. ...
- CSS3之animation属性
CSS中的animation属性可用于为许多其他CSS属性设置动画,例如颜色,背景色,高度或宽度. 每个动画都需要使用@keyframes这种at-rule语句定义,然后使用animation属性来调 ...
- [MIT6.006] 9. Table Doubling, Karp-Rabin 双散列表, Karp-Rabin
在整理课程笔记前,先普及下课上没细讲的东西,就是下图,如果有个操作g(x),它最糟糕的时间复杂度为Ο(c2 * n),它最好时间复杂度是Ω(c1 * n),那么θ则为Θ(n).简单来说:如果O和Ω可以 ...
- Tim Urban:如何选择真正适合你的职业?
Wait But Why是一个专注于写长博客的网站,Tim Urban是网站的创始人之一.Tim Urban专注于写长论文,与时下的轻度阅读完全背道而驰,文章动辄几千甚至上万字,但令人吃惊的是却拥有惊 ...
- Kubernetes+Promethues+Cloud Alert实践分享
前言 容器集群管理系统 Kubernetes(简称K8s),为容器化的应用提供部署运行.容器编排.负载均衡.服务发现和动态伸缩等一系列完整功能,Prometheus 对 K8s 支持非常棒,能够自动发 ...
- //*[starts-with(@class,'btn')][text()='差'] 正则定位元素
starts-with? //*[starts-with(@class,'btn')][text()='差'] 意思找从头开始的这个class
- Linux内核源码分析之set_arch (一)
1. 概述 之前已经写了几篇Linux内核启动相关的文章,比如:<解压内核镜像><调用 start_kernel>都是用汇编语言写的,这些代码的作用仅仅是把内核镜像放置到特定的 ...
- javascript九宫格碰撞检测
JS九宫格碰撞检测这个东西 以前学过 这次主要是做面试项目web版的win10 桌面图片需要用碰撞检测 再写的时候竟然完全忘记了碰撞检测原理 和怎么写 综合来说还是写的太少 今天再学了一下 理 ...