论文提出用于out-of-distributions输入检测的energy-based方案,通过非概率的energy score区分in-distribution数据和out-of-distribution数据。不同于softmax置信度,energy score能够对齐输入数据的密度,提升OOD检测的准确率,对算法的实际应用有很大的意义

来源:晓飞的算法工程笔记 公众号

论文: Energy-based Out-of-distribution Detection

Introduction


 今天给大家分享一篇基于能量函数来区分非训练集相关(Out-of-distribution, OOD)输入的文章,连LeCun看了都说好。虽然文章是2020年的,但里面的内容对算法实际应用比较有用。这篇文章提出的能量模型想法在之前分享的《OWOD:开放世界目标检测,更贴近现实的检测场景 | CVPR 2021 Oral》也有应用,有兴趣的可以去看看。

 现实世界是开放且未知的,OOD由于与训练集差异很大,使用通过特定训练集训练出来的模型进行预测的话,往往会出现不可控的结果。因此,确定输入是否为OOD并过滤掉,对算法在高安全要求场景下的应用是十分重要的。

 大部分OOD研究依赖softmax置信度来过滤OOD输入,将低置信度的认定为OOD。然而,由于网络通常已经过拟合了输入空间,softmax对跟训练集差异较大的输入时常会不稳定地返回高置信度,所以softmax并不是OOD检测的最佳方法。还有部分OOD研究则从生成模型的角度来产生输入的似然分数\(logp(x)\),但这种方法在实践中难以实现而且很不稳定,因为需要估计整个输入空间的归一化密度。

 为此,论文提出energy-based方法来检测OOD输入,将输入映射为energy score,能直接应用到当前的网络中。论文还提供了理论证明和实验验证,表明这种energy-based方法比softmax-based和generative-based方法更优。

Background: Energy-based Models


 EBM(energy-based model)的核心是建立一个函数\(E(x): \mathbb{R}^D\to\mathbb{R}\),将输入\(x\)映射为一个叫energy的常量。

 一组energy常量可以通过Gibbs分布转为概率分布\(p(x)\):

 这里将\((x,y)\)作为输入,对于分类场景,\(E(x,y)\)可认为是数据与标签相关的energy常量。分母\(\int_{y^{'}}e^{-E(x,y^{'})/T}\)是配分函数,即所有标签energy的整合,\(T\)是温度系数。

 输入数据\(x\in\mathbb{R}^D\)的Helmholtz free energy \(E(x)\)可表示为配分函数的负对数:

 结合公式1和公式2就构建了一个跟分类模型十分类似的EBM,可以通过Gibbs分布将多个energy输出转换为概率输出,还可以通过Helmholtz free energy得出最终的energy。

 对于分类模型,分类器\(f(x):\mathbb{R}^D\to\mathbb{R}^K\)将输入映射为K个值(logits),随后通过softmax函数将其转换为类别分布:

 其中\(f_y(x)\)为\(f(x)\)的第\(y\)个输出。

 通过关联公式1和公式3,可在不改动网络\(f(x)\)的情况下将分类网络转换为EBM。定义输入\((x,y)\)的energy为softmax的对应输入值\(E(x,y)=-f_y(x)\),再定义\(x\in\mathbb{R}^D\)的free energy为:

Energy-based Out-of-distribution Detection


Energy as Inference-time OOD Score

 Out-of-distribution detection是个二分类问题,评价函数需要产生一个能够判定ID(in-distribution)数据和OOD(out-of-distribution)数据的分数。因此,论文尝试在分类模型上接入energy函数,通过energy进行OOD检测。energy较小的为ID数据,energy较大的为OOD数据。

 实际上,通过负对数似然(negative log-likelihood,NLL))损失训练的模型本身就倾向于拉低ID数据的energy,负对数似然损失可表示为:

 定义energy函数\(E(x,y)=-f_y(x)\)并将\(log\)里面的分数展开,NLL损失可转换为:

 从损失值越低越好的优化角度看,公式6的第一项倾向于拉低目标类别\(y\)的energy,而公式6第二项从形式来看相当于输入数据的free energy。第二项导致整体损失函数倾向于拉低目标类别\(y\)的energy,同时拉高其它标签的energy,可以从梯度的角度进行解释:

 上述式子是对两项的梯度进行整合,分为目标类别相关的梯度和非目标相关的梯度。可以看到,目标类别相关的梯度是倾向于更小的energy,而非目标类别相关的梯度由于前面有负号,所以是倾向于更大的energy。另外,由于energy近似为\(-f_y(x)=E(x,y)\),通常都是目标类别的值比较大,所以NLL损失整体倾向于拉低ID数据的energy。

 由于上述的energy特性,就可以基于energy函数\(E(x;f)\)进行OOD检测:

 其中\(\tau\)为energy阈值,energy高于该阈值的被认定为OOD数据。在实际测试中,使用ID数据计算阈值,保证大部分的训练数据能被\(g(x)\)正确地区分。另外,需要注意的是,这里用了负energy分数\(-E(x;f)\),是为了遵循正样本有更高分数的常规定义。

