转载请注明出处:

http://www.cnblogs.com/darkknightzh/p/8524937.html

论文:

SphereFace: Deep Hypersphere Embedding for Face Recognition

https://arxiv.org/abs/1704.08063

http://wyliu.com/papers/LiuCVPR17v3.pdf

官方代码:

https://github.com/wy1iu/sphereface

pytorch代码:

https://github.com/clcarwin/sphereface_pytorch

说明:没用过mxnet,下面的代码注释只是纯粹从代码的角度来分析并进行注释,如有错误之处,敬请谅解,并欢迎指出。

传统的交叉熵公式如下:

${{L}_{i}}=-\log \frac{{{e}^{W_{yi}^{T}{{x}_{i}}+{{b}_{yi}}}}}{\sum\nolimits_{j}{{{e}^{W_{j}^{T}{{x}_{i}}+{{b}_{j}}}}}}=-\log \frac{{{e}^{\left\| {{W}_{yi}} \right\|\left\| {{x}_{i}} \right\|\cos ({{\theta }_{yi}},i)+{{b}_{yi}}}}}{\sum\nolimits_{j}{{{e}^{\left\| {{W}_{j}} \right\|\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)+{{b}_{j}}}}}}$

将W归一化到1,且不考虑偏置项,即${{b}_{j}}=0$,则上式变成:

${{L}_{\text{modified}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{yi}},i)}}}{\sum\nolimits_{j}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$

其中θ为w和x的夹角。

为了进一步限制夹角的范围,使用mθ,上式变成

${{L}_{\text{ang}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\cos (m{{\theta }_{yi}},i)}}}{{{e}^{\left\| {{x}_{i}} \right\|\cos (m{{\theta }_{yi}},i)}}+\sum\nolimits_{j\ne yi}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$

其中θ范围为$\left[ 0,\frac{\pi }{m} \right]$。

为了使得上式单调,引入$\psi ({{\theta }_{yi,i}})$:

${{L}_{\text{ang}}}=\frac{1}{N}\sum\limits_{i}{-\log (\frac{{{e}^{\left\| {{x}_{i}} \right\|\psi ({{\theta }_{yi,i}})}}}{{{e}^{\left\| {{x}_{i}} \right\|\psi ({{\theta }_{yi,i}})}}+\sum\nolimits_{j\ne yi}{{{e}^{\left\| {{x}_{i}} \right\|\cos ({{\theta }_{j}},i)}}}}})$

其中

$\psi ({{\theta }_{yi,i}})={{(-1)}^{k}}\cos (m{{\theta }_{yi,i}})-2k$,${{\theta }_{yi,i}}\in \left[ \frac{k\pi }{m},\frac{(k+1)\pi }{m} \right]$,$k\in \left[ 0,m-1 \right]$,$m\ge 1$

代码中引入了超参数λ,为

$\lambda =\max ({{\lambda }_{\min }},\frac{{{\lambda }_{\max }}}{1+0.1\times iterator})$

其中,${{\lambda }_{\min }}=5$,${{\lambda }_{\max }}=1500$为程序中预先设定的值。

实际的$\psi (\theta )$为

$\psi ({{\theta }_{yi}})=\frac{{{(-1)}^{k}}\cos (m{{\theta }_{yi}})-2k+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }$

对应下面代码为:

output = cos_theta * 1.0
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)

对于yi处的计算,

$output(yi)=\cos ({{\theta }_{yi}})-\frac{\cos ({{\theta }_{yi}})}{1+\lambda }+\frac{\psi ({{\theta }_{yi}})}{1+\lambda }=\frac{\psi ({{\theta }_{yi}})+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }=\frac{{{(-1)}^{k}}\cos (m{{\theta }_{yi}})-2k+\lambda \cos ({{\theta }_{yi}})}{1+\lambda }$

和上面的公式对应。

具体的代码如下(完整的代码见参考网址):

 class AngleLinear(nn.Module):
