学习笔记26— roc曲线(python)
一、概念:
准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure
机器学习(ML), 自然语言处理(NLP), 信息检索(IR)等领域, 评估(Evaluation)是一个必要的工作, 而其评价指标往往有如下几点: 准确率(Accuracy), 精确率(Precision), 召回率(Recall) 和 F1-Measure.(注:相对来说,IR 的 ground truth 很多时候是一个 Ordered List, 而不是一个 Bool 类型的 Unordered Collection,在都找到的情况下,排在第三名还是第四名损失并不是很大,而排在第一名和第一百名,虽然都是“找到了”,但是意义是不一样的,因此更多可能适用于 MAP 之类评估指标.)
本文将简单介绍其中几个概念. 中文中这几个评价指标翻译各有不同, 所以一般情况下推荐使用英文.
现在我先假定一个具体场景作为例子.
假如某个班级有男生 80 人, 女生20人, 共计 100 人. 目标是找出所有女生. 现在某人挑选出 50 个人, 其中 20 人是女生, 另外还错误的把 30 个男生也当作女生挑选出来了. 作为评估者的你需要来评估(evaluation)下他的工作
首先我们可以计算准确率(accuracy), 其定义是: 对于给定的测试数据集,分类器正确分类的样本数与总样本数之比. 也就是损失函数是0-1损失时测试数据集上的准确率[1].
这样说听起来有点抽象,简单说就是,前面的场景中,实际情况是那个班级有男的和女的两类,某人(也就是定义中所说的分类器)他又把班级中的人分为男女两类. accuracy 需要得到的是此君分正确的人占总人数的比例. 很容易,我们可以得到:他把其中70(20女+50男)人判定正确了, 而总人数是100人,所以它的 accuracy 就是70 %(70 / 100).
由准确率,我们的确可以在一些场合,从某种意义上得到一个分类器是否有效,但它并不总是能有效的评价一个分类器的工作. 举个例子, google 抓取了 argcv 100个页面,而它索引中共有10,000,000个页面, 随机抽一个页面,分类下, 这是不是 argcv 的页面呢?如果以 accuracy 来判断我的工作,那我会把所有的页面都判断为"不是 argcv 的页面", 因为我这样效率非常高(return false, 一句话), 而 accuracy 已经到了99.999%(9,999,900/10,000,000), 完爆其它很多分类器辛辛苦苦算的值, 而我这个算法显然不是需求期待的, 那怎么解决呢?这就是 precision, recall 和 f1-measure 出场的时间了.
再说 precision, recall 和 f1-measure 之前, 我们需要先需要定义 TP, FN, FP, TN 四种分类情况.
按照前面例子, 我们需要从一个班级中的人中寻找所有女生, 如果把这个任务当成一个分类器的话, 那么女生就是我们需要的, 而男生不是, 所以我们称女生为"正类", 而男生为"负类".
相关(Relevant), 正类 | 无关(NonRelevant), 负类 | |
---|---|---|
被检索到(Retrieved) | true positives (TP 正类判定为正类, 例子中就是正确的判定"这位是女生") | false positives (FP 负类判定为正类,"存伪", 例子中就是分明是男生却判断为女生, 当下伪娘横行, 这个错常有人犯) |
未被检索到(Not Retrieved) | false negatives (FN 正类判定为负类,"去真", 例子中就是, 分明是女生, 这哥们却判断为男生--梁山伯同学犯的错就是这个) | true negatives (TN 负类判定为负类, 也就是一个男生被判断为男生, 像我这样的纯爷们一准儿就会在此处) |
或者
可以很容易看出, 所谓 TRUE/FALSE 表示从结果是否分对了, Positive/Negative 表示我们认为的是"是"还是"不是".
通过这张表, 我们可以很容易得到这几个值:
- TP=20
- FP=30
- FN=0
- TN=50
精确率(precision)的公式是P=TPTP+FPP = \frac{TP}{TP+FP}P=TP+FPTP, 它计算的是所有"正确被检索的结果(TP)"占所有"实际被检索到的(TP+FP)"的比例.
在例子中就是希望知道此君得到的所有人中, 正确的人(也就是女生)占有的比例. 所以其 precision 也就是40%(20女生/(20女生+30误判为女生的男生)).
召回率(recall)的公式是R=TPTP+FNR = \frac{TP}{TP+FN}R=TP+FNTP, 它计算的是所有"正确被检索的结果(TP)"占所有"应该检索到的结果(TP+FN)"的比例.
在例子中就是希望知道此君得到的女生占本班中所有女生的比例, 所以其 recall 也就是100%(20女生/(20女生+ 0 误判为男生的女生))
F1值就是精确值和召回率的调和均值, 也就是
2F1=1P+1R\frac{2}{F_1} = \frac{1}{P} + \frac{1}{R} F12=P1+R1
调整下也就是
F1=2PRP+R=2TP2TP+FP+FNF_1 = \frac{2PR}{P+R} = \frac{2TP}{2TP + FP + FN} F1=P+R2PR=2TP+FP+FN2TP
例子中 F1-measure 也就是约为 57.143%(2∗0.4∗10.4+1\frac{2 * 0.4 * 1}{0.4 + 1}0.4+12∗0.4∗1).
需要说明的是, 有人[3]列了这样个公式, 对非负实数β\betaβ
Fβ=(β2+1)∗PRβ2∗P+RF_\beta = (\beta^2 + 1) * \frac{PR}{\beta^2*P+R}Fβ=(β2+1)∗β2∗P+RPR
将 F-measure 一般化.
F1-measure 认为精确率和召回率的权重是一样的, 但有些场景下, 我们可能认为精确率会更加重要, 调整参数 β\betaβ , 使用 Fβ_\betaβ-measure 可以帮助我们更好的 evaluate 结果.
注意:召回率就是tpr(TP/(TP+FN)) (命中率),fpr(FP/(FP+TN)):特异性(specificity),TNR(TN/(FP+TN)):敏感性
参考链接:http://mlwiki.org/index.php/ROC_Analysis#ROC_Analysis
https://blog.argcv.com/articles/1036.c
http://www.cnblogs.com/pinard/p/5993450.html
1、简单(源码):
print(__doc__)
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from scipy import interp
# Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
# Binarize the output
y = label_binarize(y, classes=[0, 1, 2])
n_classes = y.shape[1]
# Add noisy features to make the problem harder
random_state = np.random.RandomState(0)
n_samples, n_features = X.shape
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)]
# shuffle and split training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=.5,
random_state=0)
# Learn to predict each class against the other
classifier = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True,
random_state=random_state))
y_score = classifier.fit(X_train, y_train).decision_function(X_test)
# Compute ROC curve and ROC area for each class
fpr = dict()
tpr = dict()
roc_auc = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i], y_score[:, i])
roc_auc[i] = auc(fpr[i], tpr[i])
# Compute micro-average ROC curve and ROC area
fpr["micro"], tpr["micro"], _ = roc_curve(y_test.ravel(), y_score.ravel())
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
#Plot of a ROC curve for a specific class
plt.figure()
lw = 2
plt.plot(fpr[2], tpr[2], color='darkorange',
lw=lw, label='ROC curve (area = %0.2f)' % roc_auc[2])
plt.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
2、复杂(源码):
print(__doc__) import numpy as np
from scipy import interp
import matplotlib.pyplot as plt
from itertools import cycle from sklearn import svm, datasets
from sklearn.metrics import roc_curve, auc
from sklearn.model_selection import StratifiedKFold # #############################################################################
# Data IO and generation # Import some data to play with
iris = datasets.load_iris()
X = iris.data
y = iris.target
X, y = X[y != 2], y[y != 2]
n_samples, n_features = X.shape # Add noisy features
random_state = np.random.RandomState(0)
X = np.c_[X, random_state.randn(n_samples, 200 * n_features)] # #############################################################################
# Classification and ROC analysis # Run classifier with cross-validation and plot ROC curves
cv = StratifiedKFold(n_splits=6)
classifier = svm.SVC(kernel='linear', probability=True,
random_state=random_state) tprs = []
aucs = []
mean_fpr = np.linspace(0, 1, 100) i = 0
for train, test in cv.split(X, y):
probas_ = classifier.fit(X[train], y[train]).predict_proba(X[test])
# Compute ROC curve and area the curve
fpr, tpr, thresholds = roc_curve(y[test], probas_[:, 1])
tprs.append(interp(mean_fpr, fpr, tpr))
tprs[-1][0] = 0.0
roc_auc = auc(fpr, tpr)
aucs.append(roc_auc)
plt.plot(fpr, tpr, lw=1, alpha=0.3,
label='ROC fold %d (AUC = %0.2f)' % (i, roc_auc)) i += 1
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='r',
label='Chance', alpha=.8) mean_tpr = np.mean(tprs, axis=0)
mean_tpr[-1] = 1.0
mean_auc = auc(mean_fpr, mean_tpr)
std_auc = np.std(aucs)
plt.plot(mean_fpr, mean_tpr, color='b',
label=r'Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc, std_auc),
lw=2, alpha=.8) std_tpr = np.std(tprs, axis=0)
tprs_upper = np.minimum(mean_tpr + std_tpr, 1)
tprs_lower = np.maximum(mean_tpr - std_tpr, 0)
plt.fill_between(mean_fpr, tprs_lower, tprs_upper, color='grey', alpha=.2,
label=r'$\pm$ 1 std. dev.') plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic example')
plt.legend(loc="lower right")
plt.show()
Best threshold:
- we know the slope of the accuracy line: it's 1
- the best classifier for this slope is the 6th one
- threshold value θ
- so we take the score obtained on the 6th record
- and use it as the threshold value θ
- i.e. predict positive if θ⩾0.54
- if we check, we see that indeed we have accuracy = 0.7
学习笔记26— roc曲线(python)的更多相关文章
- python3.4学习笔记(二十六) Python 输出json到文件,让json.dumps输出中文 实例代码
python3.4学习笔记(二十六) Python 输出json到文件,让json.dumps输出中文 实例代码 python的json.dumps方法默认会输出成这种格式"\u535a\u ...
- python3.4学习笔记(二十五) Python 调用mysql redis实例代码
python3.4学习笔记(二十五) Python 调用mysql redis实例代码 #coding: utf-8 __author__ = 'zdz8207' #python2.7 import ...
- python3.4学习笔记(二十四) Python pycharm window安装redis MySQL-python相关方法
python3.4学习笔记(二十四) Python pycharm window安装redis MySQL-python相关方法window安装redis,下载Redis的压缩包https://git ...
- python3.4学习笔记(二十二) python 在字符串里面插入指定分割符,将list中的字符转为数字
python3.4学习笔记(二十二) python 在字符串里面插入指定分割符,将list中的字符转为数字在字符串里面插入指定分割符的方法,先把字符串变成list然后用join方法变成字符串str=' ...
- openresty 学习笔记番外篇:python的一些扩展库
openresty 学习笔记番外篇:python的一些扩展库 要写一个可以使用的python程序还需要比如日志输出,读取配置文件,作为守护进程运行等 读取配置文件 使用自带的ConfigParser模 ...
- openresty 学习笔记番外篇:python访问RabbitMQ消息队列
openresty 学习笔记番外篇:python访问RabbitMQ消息队列 python使用pika扩展库操作RabbitMQ的流程梳理. 客户端连接到消息队列服务器,打开一个channel. 客户 ...
- python学习笔记(三)---python关键字及其用法
转载出处:https://www.cnblogs.com/ECJTUACM-873284962/p/7576959.html 前言 最近在学习Java Sockst的时候遇到了一些麻烦事,我觉得我很有 ...
- [原创]java WEB学习笔记26:MVC案例完整实践(part 7)---修改的设计和实现
本博客为原创:综合 尚硅谷(http://www.atguigu.com)的系统教程(深表感谢)和 网络上的现有资源(博客,文档,图书等),资源的出处我会标明 本博客的目的:①总结自己的学习过程,相当 ...
- python学习笔记:安装boost python库以及使用boost.python库封装
学习是一个累积的过程.在这个过程中,我们不仅要学习新的知识,还需要将以前学到的知识进行回顾总结. 前面讲述了Python使用ctypes直接调用动态库和使用Python的C语言API封装C函数, C+ ...
随机推荐
- js DOM常见事件
js事件命名为on+动词 1.onclick事件,点击鼠标时触发,ondbclick双击事件 <h1 onclick="this.innerHTML='点击后文本'"> ...
- python pip
如果pip的版本较低,可能导致pip时安装出错,所以我们要更新pip版本-- 查询pip版本 pip -V -- Linux and OS X 升级 pip install -U pip -- Win ...
- The POM for XXX is invalid, transitive dependencies (if any) will not be available解决方案
今天,某个开发的环境在编译的时候提示警告The POM for XXX is invalid, transitive dependencies (if any) will not be availab ...
- 【题解】Luogu P4391 [BOI2009]Radio Transmission 无线传输
原题传送门 这题需要用到kmp匹配 推导发现: 设循环节的长度为x,那么kmp数组前x个都是0,后面kmp[x+n]=n 先求出kmp数组 答案实际就是len-kmp[len] #include &l ...
- LVS群集配置
第一步:网络环境配置内网网段:10.0.0.0/24DR:10.0.0.254rs1:10.0.0.1rs2:10.0.0.2nfs:10.0.0.3 第二步:nfs和web服务搭建 nfs服务器:安 ...
- QT5下的caffe项目属性
TEMPLATE = app CONFIG += console c++11 CONFIG -= app_bundle CONFIG -= qt SOURCES += /home/aimhabo/ca ...
- topcoder srm 679 div1
problem1 link $f[u][0],f[u][1]$表示$u$节点表示的子树去掉和不去掉节点$u$的最大权值. problem2 link 首先预处理计算任意三个蓝点组成的三角形中的蓝点个数 ...
- 深刻理解Python中的元类(metaclass)以及元类实现单例模式
在理解元类之前,你需要先掌握Python中的类.Python中类的概念借鉴于Smalltalk,这显得有些奇特.在大多数编程语言中,类就是一组用来描述如何生成一个对象的代码段.在Python中这一点仍 ...
- luoguP4072 [SDOI2016]征途
[SDOI2016]征途 大体 大概就是推推公式,发现很傻逼的\(n^3\)DP get60 进一步我们发现状态不能入手,考虑优化转移 套个斜率优化板子 每一层转移来一次斜率优化 思路 先便便式子 \ ...
- 深度学习课程笔记(五)Ensemble
深度学习课程笔记(五)Ensemble 2017.10.06 材料来自: 首先提到的是 Bagging 的方法: 我们可以利用这里的 Bagging 的方法,结合多个强分类器,来提升总的结果.例如: ...