Energy Score vs. Softmax Score

 论文先通过公式推导,来证明energy可以简单又高效地在任意训练好的模型上代替softmax置信度。将sofmax置信度进行对数展开,结合公式4以及\(T=1\)进行符号转换:

 从上述式子可以看出,softmax置信度的对数实际上是free energy的特例,先将每个energy减去最大的energy进行偏移(shift),再进行free energy的计算,导致置信度与输入的概率密度不匹配。随着训练的进行,通常\(f^{max}(x)\)会变高,而\(E(x; f)\)则变低,所以softmax是有偏评价函数,置信度也不适用于OOD检测。

 论文也通过真实的例子来进行对比,ID数据和OOD数据的softmax置信度差别很小(1.0 vs 0.99),而负energy则更有区分度(11.19 vs. 7.11)。因此,网络输出的值(energy)比偏移后的值(softmax score)包含更多有用的信息。

Energy-bounded Learning for OOD Detection

 尽管energy score能够直接应用于训练好的模型,但ID数据和OOD数据的区分度可能还不够明显。为此,论文提出energy-bounded学习目标,对训练好的网络进行fine-tuned训练,显示地扩大ID数据和OOD数据之间的energy差异:

 \(F(x)\)为softmax输出,\(D^{train}_{in}\)为ID数据。整体的训练目标包含标准交叉熵损失,以及基于energy的正则损失:

 \(D^{train}_{out}\)为无标签的辅助OOD数据集,通过两个平方hinge损失对energy进行正则化,惩罚energy大于间隔参数\(m_{in}\)的ID数据以及energy小于间隔参数\(m_{out}\)的OOD数据。当模型fine-tuned好后,即可根据公式7进行OOD检测。

Experiment


 ID数据集包含CIFA-10、CIFAR-100,并且分割训练集和测试集。OOD测试数据集包含Textures、SVHN、Places365、LSUN-Crop、LSUN_Resize和iSUN。辅助用的OOD数据集则采用80 Million Tiny Images,去掉CIFAR里面出现的类别。

 评价指标采用以下:1)在ID数据95%正确的\(\tau\)阈值下的OOD数据错误率。2)ROC曲线下的区域大小(AUROC)。3)PR曲线下的区域大小(AUPR)。

 从结果可以看出,energy score比softmax score的表现要好,而经过fine-tuned之后,错误就很低了。

 与其他OOD方法进行比较。

 可视化对比。

Conclustion


 论文提出用于out-of-distributions输入检测的energy-based方案,通过非概率的energy score区分in-distribution数据和out-of-distribution数据。不同于softmax置信度,energy score能够对齐输入数据的密度,提升OOD检测的准确率,对算法的实际应用有很大的意义。





如果本文对你有帮助,麻烦点个赞或在看呗~

更多内容请关注 微信公众号【晓飞的算法工程笔记】

