本篇借鉴了这篇文章,如果有兴趣,大家可以看看:https://blog.csdn.net/geter_CS/article/details/84857220

1、交叉熵:交叉熵主要是用来判定实际的输出与期望的输出的接近程度

2、CrossEntropyLoss()损失函数结合了nn.LogSoftmax()和nn.NLLLoss()两个函数。它在做分类(具体几类)训练的时候是非常有用的。

3、softmax用于多分类过程中,它将多个神经元的输出,映射到(0,1)区间内,可以看成概率来理解,从而来进行多分类!

其公式如下:

numpy计算代码:

import numpy as np
z = np.array([1.0, 2.0, 3.0, 4.0, 1.0, 2.0, 3.0])
print(np.exp(z)/sum(np.exp(z)))

4、LogSoftmax能够解决函数上溢和下溢的问题,加快运算速度,提高数据稳定性

其计算公式:

M是max(x_i),这样可以解决上溢下溢的问题.(但这样输出概率和就不是1了)

代码:

import torch
x = torch.Tensor([-4, 2, -3.2, 0, 7])
softmax = torch.exp(x)/torch.sum(torch.exp(x))
print("softmax\n",softmax)
print("sum:",torch.sum(softmax))
LogSoftmax = torch.log(softmax)
print("LogSoftmax\n",LogSoftmax)
print("sum:",torch.sum(LogSoftmax))

结果:

5、NllLoss:即负对数似然损失函数(Negtive Log Likehood)。

公式:

其中 y是one_hot编码后的数据标签,NLLLoss()得到的结果即是 y与logsoftmax()激活后的结果相乘再求均值再取反。(实际在用封装好的函数时,传入的标签无需进行one_hot编码)

代码:

import torch
import torch.nn.functional as F
import torch.nn as nn
x = torch.randn(5,5)
print("x:\n",x)
target = torch.tensor([0,2,3,1,4])
one_hot = F.one_hot(target).float()
print("one_hot:\n", one_hot)
softmax = torch.exp(x)/torch.sum(torch.exp(x), dim=1).reshape(-1,1)
print("soft_max:\n",softmax)
LogSoftmax = torch.log(softmax)
nllloss = -torch.sum(one_hot*LogSoftmax)/target.shape[0]
print("nllLoss:",nllloss)
#利用torch.nn.funcation实现
logsoftmax = F.log_softmax(x, dim=1)
nllloss = F.nll_loss(logsoftmax, target)
print("torch_nllLoss:",nllloss) #直接用torch.nn.CrossEntropyLoss验证
cross_entropy = F.cross_entropy(x, target)
print("cross_entropy:",cross_entropy)

结果:

5、没有权重的损失函数的计算如下:

有权重的损失函数的计算如下:

注意这里的标签值class,并不参与直接计算,而是作为一个索引,索引对象为实际类别

6、交叉熵损失(CE)和负对数极大似然估计(NLL)的关系:交叉熵是定义在两个one-hot向量之间的,更具体地说是定义在两个概率向量之间nll是定义在一个模型上的,取决于模型本身可以取不同的形式。

似然函数:都是指某种事件发生的可能性,但是在统计学中,“似然性”和“概率”(或然性)有明确的区分:概率,用于在已知一些参数的情况下,预测接下来在观测上所得到的结果;似然性,则是用于在已知某些观测所得到的结果时,对有关事物之性质的参数进行估值,也就是说已观察到某事件后,对相关参数进行猜测。

下图出处:

举个栗子,我们一共有三种类别,批量大小为1(为了好计算),那么输入size为(1,3),具体值为torch.Tensor([[-0.7715, -0.6205,-0.2562]])。标签值为target = torch.tensor([0]),这里标签值为0,表示属于第0类。loss计算如下:

import torch
import torch.nn as nn
import numpy as np
entroy = nn.CrossEntropyLoss()
input = torch.Tensor([[-0.7715,-0.6205,-0.2562]])
target = torch.tensor([0])
output = entroy(input,target)
print(output) #采用CrossEntropyLoss计算的结果。
myselfout = -(input[:,0])+np.log(np.exp(input[:,0])+np.exp(input[:,1])+np.exp(input[:,2])) #自己带公式计算的结果
print(myselfout)
lsf = nn.LogSoftmax()
loss = nn.NLLLoss()
lsfout = lsf(input)
lsfnout = loss(lsfout,target)
print(lsfnout)

结果:

