机器学习的评估

PR曲线用于positive类数据占比比较小,或者你更加在意false postion(相比于false negative);其他情况采用ROC曲线;比如Demo中手写体5的判断,因为只有少量5,所以从ROC上面来看分类效果不错,但是从PR曲线可以看到分类器效果不佳。

 

y_scores = sgd_clf.decision_function([some_digit])

decision_function代表的是参数实例到各个类所代表的超平面的距离;在梯度下滑里面特有的(随机森林里面没有decision_function),这个返回的距离,或者说是分值;后续的对于这个值的利用方式是指定阈值来进行过滤:

>>> y_scores = sgd_clf.decision_function([some_digit])

>>> y_scores

array([ 161855.74572176])

>>> threshold = 0

>>> y_some_digit_pred = (y_scores > threshold)

array([ True], dtype=bool)

 

>>> threshold = 200000

>>> y_some_digit_pred = (y_scores > threshold)

>>> y_some_digit_pred

array([False], dtype=bool)

通过上面例子看到了,通过decision_function可以获得一种"分值",这个分值的几何意义就是当前点到超平面(hyperplane)的距离;然后,你可以利用这个分值来和某个阈值做比较(距离的阈值),超过阈值则通过,低于阈值则不通过。再举一个例子:

>>> sgd_clf.fit(X_train, y_train) # y_train, not y_train_5

>>> sgd_clf.predict([some_digit])

array([ 5.])

some_digit_scores=sdg_clf.decision_function([some_digit])

some_digit_scores

array([[-311402.62954431, -363517.28355739, -446449.5306454 ,

-183226.61023518, -414337.15339485, 161855.74572176,

-452576.39616343, -471957.14962573, -518542.33997148,

-536774.63961222]])

 

sgd_clf.fit(X_train, y_train)这个梯度下降算法学习的对象是说有手写训练样本以及0-9的分类标签,基于学习的模型调用decision_function之后,获取是[some_digit]所有的标签到超平面的距离,其中只有5是正值,所以如果调用predict的话返回的就是5。但是,如果我们训练的分类器是二元分类器(True,false),那么情况又不同:

y_train_5 =(y_train==5)

>>> sgd_clf.fit(X_train,y_train_5) # y_train, not y_train_5

>>> sgd_clf.predict([some_digit])

array([ True])

因为y_train_5这个标签集合只有True和False两种标签,所以训练之后的模型预测的知识True和false;所以到底是二元分类还是多元分类完全取决于训练的时候的标签集。

 

predict:用于分类模型的预测分类;

fit:对于线性回归的模型学习,称之为"拟合";

y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")

cross_val_predict是交叉获取分类概率(注意,这里的method参数设置为"predict_proba",代表返回值返回的是预期分类的概率)

参考:

http://scikit-learn.org/stable/modules/generated/sklearn.svm.SVC.html

 

?这有一个问题其实每太搞懂,就是scores和predict的关系到底什么,cross_val_score的机制和cross_val_predit之间的差别是什么,文中代码如下:

from sklearn.ensemble import RandomForestClassifier

forest_clf = RandomForestClassifier(random_state=42)

y_probas_forest = cross_val_predict(forest_clf, X_train, y_train_5, cv=3,

method="predict_proba")

 

But to plot a ROC curve, you need scores, not probabilities. A simple solution is to

use the positive class's probability as the score:

 

y_scores_forest = y_probas_forest[:, 1] # score = proba of positive class

fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5,y_scores_forest)

 

不过这里的代码可以看出一些端倪:

from sklearn.ensemble import RandomForestClassifier

forest_clf=RandomForestClassifier(random_state=42)

y_probas_forest=cross_val_predict(forest_clf, X_train, y_train_5, cv=3, method="predict_proba")

print y_probas_forest

[[0.9 0.1] [1. 0. ] [1. 0. ] ... [1. 0. ] [0.9 0.1] [1. 0. ]]

y_scores_forest=y_probas_forest[:, 1]

print y_scores_forest

[0.1 0. 0. ... 0. 0.1 0. ]

你可以看到,scores是probas的二维数组的第二维的值。那么问题来了,作为cross_val_predict里面的数据,二维数组中这个二维到底是什么?这个二维数组其实代表的是各个分类的概率,对于二分类而言,就是为negative的概率以及position概率;对于scores其实就是为position的分类信息。那就意味着如果N个分类(classification),那么就是N维数组了。

另外对于森林分类器里面有一个method的参数,例子中值是"predict_proba",这个代表着预测各个分类的概率;他还有很多其他选项:

predict:代表的是预测的分类,就是会挑选概率最大的分类返回;

