原文: http://www.voidcn.com/article/p-rtzqgqkz-bpg.html


最近看了下 PyTorch 的损失函数文档,整理了下自己的理解,重新格式化了公式如下,以便以后查阅。

注意下面的损失函数都是在单个样本上计算的,粗体表示向量,否则是标量。向量的维度用

N
表示。

nn.L1Loss

loss(x,y)=1N∑i=1N|x−y|

nn.SmoothL1Loss

也叫作 Huber Loss,误差在 (-1,1) 上是平方损失,其他情况是 L1 损失。

loss(x,y)=1N⎧⎩⎨⎪⎪⎪⎪12(xi−yi)2|xi−yi|−12,if |xi−yi|<1otherwise

nn.MSELoss

平方损失函数

loss(x,y)=1N∑i=1N|x−y|2

nn.BCELoss

二分类用的交叉熵,TODO

loss(o,t)=−1N∑i=1N[ti∗log(oi)+(1−ti)∗log(1−oi)]

nn.CrossEntropyLoss

交叉熵损失函数

loss(x,label)=−logexlabel∑Nj=1exj=−xlabel+log∑j=1Nexj

x
是没有经过 Softmax 的激活值。参考 cs231n 作业里对 Softmax Loss 的推导。

nn.NLLLoss

负对数似然损失函数(Negative Log Likelihood)

loss(x,label)=−xlabel

在前面接上一个 LogSoftMax 层就等价于交叉熵损失了。注意这里的

xlabel
和上个交叉熵损失里的不一样(虽然符号我给写一样了),这里是经过

log
运算后的数值,

nn.NLLLoss2d

和上面类似,但是多了几个维度,一般用在图片上。

  • input, (N, C, H, W)
  • target, (N, H, W)

比如用全卷积网络做 Semantic Segmentation 时,最后图片的每个点都会预测一个类别标签。

nn.KLDivLoss

KL 散度,又叫做相对熵,算的是两个分布之间的距离,越相似则越接近零。

loss(x,y)=1N∑i=1N[yi∗(logyi−xi)]

注意这里的

xi

log
概率,刚开始还以为 API 弄错了。

nn.MarginRankingLoss

评价相似度的损失

loss(x1,x2,y)=max(0,−y∗(x1−x2)+margin)

这里的三个都是标量,y 只能取 1 或者 -1,取 1 时表示 x1 比 x2 要大;反之 x2 要大。参数 margin 表示两个向量至少要相聚 margin 的大小,否则 loss 非负。默认 margin 取零。

nn.MultiMarginLoss

多分类(multi-class)的 Hinge 损失,

loss(x,y)=1N∑i=1,i≠yNmax(0,(margin−xy+xi)p)

其中

1≤y≤N
表示标签,

p
默认取 1,

margin
默认取 1,也可以取别的值。参考 cs231n 作业里对 SVM Loss 的推导。

nn.MultiLabelMarginLoss

多类别(multi-class)多分类(multi-classification)的 Hinge 损失,是上面 MultiMarginLoss 在多类别上的拓展。同时限定 p = 1,margin = 1.

loss(x,y)=1N∑i=1,i≠yjn∑j=1yj≠0[max(0,1−(xyj−xi))]

这个接口有点坑,是直接从 Torch 那里抄过来的,见 MultiLabelMarginCriterion 的描述。而 Lua 的下标和 Python 不一样,前者的数组下标是从 1 开始的,所以用 0 表示占位符。有几个坑需要注意,

  1. 这里的

    x,y
    都是大小为

    N
    的向量,如果

    y
    不是向量而是标量,后面的

    ∑j
    就没有了,因此就退化成上面的 MultiMarginLoss.

  2. 限制

    y
    的大小为

    N
    ,是为了处理多标签中标签个数不同的情况,用 0 表示占位,该位置和后面的数字都会被认为不是正确的类。如

    y=[5,3,0,0,4]
    那么就会被认为是属于类别 5 和 3,而 4 因为在零后面,因此会被忽略。

  3. 上面的公式和说明只是为了和文档保持一致,其实在调用接口的时候,用的是 -1 做占位符,而 0 是第一个类别。

举个梨子,

import torch
loss = torch.nn.MultiLabelMarginLoss()
x = torch.autograd.Variable(torch.FloatTensor([[0.1, 0.2, 0.4, 0.8]]))
y = torch.autograd.Variable(torch.LongTensor([[3, 0, -1, 1]]))
print loss(x, y) # will give 0.8500

按照上面的理解,第 3, 0 个是正确的类,1, 2 不是,那么,

loss=14∑i=1,2∑j=3,0[max(0,1−(xj−xi))]=14[(1−(0.8−0.2))+(1−(0.1−0.2))+(1−(0.8−0.4))+(1−(0.1−0.4))]=14[0.4+1.1+0.6+1.3]=0.85

*注意这里推导的第二行,我为了简短,都省略了 max(0, x) 符号。

nn.SoftMarginLoss

多标签二分类问题,这

N
项都是二分类问题,其实就是把

N
个二分类的 loss 加起来,化简一下。其中

y
只能取

1,−1
两种,代表正类和负类。和下面的其实是等价的,只是

y
的形式不同。

loss(x,y)=∑i=1Nlog(1+e−yixi)

nn.MultiLabelSoftMarginLoss

