[loss]Triphard loss优雅的写法
之前一直自己手写各种triphard,triplet损失函数, 写的比较暴力,然后今天一个学长给我在github上看了一个别人的triphard的写法,一开始没看懂,用的pytorch函数没怎么见过,看懂了之后, 被惊艳到了。。因此在此记录一下,以及详细注释一下
class TripletLoss(nn.Module):
def __init__(self, margin=0.3):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin) # 获得一个简单的距离triplet函数
def forward(self, inputs, labels):
n = inputs.size(0) # 获取batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # 每个数平方后, 进行加和(通过keepdim保持2维),再扩展成nxn维
dist = dist + dist.t() # 这样每个dis[i][j]代表的是第i个特征与第j个特征的平方的和
dist.addmm_(1, -2, inputs, inputs.t()) # 然后减去2倍的 第i个特征*第j个特征 从而通过完全平方式得到 (a-b)^2
dist = dist.clamp(min=1e-12).sqrt() # 然后开方
# For each anchor, find the hardest positive and negative
mask = labels.expand(n, n).eq(labels.expand(n, n).t()) # 这里dist[i][j] = 1代表i和j的label相同, =0代表i和j的label不相同
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) # 在i与所有有相同label的j的距离中找一个最大的
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) # 在i与所有不同label的j的距离找一个最小的
dist_ap = torch.cat(dist_ap) # 将list里的tensor拼接成新的tensor
dist_an = torch.cat(dist_an)
# Compute ranking hinge loss
y = torch.ones_like(dist_an) # 声明一个与dist_an相同shape的全1tensor
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss
[loss]Triphard loss优雅的写法的更多相关文章
- C#中一种替换switch语句更优雅的写法
今天在项目中遇到了使用switch语句判断条件,但问题是条件比较多,大概有几十个条件,满屏幕的case判断,是否有更优雅的写法替代switch语句呢? 假设有这样的一个场景:商场经常会根据情况采取不同 ...
- 会了这十种Python优雅的写法,让你工作效率翻十倍,一人顶十人用!
我们都知道,Python 的设计哲学是「优雅」.「明确」.「简单」.这也许很多人选择 Python 的原因.但是我收到有些伙伴反馈,他写的 Python 并不优雅,甚至很臃肿,那可能是你的姿势不对 ...
- L1 loss, L2 loss以及Smooth L1 Loss的对比
总结对比下\(L_1\) 损失函数,\(L_2\) 损失函数以及\(\text{Smooth} L_1\) 损失函数的优缺点. 均方误差MSE (\(L_2\) Loss) 均方误差(Mean Squ ...
- L1 loss L2 loss
https://www.letslearnai.com/2018/03/10/what-are-l1-and-l2-loss-functions.html http://rishy.github.io ...
- 处理样本不平衡的LOSS—Focal Loss
0 前言 Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的.在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失.然后我们从样本 ...
- C# 多个个Dictionary合并更优雅的写法
Dictionary 现在有两个Dictionary的对象,想把两个对象的中数据合并成一个. 使用for循环的话觉得非常不合适,于是考虑是否有相应的方法,网上找了很多,都是for循环,最后终于找到了一 ...
- if else 更优雅的写法(转)
https://www.cnblogs.com/y896926473/articles/9675819.html
- 损失函数(Loss Function) -1
http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...
- Cross-Entropy Loss 与Accuracy的数值关系
以分类任务为例, 假设要将样本分为\(n\)个类别. 先考虑单个样本\((X, z)\). 将标题\(z\)转化为一个\(n\)维列向量\(y = (y_1, \dots y_k, \dots, y_ ...
随机推荐
- 实习培训——Servlet(7)
实习培训——Servlet(7) 1 Servlet 异常处理 当一个 Servlet 抛出一个异常时,Web 容器在使用了 exception-type 元素的 web.xml 中搜索与抛出异常类 ...
- 关于LUA中的随机数问题
也许很多人会奇怪为什么使用LUA的时候,第一个随机数总是固定,而且常常是最小的那个值,下面我就简要的说明一下吧,说得不好,还请谅解.我现在使用的4.0版本的LUA,看的代码是5.0的,呵呵 LUA4. ...
- (转)Linux Oracle服务启动&停止脚本与开机自启动
在CentOS 6.3下安装完Oracle 10g R2,重开机之后,你会发现Oracle没有自行启动,这是正常的,因为在Linux下安装Oracle的确不会自行启动,必须要自行设定相关参数,首先先介 ...
- jmeter 线程组之间的参数传递(加密接口测试三)
场景测试中,一次登录后做多个接口的操作,然后登录后的uid需要关联传递给其他接口发送请求的时候使用. 1.在登录接口响应信息中提取uid字段值 1>login请求 -->添加 --> ...
- char* a与char a[]的区别
char *a 与char a[] 的区别 char *a = "hello" 中的a是指向第一个字符‘a'的一个指针 char a[20] = "hello&quo ...
- GreenPlum安装greenplum-cc-web监控
一. GreenPlum集群安装环境 由虚拟机搭建的一台master两台segment. 二.安装前准备 1) 所需安装包 GreenPlum监控安装包: greenplum-cc-web-3.0.2 ...
- php深入学习
关于PHP程序员解决问题的能力 http://rango.swoole.com/archives/340 深入理解PHP内核 by xuhong大牛 http://www.php-internals. ...
- 下载YouTube视频的方法
这个网站就可以: http://www.clipconverter.cc/ 更多的网站及介绍参考知乎:http://www.zhihu.com/question/19964181
- 015-awk
1.awk的处理方式: 一次处理一行. 对每行可以进行切片处理,针对字段处理.2.awk的格式: awk BEGIN{循环之前初始化} [options] 'command' END{循环之后结尾} ...
- SV中的OOP
OOP:Object-Oriented Programming,有两点个人认为适合验证环境的搭建:1)Property(变量)和Method(function/task)的封装,其实是BFM模型更方便 ...