在使用Pytorch时经常碰见这些函数cross_entropy,CrossEntropyLoss, log_softmax, softmax。看得我头大,所以整理本文以备日后查阅。

首先要知道上面提到的这些函数一部分是来自于torch.nn,而另一部分则来自于torch.nn.functional(常缩写为F)。二者函数的区别可参见 知乎:torch.nn和funtional函数区别是什么?

下面是对与cross entropy有关的函数做的总结:

torch.nn torch.nn.functional (F)
CrossEntropyLoss cross_entropy
LogSoftmax log_softmax
NLLLoss nll_loss

下面将主要介绍torch.nn.functional中的函数为主,torch.nn中对应的函数其实就是对F里的函数进行包装以便管理变量等操作。

在介绍cross_entropy之前先介绍两个基本函数:

log_softmax

这个很好理解,其实就是logsoftmax合并在一起执行。

nll_loss

该函数的全程是negative log likelihood loss,函数表达式为

\[f(x,class)=-x[class]
\]

例如假设\(x=[1,2,3], class=2\),那额\(f(x,class)=-x[2]=-3\)

cross_entropy

交叉熵的计算公式为:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)
\]

其中\(p\)表示真实值,在这个公式中是one-hot形式;\(q\)是预测值,在这里假设已经是经过softmax后的结果了。

仔细观察可以知道,因为\(p\)的元素不是0就是1,而且又是乘法,所以很自然地我们如果知道1所对应的index,那么就不用做其他无意义的运算了。所以在pytorch代码中target不是以one-hot形式表示的,而是直接用scalar表示。所以交叉熵的公式(m表示真实类别)可变形为:

\[cross\_entropy=-\sum_{k=1}^{N}\left(p_{k} * \log q_{k}\right)=-log \, q_m
\]

仔细看看,是不是就是等同于log_softmaxnll_loss两个步骤。

所以Pytorch中的F.cross_entropy会自动调用上面介绍的log_softmaxnll_loss来计算交叉熵,其计算方式如下:

\[\operatorname{loss}(x, \text {class})=-\log \left(\frac{\exp (x[\operatorname{class}])}{\sum_{j} \exp (x[j])}\right)
\]

代码示例

>>> input = torch.randn(3, 5, requires_grad=True)
>>> target = torch.randint(5, (3,), dtype=torch.int64)
>>> loss = F.cross_entropy(input, target)
>>> loss.backward()

微信公众号:AutoML机器学习

MARSGGBO♥原创

如有意合作或学术讨论欢迎私戳联系~
邮箱:marsggbo@foxmail.com


2019-2-19