上面的多分类版本,根据最大熵的多标签 one-versue-all 损失,其中

y
只能取

1,−1
两种,代表正类和负类。

loss(x,y)=−∑i=1N[yilogexi1+exi+(1−yi)log11+exi]

nn.CosineEmbeddingLoss

余弦相似度的损失,目的是让两个向量尽量相近。注意这两个向量都是有梯度的。

loss(x,y)={1−cos(x,y)max(0,cos(x,y)+margin)if if y==1y==−1

margin 可以取

[−1,1]
,但是比较建议取 0-0.5 较好。

nn.HingeEmbeddingLoss

不知道做啥用的。另外文档里写错了,

x,y
的维度应该是一样的。

loss(x,y)=1N{ximax(0,margin−xi)if if yi==1yi==−1

nn.TripleMarginLoss

L(a,p,n)=1N(∑i=1Nmax(0, d(ai,pi)−d(ai,ni)+margin))

其中

d(xi,yi)=∥xi−yi∥22


[pytorch]pytorch loss function 总结的更多相关文章

  1. loss function

    什么是loss?   loss: loss是我们用来对模型满意程度的指标.loss设计的原则是:模型越好loss越低,模型越差loss越高,但也有过拟合的情况.   loss function: 在分 ...

  2. Derivative of the softmax loss function

    Back-propagation in a nerual network with a Softmax classifier, which uses the Softmax function: \[\ ...

  3. loss function与cost function

    实际上,代价函数(cost function)和损失函数(loss function 亦称为 error function)是同义的.它们都是事先定义一个假设函数(hypothesis),通过训练集由 ...

  4. 损失函数(Loss Function) -1

    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...

  5. 【caffe】loss function、cost function和error

    @tags: caffe 机器学习 在机器学习(暂时限定有监督学习)中,常见的算法大都可以划分为两个部分来理解它 一个是它的Hypothesis function,也就是你用一个函数f,来拟合任意一个 ...

  6. 惩罚因子(penalty term)与损失函数(loss function)

    penalty term 和 loss function 看起来很相似,但其实二者完全不同. 惩罚因子: penalty term的作用是把受限优化问题转化为非受限优化问题. 比如我们要优化: min ...

  7. 论文笔记之: Person Re-Identification by Multi-Channel Parts-Based CNN with Improved Triplet Loss Function

    Person Re-Identification by Multi-Channel Parts-Based CNN with Improved Triplet Loss Function CVPR 2 ...

  8. [machine learning] Loss Function view

    [machine learning] Loss Function view 有关Loss Function(LF),只想说,终于写了 一.Loss Function 什么是Loss Function? ...

  9. [基础] Loss function (二)

    Loss function = Loss term(误差项) + Regularization term(正则项),上次写的是误差项,这次正则项. 正则项的解释没那么直观,需要知道不适定问题,在经典的 ...

  10. [基础] Loss function(一)

    Loss function = Loss term(误差项) + Regularization term(正则项),我们先来研究误差项:首先,所谓误差项,当然是误差的越少越好,由于不存在负误差,所以为 ...

随机推荐

  1. windows环境下为php打开ssh2扩展

    安装步骤 1. 下载 php extension ssh2下载地址 http://windows.php.net/downloads/pecl/releases/ssh2/0.12/ 根据自己PHP的 ...

  2. 更改docker服务网段分配地址

    docker安装完毕后,会自动生成一个网卡名为docker0的网桥,如果其默认分配的网段地址和已有地址段冲突,可按如下步骤修改. 查看默认地址段如下 docker0: flags=4099<UP ...

  3. 字符串最长子串匹配-dp矩阵[转载]

    转自:https://blog.csdn.net/zls986992484/article/details/69863710 题目描述:求最长公共子串,sea和eat.它们的最长公共子串为ea,长度为 ...

  4. Django-form组件和ModelForm组件

    一. 构建Form表单 通过建一个类,添加需要进行验证的form字段,继而添加验证条件 from django import forms from django.forms import widget ...

  5. How to install MVVM Light Toolkit via NuGet

    Here is how you can install MVVM Light Toolkit  via NuGet in an easy way using only Visual Studio. S ...

  6. Liferay中request

    在liferay中的请求分为renderRequest和actionRequest这两种请求的方式,portletRequest的子类有三个1renderRequest,2EventRequest3C ...

  7. UVM中的regmodel建模(三)

    总结一下UVM中的寄存器访问实现: 后门访问通过add_hdl_path命令来添加寄存器路径,并扩展uvm_reg_backdoor基类,定义read与write函数,最后在uvm_reg_block ...

  8. POJ 1182 并查集

    Description 动物王国中有三类动物A,B,C,这三类动物的食物链构成了有趣的环形.A吃B, B吃C,C吃A. 现有N个动物,以1-N编号.每个动物都是A,B,C中的一种,但是我们并不知道它到 ...

  9. 【转】编程思想之多线程与多进程(3)——Java中的多线程

    <编程思想之多线程与多进程(1)——以操作系统的角度述说线程与进程>一文详细讲述了线程.进程的关系及在操作系统中的表现,这是多线程学习必须了解的基础.本文将接着讲一下Java中多线程程序的 ...

  10. python selenium第一个WebDriver脚本

    #coding=utf-8from selenium import webdriverimport timeimport osos.environ["webdriver.firefox.dr ...