def __init__(self, in_features, out_features, m = 4, phiflag=True):
super(AngleLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(in_features,out_features))
self.weight.data.uniform_(-1, 1).renorm_(2,1,1e-5).mul_(1e5)
self.phiflag = phiflag
self.m = m
self.mlambda = [
lambda x: x**0, # cos(0*theta)=1
lambda x: x**1, # cos(1*theta)=cos(theta)
lambda x: 2*x**2-1, # cos(2*theta)=2*cos(theta)**2-1
lambda x: 4*x**3-3*x,
lambda x: 8*x**4-8*x**2+1,
lambda x: 16*x**5-20*x**3+5*x
] def forward(self, input): # input为输入的特征,(B, C),B为batchsize,C为图像的类别总数
x = input # size=(B,F),F为特征长度,如512
w = self.weight # size=(F,C) ww = w.renorm(2,1,1e-5).mul(1e5) #对w进行归一化,renorm使用L2范数对第1维度进行归一化,将大于1e-5的截断,乘以1e5,使得最终归一化到1.如果1e-5设置的过大,裁剪时某些很小的值最终可能小于1。注意,第0维度只对每一行进行归一化(每行平方和为1),第1维度指对每一列进行归一化。由于w的每一列为x的权重,因而此处需要对每一列进行归一化。如果要对x归一化,需要对每一行进行归一化,此时第二个参数应为0
xlen = x.pow(2).sum(1).pow(0.5) # 对输入x求平方,而后对不同列求和,再开方,得到每行的模,最终大小为第0维的,即B(由于对x不归一化,但是计算余弦时需要归一化,因而可以先计算模。但是对于w,不太懂为何不直接使用这种方式,而是使用renorm函数?)
wlen = ww.pow(2).sum(0).pow(0.5) # 对权重w求平方,而后对不同行求和,再开方,得到每列的模(理论上之前已经归一化,此处应该是1,但第一次运行到此处时,并不是1,不太懂),最终大小为第1维的,即C cos_theta = x.mm(ww) # 矩阵相乘(B,F)*(F,C)=(B,C),得到cos值,由于此处只是乘加,故未归一化
cos_theta = cos_theta / xlen.view(-1,1) / wlen.view(1,-1) # 对每个cos值均除以B和C,得到归一化后的cos值
cos_theta = cos_theta.clamp(-1,1) #将cos值截断到[-1,1]之间,理论上不截断应该也没有问题,毕竟w和x都归一化后,cos值不可能超出该范围 if self.phiflag:
cos_m_theta = self.mlambda[self.m](cos_theta) # 通过cos_theta计算cos_m_theta,mlambda为cos_m_theta展开的结果
theta = Variable(cos_theta.data.acos()) # 通过反余弦,计算角度theta,(B,C)
k = (self.m*theta/3.14159265).floor() # 通过公式,计算k,(B,C)。此处为了保证theta大于k*pi/m,转换过来就是m*theta/pi,再向上取整
n_one = k*0.0 - 1 # 通过k的大小,得到同样大小的-1矩阵,(B,C)
phi_theta = (n_one**k) * cos_m_theta - 2*k # 通过论文中公式,得到phi_theta。(B,C)
else:
theta = cos_theta.acos() # 得到角度theta,(B, C),每一行为当前特征和w的每一列的夹角
phi_theta = myphi(theta,self.m) #
phi_theta = phi_theta.clamp(-1*self.m,1) cos_theta = cos_theta * xlen.view(-1,1) # 由于实际上不对x进行归一化,此处cos_theta需要乘以B。(B,C)
phi_theta = phi_theta * xlen.view(-1,1) # 由于实际上不对x进行归一化,此处phi_theta需要乘以B。(B,C)
output = (cos_theta,phi_theta)
return output # size=(B,C,2) class AngleLoss(nn.Module):
def __init__(self, gamma=0):
super(AngleLoss, self).__init__()
self.gamma = gamma
self.it = 0
self.LambdaMin = 5.0
self.LambdaMax = 1500.0
self.lamb = 1500.0 def forward(self, input, target):
self.it += 1
cos_theta,phi_theta = input # cos_theta,(B,C)。 phi_theta,(B,C)
target = target.view(-1,1) #size=(B,1) index = cos_theta.data * 0.0 #得到和cos_theta相同大小的全0矩阵。(B,C)
index.scatter_(1,target.data.view(-1,1),1) # 得到一个one-hot矩阵,第i行只有target[i]的值为1,其他均为0
index = index.byte() # index为float的,转换成byte类型
index = Variable(index) self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it)) # 得到lamb
output = cos_theta * 1.0 #size=(B,C) # 如果直接使用output=cos_theta,可能不收敛(未测试,但其他程序中碰到过直接对输入使用[index]无法收敛,加上*1.0可以收敛的情况)
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb) # 此行及下一行将target[i]的值通过公式得到最终输出
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb) logpt = F.log_softmax(output) # 得到概率
logpt = logpt.gather(1,target) # 下面为交叉熵的计算(和focal loss的计算有点类似,当gamma为0时,为交叉熵)。
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp()) loss = -1 * (1-pt)**self.gamma * logpt
loss = loss.mean() # target = target.view(-1) # 若要简化,理论上可直接使用这两行计算交叉熵(此处未测试,在其他程序中使用后可以正常训练)
# loss = F.cross_entropy(cos_theta, target) return loss

