之前一直自己手写各种triphard,triplet损失函数, 写的比较暴力,然后今天一个学长给我在github上看了一个别人的triphard的写法,一开始没看懂,用的pytorch函数没怎么见过,看懂了之后, 被惊艳到了。。因此在此记录一下,以及详细注释一下

class TripletLoss(nn.Module):
def __init__(self, margin=0.3):
super(TripletLoss, self).__init__()
self.margin = margin
self.ranking_loss = nn.MarginRankingLoss(margin=margin) # 获得一个简单的距离triplet函数 def forward(self, inputs, labels): n = inputs.size(0) # 获取batch_size
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) # 每个数平方后, 进行加和(通过keepdim保持2维),再扩展成nxn维
dist = dist + dist.t() # 这样每个dis[i][j]代表的是第i个特征与第j个特征的平方的和
dist.addmm_(1, -2, inputs, inputs.t()) # 然后减去2倍的 第i个特征*第j个特征 从而通过完全平方式得到 (a-b)^2
dist = dist.clamp(min=1e-12).sqrt() # 然后开方 # For each anchor, find the hardest positive and negative
mask = labels.expand(n, n).eq(labels.expand(n, n).t()) # 这里dist[i][j] = 1代表i和j的label相同, =0代表i和j的label不相同
dist_ap, dist_an = [], []
for i in range(n):
dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) # 在i与所有有相同label的j的距离中找一个最大的
dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) # 在i与所有不同label的j的距离找一个最小的
dist_ap = torch.cat(dist_ap) # 将list里的tensor拼接成新的tensor
dist_an = torch.cat(dist_an) # Compute ranking hinge loss
y = torch.ones_like(dist_an) # 声明一个与dist_an相同shape的全1tensor
loss = self.ranking_loss(dist_an, dist_ap, y)
return loss

[loss]Triphard loss优雅的写法的更多相关文章

  1. C#中一种替换switch语句更优雅的写法

    今天在项目中遇到了使用switch语句判断条件,但问题是条件比较多,大概有几十个条件,满屏幕的case判断,是否有更优雅的写法替代switch语句呢? 假设有这样的一个场景:商场经常会根据情况采取不同 ...

  2. 会了这十种Python优雅的写法,让你工作效率翻十倍,一人顶十人用!

      我们都知道,Python 的设计哲学是「优雅」.「明确」.「简单」.这也许很多人选择 Python 的原因.但是我收到有些伙伴反馈,他写的 Python 并不优雅,甚至很臃肿,那可能是你的姿势不对 ...

  3. L1 loss, L2 loss以及Smooth L1 Loss的对比

    总结对比下\(L_1\) 损失函数,\(L_2\) 损失函数以及\(\text{Smooth} L_1\) 损失函数的优缺点. 均方误差MSE (\(L_2\) Loss) 均方误差(Mean Squ ...

  4. L1 loss L2 loss

    https://www.letslearnai.com/2018/03/10/what-are-l1-and-l2-loss-functions.html http://rishy.github.io ...

  5. 处理样本不平衡的LOSS—Focal Loss

    0 前言 Focal Loss是为了处理样本不平衡问题而提出的,经时间验证,在多种任务上,效果还是不错的.在理解Focal Loss前,需要先深刻理一下交叉熵损失,和带权重的交叉熵损失.然后我们从样本 ...

  6. C# 多个个Dictionary合并更优雅的写法

    Dictionary 现在有两个Dictionary的对象,想把两个对象的中数据合并成一个. 使用for循环的话觉得非常不合适,于是考虑是否有相应的方法,网上找了很多,都是for循环,最后终于找到了一 ...

  7. if else 更优雅的写法(转)

    https://www.cnblogs.com/y896926473/articles/9675819.html

  8. 损失函数(Loss Function) -1

    http://www.ics.uci.edu/~dramanan/teaching/ics273a_winter08/lectures/lecture14.pdf Loss Function 损失函数 ...

  9. Cross-Entropy Loss 与Accuracy的数值关系

    以分类任务为例, 假设要将样本分为\(n\)个类别. 先考虑单个样本\((X, z)\). 将标题\(z\)转化为一个\(n\)维列向量\(y = (y_1, \dots y_k, \dots, y_ ...

随机推荐

  1. 帝国cms底部代码哪里改?要修改版权和统计代码

    最近接手的几个站是用帝国cms做的,底部代码那边都有一个**设计的链接,还有一些不相关的东西,第一眼看到就想把那些帝国cms底部代码清理掉,这就是让别人建站的烦恼,让他们删除说要收费,坑就一个字,自己 ...

  2. 微软官方出的各种dll丢失的修复工具

    例如 :因为计算机中丢失 api-ms-win-crt-runtime-l1-1-0.dll.尝试重新安装该程序以解决此问题. 软件名称: Visual C++ Redistributable for ...

  3. 【剑指offer】栈的压入、弹出序列

    一.题目: 输入两个整数序列,第一个序列表示栈的压入顺序,请判断第二个序列是否可能为该栈的弹出顺序.假设压入栈的所有数字均不相等.例如序列1,2,3,4,5是某栈的压入顺序,序列4,5,3,2,1是该 ...

  4. 在sublime3中docblockr插件配置apidoc接口文档注释模板

    写在前面: 将进行3个步骤配置 1.在sublime3中安装插件docblockr,可以参考http://www.cnblogs.com/jiangxiaobo/p/8327709.html 2.安装 ...

  5. SQLyog恢复数据库报错解决方法【Error Code: 2006 - MySQL server has gone away】

    https://blog.csdn.net/niqinwen/article/details/8693044 导入数据库的时候 SQLyog 报错了 Error Code: 2006 – MySQL ...

  6. [py]python的私有变量

    参考 python中并没有真正意义上的私有成员,它提供了在成员前面添加双下划线的方法来模拟类似功能.具体来说: _xxx 表示模块级别的私有变量或函数 __xxx 表示类的私有变量或函数 这被称为na ...

  7. 怎么获得当前点击的按钮的id名?

    <body> <input id="t1" type="button" value='fff'> <input id=" ...

  8. Python + logging 输出到屏幕,将log日志写入文件

    日志 日志是跟踪软件运行时所发生的事件的一种方法.软件开发者在代码中调用日志函数,表明发生了特定的事件.事件由描述性消息描述,该描述性消息可以可选地包含可变数据(即,对于事件的每次出现都潜在地不同的数 ...

  9. jmeter 测试websocket接口(二)

    1.到https://github.com/maciejzaleski/JMeter-WebSocketSampler下载Jmeter的WebSocket协议的支持插件:JMeterWebSocket ...

  10. unittest之suite测试集(测试套件)

    suite 这个表示测试集,不要放在class内,否则会提示"没有这样的测试方法在类里面 ",我觉得它唯一的好处就是调试的时候可以单独调试某个class而已,我一般不用它,调试时可 ...