tensor(1.3447)
tensor([1.3447])
tensor(1.3447)
softmax
SoftMax
 

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解的更多相关文章

  1. 交叉熵损失CrossEntropyLoss

    在各种深度学习框架中,我们最常用的损失函数就是交叉熵,熵是用来描述一个系统的混乱程度,通过交叉熵我们就能够确定预测数据与真实数据的相近程度.交叉熵越小,表示数据越接近真实样本. 1 分类任务的损失计算 ...

  2. [ch03-02] 交叉熵损失函数

    系列博客,原文在笔者所维护的github上:https://aka.ms/beginnerAI, 点击star加星不要吝啬,星越多笔者越努力. 3.2 交叉熵损失函数 交叉熵(Cross Entrop ...

  3. 【转载】深度学习中softmax交叉熵损失函数的理解

    深度学习中softmax交叉熵损失函数的理解 2018-08-11 23:49:43 lilong117194 阅读数 5198更多 分类专栏: Deep learning   版权声明:本文为博主原 ...

  4. 深度学习基础5:交叉熵损失函数、MSE、CTC损失适用于字识别语音等序列问题、Balanced L1 Loss适用于目标检测

    深度学习基础5:交叉熵损失函数.MSE.CTC损失适用于字识别语音等序列问题.Balanced L1 Loss适用于目标检测 1.交叉熵损失函数 在物理学中,"熵"被用来表示热力学 ...

  5. 深度学习原理与框架-神经网络结构与原理 1.得分函数 2.SVM损失函数 3.正则化惩罚项 4.softmax交叉熵损失函数 5. 最优化问题(前向传播) 6.batch_size(批量更新权重参数) 7.反向传播

    神经网络由各个部分组成 1.得分函数:在进行输出时,对于每一个类别都会输入一个得分值,使用这些得分值可以用来构造出每一个类别的概率值,也可以使用softmax构造类别的概率值,从而构造出loss值, ...

  6. 关于交叉熵损失函数Cross Entropy Loss

    1.说在前面 最近在学习object detection的论文,又遇到交叉熵.高斯混合模型等之类的知识,发现自己没有搞明白这些概念,也从来没有认真总结归纳过,所以觉得自己应该沉下心,对以前的知识做一个 ...

  7. softmax交叉熵损失函数求导

    来源:https://www.jianshu.com/p/c02a1fbffad6 简单易懂的softmax交叉熵损失函数求导 来写一个softmax求导的推导过程,不仅可以给自己理清思路,还可以造福 ...

  8. 最强常用开发库总结 - JSON库详解

    最强常用开发库总结 - JSON库详解 JSON应用非常广泛,对于Java常用的JSON库要完全掌握.@pdai JSON简介 JSON是什么 JSON 指的是 JavaScript 对象表示法(Ja ...

  9. 常用开发库 - MapStruct工具库详解

    常用开发库 - MapStruct工具库详解 MapStruct是一款非常实用Java工具,主要用于解决对象之间的拷贝问题,比如PO/DTO/VO/QueryParam之间的转换问题.区别于BeanU ...

  10. Pytorch里的CrossEntropyLoss详解

    在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax.看得我头大,所以整理本文以备日后查阅. 首先要知道上面提 ...

随机推荐

  1. CentOS 7 下将 jar 包注册为服务

    前提条件 因为 jar 包启动需要用到 jdk,所以服务器上必须要安装jdk或者jre,这方面的教程网上有非常多,可以去百度一下 创建文件 创建website.service文件, 内容如下: [ro ...

  2. 【ELK】Kibana-7.13.1版本 启动报错 Centos6

    报错信息: [root@centos6-1 gcc-4.8.2]# /opt/kibana-7.13.1-linux-x86_64/bin/kibana /opt/kibana-7.13.1-linu ...

  3. 【DataBase】MySQL 09 SQL函数 单行函数其三 日期函数

    日期函数 日期&时间函数 NOW 当前日期时间. CURDATE 当前日期. CURTIME 当前时间 -- NOW();返回系统日期+时间 SELECT NOW(); -- CURDATE( ...

  4. 【C】Re02

    一.命令行参数 #include <stdio.h> /** * 运行执行程序的命令携带 一些附加参数,传递给程序执行 * @param argc 命令行参数的个数 * @param ar ...

  5. DirectX9(D3D9)游戏开发:高光时刻录制和共享纹理的踩坑

    共享纹理 老游戏使用directx9无法直接与cc高光sdk(d3d11)对接,但是d3d9ex有共享纹理,我们通过共享纹理把游戏画面共享给cc录制,记录一些踩坑的笔记. 共享纹理示例: // 初始化 ...

  6. AI大模型 —— 国产大模型 —— 华为大模型

    有这么一句话,那就是AI大模型分两种,一种是大模型:另一种是华为大模型. 如果从技术角度来分析,华为的技术不论是在软件还是硬件都比国外的大公司差距极大,甚至有些技术评论者认为华为的软硬件技术至少落后2 ...

  7. Python按条件删除Excel表格数据的方法

      本文介绍基于Python语言,读取Excel表格文件,基于我们给定的规则,对其中的数据加以筛选,将不在指定数据范围内的数据剔除,保留符合我们需要的数据的方法.   首先,我们来明确一下本文的具体需 ...

  8. Git-HEAD 的含义

    在 Git 中,"HEAD" 是一个特殊的引用,它指向当前所处的分支或提交. 当你进行一些操作时,比如提交代码.切换分支等,HEAD 的指向会随之改变.下面是 HEAD 在不同情况 ...

  9. [学习笔记] 树链剖分(重链剖分) - 图论 & 数据结构

    树链剖分 树链剖分,用于解决一系列的树中链上问题的算法(数据结构).其实对于树链修改和树链求和问题可以使用更加方便的树上差分解决,但是对于像求树链最大(小)权值之类的更复杂的问题,差分就显得不够用了. ...

  10. JavaScript设计模式样例十五 —— 状态模式

    状态模式(State Pattern) 定义:创建表示各种状态的对象和一个行为随着状态对象改变而改变的 context 对象.目的:允许对象在内部状态发生改变时改变它的行为,对象看起来好像修改了它的类 ...