Pytorch里的CrossEntropyLoss详解的更多相关文章

  1. pytorch之nn.Conv1d详解

    转自:https://blog.csdn.net/sunny_xsc1994/article/details/82969867,感谢分享 pytorch之nn.Conv1d详解

  2. 全网最全的Windows下Anaconda2 / Anaconda3里Python语言实现定时发送微信消息给好友或群里(图文详解)

    不多说,直接上干货! 缘由: (1)最近看到情侣零点送祝福,感觉还是很浪漫的事情,相信有很多人熬夜为了给爱的人送上零点祝福,但是有时等着等着就睡着了或者时间并不是卡的那么准就有点强迫症了,这是也许程序 ...

  3. 【小白学PyTorch】11 MobileNet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,欢迎加我微信好友交流学习:cyx645016617. @ 目录 1 背景 2 深度可分离卷积 2.2 一般卷积计算量 2.2 深度可分离卷积计算量 2 ...

  4. pytorch nn.LSTM()参数详解

    输入数据格式:input(seq_len, batch, input_size)h0(num_layers * num_directions, batch, hidden_size)c0(num_la ...

  5. Pytorch Bi-LSTM + CRF 代码详解

    久闻LSTM + CRF的效果强大,最近在看Pytorch官网文档的时候,看到了这段代码,前前后后查了很多资料,终于把代码弄懂了.我希望在后来人看这段代码的时候,直接就看我的博客就能完全弄懂这段代码. ...

  6. Yii 框架里数据库操作详解-[增加、查询、更新、删除的方法 'AR模式']

    public function getMinLimit () {        $sql = "...";        $result = yii::app()->db-& ...

  7. 扩展运算符及其在vuex的辅助函数里的应用详解

         一.扩展运算符   <1>为什么扩展运算符会诞生?              因为箭头函数没有arguments,所以才有了扩展运算符       <2>在箭头函数里 ...

  8. pytorch BiLSTM+CRF代码详解 重点

    一. BILSTM + CRF介绍 https://www.jianshu.com/p/97cb3b6db573 1.介绍 基于神经网络的方法,在命名实体识别任务中非常流行和普遍. 如果你不知道Bi- ...

  9. 【小白学PyTorch】12 SENet详解及PyTorch实现

    文章来自微信公众号[机器学习炼丹术].我是炼丹兄,有什么问题都可以来找我交流,近期建立了微信交流群,也在朋友圈抽奖赠书十多本了.我的微信是cyx645016617,欢迎各位朋友. 参考目录: @ 目录 ...

随机推荐

  1. CodeChef Dynamic GCD

    嘟嘟嘟vjudge 我今天解决了一个历史遗留问题! 题意:给一棵树,写一个东西,支持一下两种操作: 1.\(x\)到\(y\)的路径上的每一个点的权值加\(d\). 2.求\(x\)到\(y\)路径上 ...

  2. Oracle的RowId和Rownum

    本文参照来自:https://www.cnblogs.com/whut-helin/p/8024860.html 由sql select p.*,rowid,rownum from promotion ...

  3. OllyDbg使用笔记

    [TOC] OD步过后,返回到之前某位置,重新单步执行 找到你想返回的行, 右键选择New origin here,快捷键Ctrl+Gray *, 然后程序会返回到这一行,再次按F7或者F8等执行即可

  4. Thread类的其他方法,同步锁,死锁与递归锁,信号量,事件,条件,定时器,队列,Python标准模块--concurrent.futures

    参考博客: https://www.cnblogs.com/xiao987334176/p/9046028.html 线程简述 什么是线程?线程是cpu调度的最小单位进程是资源分配的最小单位 进程和线 ...

  5. 【区块链】【一】Hash 算法【转】

    问题导读1.哈希算法在区块链的作用是什么?2.什么是哈希算法?3.哈希算法是否可逆?4.比特币采用的是什么哈希算法? 作用在学习哈希算法前,我们需要知道哈希在区块链的作用哈希算法的作用如下:区块链通过 ...

  6. Linux下安装 Python3

    前言 Linux下大部分系统默认自带python2.x的版本,最常见的是python2.6或python2.7版本,默认的python被系统很多程序所依赖,比如centos下的yum就是python2 ...

  7. 初识:java虚拟机的内存划分

    什么是内存? 内存是计算机中的重要原件,临时存储区域,作用是运行程序.我们编写的程序是存放在硬盘中的,在硬盘中的程序是不会运行的,必须放进内存中才能运行,运行完毕后会清空内存.Java虚拟机要运行程序 ...

  8. HttpInvoker http请求工具类

    import java.io.BufferedReader; import java.io.DataOutputStream; import java.io.IOException; import j ...

  9. MIUI9 解锁并刷入TWRP后,删除解锁密码

    如果因为某种原因导致解锁密码失效(比如刷了其他ROM),还原备份回来之后,解锁密码失效了. 那么可以进入TWRP,然后通过  adb shell 进入\data\system\文件夹 用rm命令删除g ...

  10. python之面向对象初识

    一.面向对象初识 1.结构上 面向对象分成两部分:属性.方法 class A: name = 'xiaoming' # 静态属性.静态变量.静态字段. def func1(self): # 函数.动态 ...