predict_log_proba:算法和predic_proba是一样的,但是最后对于结果会取对数运算,目的是放大值,避免在概率的相乘中会产生一些极小值,然后会因为舍入问题导致误差;另外一些机器算法(比如散度KL)本身就是基于对数运算的。最后,贝叶斯的分类算法需要通过对数运算(log)来实现稳定性;

对于cv=3,代表采用三折交叉验证,即将数据随机分为三份(或者尽量保持数据的均匀分布性),每次拿其中的一份来做测试集(另外两份做训练集),然后将三次的结果(每个测试样本各个分类的概率)做一下平均值;

 

参考

https://stackoverflow.com/questions/20335944/why-use-log-probability-estimates-in-gaussiannb-scikit-learn

https://www.reddit.com/r/MLQuestions/comments/5lzv9o/sklearn_why_predict_log_proba/

https://baike.baidu.com/item/%E5%AF%B9%E6%95%B0%E5%85%AC%E5%BC%8F

https://stats.stackexchange.com/questions/329857/what-is-the-difference-between-decision-function-predict-proba-and-predict-fun

https://stackoverflow.com/questions/36543137/whats-the-difference-between-predict-proba-and-decision-function-in-scikit-lear

Decision_function:scores,predict以及其他的更多相关文章

  1. [LeetCode] Predict the Winner 预测赢家

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  2. [Swift]LeetCode486. 预测赢家 | Predict the Winner

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  3. Predict the Winner LT486

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  4. Minimax-486. Predict the Winner

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  5. 动态规划-Predict the Winner

    2018-04-22 19:19:47 问题描述: Given an array of scores that are non-negative integers. Player 1 picks on ...

  6. 486. Predict the Winner

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  7. LeetCode Predict the Winner

    原题链接在这里:https://leetcode.com/problems/predict-the-winner/description/ 题目: Given an array of scores t ...

  8. LN : leetcode 486 Predict the Winner

    lc 486 Predict the Winner 486 Predict the Winner Given an array of scores that are non-negative inte ...

  9. [LeetCode] 486. Predict the Winner 预测赢家

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

  10. LC 486. Predict the Winner

    Given an array of scores that are non-negative integers. Player 1 picks one of the numbers from eith ...

随机推荐

  1. java中4种修饰符访问权限的区别

    访问权限 类 本包 子类 其他包 public √ √ √ √ protected √ √ √ x default(缺省) √ √ x x private √ x x x

  2. Python学习之路day3-集合

    一.概述 集合(set)是一种无序且不重复的序列. 无序不重复的特点决定它存在以下的应用场景: 去重处理 关系测试 差集.并集.交集等,下文详述. 二.创建集合 创建集合的方法与创建字典类似,但没有键 ...

  3. 快速切题 sgu104. Little shop of flowers DP 难度:0

    104. Little shop of flowers time limit per test: 0.25 sec. memory limit per test: 4096 KB PROBLEM Yo ...

  4. 跟我一起学习ASP.NET 4.5 MVC4.0(六)

    这一系列文章跨度有点大,由于最近忙于其他事情,没有更新,今天重新安装了下Win8系统,VS2012和SQLServer 2012,顺便抽空继续一篇.随着VS2012 RC版本的放出,ASP.NET M ...

  5. Cetus

    转自:https://github.com/Lede-Inc/cetus Cetus 简介 Cetus是由C语言开发的关系型数据库MySQL的中间件,主要提供了一个全面的数据库访问代理功能.Cetus ...

  6. Django小示例

    创建项目,在命令行中输入:django-admin startproject mysite 则会创建一个名为mysite的项目.项目结构如下: +mysite |--+ugo |          | ...

  7. 谷歌Gmail诞生记:十年回首

    美国<时代>周刊网络版今天刊登题为<Gmail诞生记:10年前鲜为人知的故事>(How Gmail Happened: The Inside Story of Its Laun ...

  8. React 源码剖析系列 - 不可思议的 react diff

      简单点的重复利用已有的dom和其他REACT性能快的原理. key的作用和虚拟节点 目前,前端领域中 React 势头正盛,使用者众多却少有能够深入剖析内部实现机制和原理. 本系列文章希望通过剖析 ...

  9. Vue.js 源码学习笔记 -- 分析前准备1 -- vue三大利器

    主体 实例方法归类:   先看个作者推荐, 清晰易懂的  23232 简易编译器   重点: 最简单的订阅者模式 // Observer class Observer { constructor (d ...

  10. REST easy with kbmMW #15 – Handling HTTP POST

    我被问到有关如何通过基于kbmMW智能服务(Smart Service)的REST处理POST的问题. 这篇博客文章解释了典型的POST各种形式的访问,以及如何在kbmMW中处理它们. POST变种W ...