基于energy score的out-of-distribution数据检测,LeCun都说好 | NerulPS 2020的更多相关文章

  1. TOP100summit:【分享实录-Microsoft】基于Kafka与Spark的实时大数据质量监控平台

    本篇文章内容来自2016年TOP100summit Microsoft资深产品经理邢国冬的案例分享.编辑:Cynthia 邢国冬(Tony Xing):Microsoft资深产品经理.负责微软应用与服 ...

  2. 基于 WebSocket 实现 WebGL 3D 拓扑图实时数据通讯同步(二)

    我们上一篇<基于 WebSocket 实现 WebGL 3D 拓扑图实时数据通讯同步(一)>主要讲解了如何搭建一个实时数据通讯服务器,客户端与服务端是如何通讯的,相信通过上一篇的讲解,再配 ...

  3. Mybatis框架基于注解的方式,实对数据现增删改查

    编写Mybatis代码,与spring不一样,不需要导入插件,只需导入架包即可: 在lib下 导入mybatis架包:mybatis-3.1.1.jarmysql驱动架包:mysql-connecto ...

  4. 一款基于jQuery饼状图比例分布数据报表

    今天给大家带来一款基于jQuery饼状图比例分布数据报表.这款报表插件适用浏览器:IE8.360.FireFox.Chrome.Safari.Opera.傲游.搜狗.世界之窗.效果图如下: 在线预览  ...

  5. Python 基于Python从mysql表读取千万数据实践

    基于Python 从mysql表读取千万数据实践   by:授客 QQ:1033553122 场景:   有以下两个表,两者都有一个表字段,名为waybill_no,我们需要从tl_waybill_b ...

  6. 基于Thinkphp5+phpQuery 网络爬虫抓取数据接口,统一输出接口数据api

    TP5_Splider 一个基于Thinkphp5+phpQuery 网络爬虫抓取数据接口 统一输出接口数据api.适合正在学习Vue,AngularJs框架学习 开发demo,需要接口并保证接口不跨 ...

  7. 项目一:第四天 1、快递员的条件分页查询-noSession,条件查询 2、快递员删除(逻辑删除) 3、基于Apache POI实现批量导入区域数据 a)Jquery OCUpload上传文件插件使用 b)Apache POI读取excel文件数据

    1. 快递员的条件分页查询-noSession,条件查询 2. 快递员删除(逻辑删除) 3. 基于Apache POI实现批量导入区域数据 a) Jquery OCUpload上传文件插件使用 b) ...

  8. (转) 基于Arcgis for Js的web GIS数据在线采集简介

    http://blog.csdn.net/gisshixisheng/article/details/44310765 在前一篇博文“Arcgis for js之WKT和geometry转换”中实现了 ...

  9. 应用层级时空记忆模型(HTM)实现对实时异常流时序数据检测

    应用层级时空记忆模型(HTM)实现对实时异常流时序数据检测 Real-Time Anomaly Detection for Streaming Analytics Subutai Ahmad SAHM ...

  10. segMatch:基于3D点云分割的回环检测

    该论文的地址是:https://arxiv.org/pdf/1609.07720.pdf segmatch是一个提供车辆的回环检测的技术,使用提取和匹配分割的三维激光点云技术.分割的例子可以在下面的图 ...

随机推荐

  1. win32 - QueryDisplayConfig的使用

    QueryDisplayConfig函数检索关于所有显示设备的所有可能的显示路径,或视图,在当前设置的信息. C++样本: (开箱即用) 代码列出了所有显示器的名称和拓展模式 #include < ...

  2. 分层架构设计模式总结-MVC,洋葱架构,整洁架构,六边形架构,DDD等等

    一.单层结构不分层 最开始开发项目时,由于需求较少,用一个单独的工程文件就可以满足开发的需求了,不需要进行划分. 二.MVC 分层和三层 到后面需求越来越多,于是就把文件进行分解,怎么分解?有人提出了 ...

  3. [Android 逆向]frida 破解 切水果大战原版.apk

    1. 手机安装该apk,运行,点击右上角礼物 提示 支付失败,请稍后重试 2. apk拖入到jadx中,待加载完毕后,搜素失败,找到疑似目标类MymmPay的关键方法payResultFalse 4. ...

  4. 商店销售预测(回归&随机森林)

    ​ 目录 一.题目概要 二.导入包和数据集 三.数据处理 四.描述性分析 五.探索性数据分析 六.模型一:线性回归 七.模型2:随机森林 一.题目概要 在Kaggle竞赛中,要求我们应用时间序列预测, ...

  5. 新零售SaaS架构:订单履约系统的应用架构梳理

    订单履约系统的核心能力 通过分析订单履约的全流程和各个业务活动,我们可以梳理出订单履约的核心业务链路,基于业务链路,我们抽象出订单履约系统的三大系统能力,分别为履约服务表达.履约调度.物流配送. 履约 ...

  6. python基础安装虚拟环境

    1.pip install virtualenv或者pip3 install virtualenv 2.在要存放虚拟环境的地方创建一个venv文件夹,用来存放所有创建的虚拟环境,方便查找与管理 3.m ...

  7. 在Windows环境中配置使用我们搭建的DNS服务器

    1.修改网卡的设置,首选DNS用我们自己的 2.在命令行中测试 专业的nslookup 3.已知的问题 每次在DNS服务器的web界面中,修改了解析,必须用docker restart dns命令,把 ...

  8. 【Azure Spring Cloud】部署Azure spring cloud 失败

    问题描述 使用Azure CLI指令部署Azure Spring Cloud项目失败,错误消息提示没有安装"azure.storage.blob"模块 问题分析 根据错误提示,是p ...

  9. 【Azure 应用服务】在Azure App Service多实例的情况下,如何在应用中通过代码获取到实例名(Instance ID)呢?

    问题描述 App Service开启多实例后,如何在代码中获取当前请求所真实到达的实例ID(Instance ID)呢? 问题答案 App Service 通过 环境变量的方式 显示的输出实例ID等信 ...

  10. Java 接口:比较对象的大小

    1 package com.bytezreo.interfacetest; 2 3 /** 4 * 5 * @Description 接口:比较对象的大小 6 * @author Bytezero·z ...