(原)SphereFace及其pytorch代码的更多相关文章

  1. 目标检测之Faster-RCNN的pytorch代码详解(数据预处理篇)

    首先贴上代码原作者的github:https://github.com/chenyuntc/simple-faster-rcnn-pytorch(非代码作者,博文只解释代码) 今天看完了simple- ...

  2. (转载)PyTorch代码规范最佳实践和样式指南

    A PyTorch Tools, best practices & Styleguide 中文版:PyTorch代码规范最佳实践和样式指南 This is not an official st ...

  3. PyTorch代码调试利器: 自动print每行代码的Tensor信息

    本文介绍一个用于 PyTorch 代码的实用工具 TorchSnooper.作者是TorchSnooper的作者,也是PyTorch开发者之一. GitHub 项目地址: https://github ...

  4. 如何将tensorflow1.x代码改写为pytorch代码(以图注意力网络(GAT)为例)

    之前讲解了图注意力网络的官方tensorflow版的实现,由于自己更了解pytorch,所以打算将其改写为pytorch版本的. 对于图注意力网络还不了解的可以先去看看tensorflow版本的代码, ...

  5. pointnet.pytorch代码解析

    pointnet.pytorch代码解析 代码运行 Training cd utils python train_classification.py --dataset <dataset pat ...

  6. 残差网络resnet理解与pytorch代码实现

    写在前面 ​ 深度残差网络(Deep residual network, ResNet)自提出起,一次次刷新CNN模型在ImageNet中的成绩,解决了CNN模型难训练的问题.何凯明大神的工作令人佩服 ...

  7. 记录下pytorch代码从0.3版本迁移到0.4版本要做的一些更改。

    1. UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to in ...

  8. 运行pytorch代码遇到的error解决办法

    1.no CUDA-capable device is detected 首先考虑的是cuda的驱动问题,查看gpu显示是否正常,然后更新最新的cuda驱动: 第二个考虑的是cuda设备的默认参数是否 ...

  9. 目标检测之Faster-RCNN的pytorch代码详解(模型训练篇)

    本文所用代码gayhub的地址:https://github.com/chenyuntc/simple-faster-rcnn-pytorch  (非本人所写,博文只是解释代码) 好长时间没有发博客了 ...

随机推荐

  1. [PowerShell Utils] Automatically Change DNS and then Join Domain

    I would like to start a series of blog posts sharing PowerShell scripts to speed up our solution ope ...

  2. ASP.NET 网站管理工具介绍

    有没有感觉对 web.config 的操作很烦呢? 老是手动来编辑 web.config 确实挺麻烦的, 不过自 ASP.NET 2.0 起便有了 ASP.NET 网站管理工具, 这个工具呢,其实就是 ...

  3. Laravel Composer 脚本

    composer update --no-scripts 执行静态文件 composer dump-autoload 文件映射

  4. 转: H264码流分析 --264分析两大利器:264VISA和Elecard StreamEye Tools

    转码: http://www.360doc.com/content/13/0225/19/21412_267854467.shtml ESEYE视频工具全称是什么: Elecard StreamEye ...

  5. 【html5】HTML5中canvas怎样画虚线

    在canvas API中,我们发现仅仅提供了画实线的方法实现,并没有虚线的相关方法,那么怎样实现画虚线呢? 现实中,虚线是由一小段小段的实线线段组成,那么仅仅要我们通过画出等长度的线段就能够组成我们想 ...

  6. Docker worker nodes shown as “Down” after re-start

    After docker is shutdown, the worker node  changes its status to Down, but availability remains at A ...

  7. oauth2-server-php-docs 存储

    PDO 概观 PDO存储类使用 PHP 的PDO扩展.这允许连接到MySQL,SQLite,PostgreSQL 等等. 安装 PDO是默认安装的php 5.1+,这个库已经是必需的了,所以你会很好的 ...

  8. 保密员(baomi)

    #include<iostream> #include<string> #include<stdio.h> #include<algorithm> #i ...

  9. 015-Go 数据库操作注意事项

    1.Query.Exec(1)Exec(update.insert.delete等无结果集返回的操作)调用完后会自动释放连接:(2)Query(返回sql.Rows)则不会释放连接,调用完后仍然占有连 ...

  10. 转:清理系统垃圾的BAT代码

    @echo off title @echo off color 2 echo. echo. echo 请不要关闭此窗口! echo. echo 开始清理垃圾文件,请稍等...... echo. ech ...