• 概述

前面几节讲的是linear regression的内容,这里咱们再讲一个非常常用的一种模型那就是classification,classification顾名思义就是分类的意思,在实际的情况是非常常用的,例如咱们可以定义房价是否过高,如果房价高于100万,则房价过高,设置成true;如果房价低于100万,则房价不高,target就可以设置成false。这里的target就只有2种,分别只有True和False,而不像咱们的的linear regression那样target是连续的。在实际的应用中,这是有非常广泛的应用的,这一节的第一部分主要是讲如何用TensorFlow来训练一个classifier模型来预测classification problems。第二部分主要解释一下measure classification模型的的方法,那就是ROC curve。在linear regression中咱们知道有MAE,MSE等等一些列的方式来判断咱们的模型的表现怎么样,那么在classification中,MAE和MSE都不适用的,那么咱们用什么measurement来判断咱们的模型好不好呢?这时候就需要介绍咱们的ROC curve了。

  • TensorFlow应用之Classification

如果咱们的target只有2个(True/False 或者 1/0等等),这种情况咱们一般称之为binary classification problem;如果咱们的target的数量大于2,咱们一般称之为multi_class classification problem。这两种方式无论是哪一种,在咱们用TensorFlow训练的时候,它的的API都是一样的,只是multi-class需要在定义模型的的时候设置一个n_classes参数而已,其他都一样。另外的建模过程跟前面章节说的一样,这一节主要介绍一下他在TensorFlow的应用中跟linear regression的区别,所以我就不会展示整个建模的过程,只会展示他们的不同。第一个不同就是模型定义的时候不同,那么现在来看一下吧

linear_classifier = tf.estimator.LinearClassifier(feature_columns = construct_feature_columns(trainning_features),
optimizer = my_optimizer
                              ) linear_classifier = tf.estimator.LinearClassifier(feature_columns = configure_feature_columns(),
n_classes = 10,
optimizer = my_optimizer,
)

上面咱们可以看出来有两种定义classifier的方式,他们用的是LinearClassifier()来实例化模型的,而不像linear regression那样用LinearRegressor(); 其次上面的第一种没有n_classes这个参数,则说明是binary classification,因为他的默认值就是2;上面第二种方式则说明这是一个multi_class classification的问题。所以综上所述,它也是一个非常简单的定义的过程;

其次当咱们用的这个classifier来predict的时候,咱们可以看出来它的结果的数据结构跟linear regressor是不同的,下面我把的的结构在Spyder中打开给大家看一下

咱们可以看出来,它的prediction的结果是一个list,list里面的element是dictionary,每一个dictionary都有6个key-value pairs。这个dictionary的信息也是非常丰富的,从上面的图片可以看出来咱们的模型计算出来的结果就是class_ids这个key对应的value,当然啦,classifier计算的结果是每一种class的概率,然后选择概率最大的那一个;概率的对于的key很显然是probabilities这个字段。在其他方面,TensorFlow在classification problem中的应用的流程基本跟linear regression是一样的。

  • Classifier measurement
  •   Accuracy

  我们知道在前面的linear regression中,咱们可以用MAE,MES等等measurements来判断一个模型是否好呢?这里对于binary classification的问题,咱们可以经常使用Accuracy, ROC等方式来判断一个模型是否合格,另外在multi-classes的场景中,咱们也可以使用Accuracy和logloss等方法来判断,但是英文accuracy有它固定的缺陷,所以咱们经常不拿它作为最终参考的对象,只起一个辅助的作用。好了,那么咱们接下来来分别讲述一下他们的细节部分,在正式将这些metric之前,先给大家看一个谷歌官方教程的matrix

咱先来看看几个概念分别是TP, TN, FP, FN; 在这个例子中我们定义Malignant 的值是True, Benign的值是False。那么很显然上面绿色的部分就是咱们的模型预测的跟实际的是一样的;红色的部分则是怎么预测错误了,就是预测的跟实际的不一样。根据Accuracy的的定义,咱们很容易就能得到下面的公式

咱们来看看上面的例子,咱们总共的数据有100条,分别有91个Benign(良性), 9个Malignant(恶性); 如果咱们的模型预测的结果如上的matrix所示,那么根据accuracy的公式,咱们可以看出咱的的准确率高达91%,看其实还不错哦,对吧?那么咱们能用这个模型来预测吗??答案是不可以!!!咱们来仔细分析一下哈,上面的数据一共有9个malignant 恶性肿瘤,可是咱们的模型竟然只准确的预测出一个malignant (TP),其他的8个malignant竟然都没有预测出来,很显然这是有很大问题的!!!!那么为什么咱们的accuracy还是这么高呢??这是因为咱们的数据target的分布是非常不均匀的,换句话说咱们的数据是class-imbalanced dataset, 例如咱们的数据中有高达91%的Benign, 只有19%的的malignant,这个数据是非常不平衡的;咱们再举个极端的例子,上面的情况,即使咱们的模型prediction始终等于Negative,即无论什么数据送进来,咱们的预测结果始终都是Negative, 咱们的accuracy也是高达91%,这是不是很不合理??所以在判断咱们模型的时候,一定要慎用accuracy,尤其是在class-imbalanced dataset中。

  •   AUC-ROC Curve

这个是咱们在binary classification problem中判断一个模型好坏的一个最常用的一种方式, AUC是Aera Under Curve的缩写,很显然它是一个计算一条曲线下面的面积的函数, ROC是Receiver Operating Characteristics的缩写;那么根据名字咱们就知道AUC-ROC curve就是计算ROC curve下面的面积的一个方法。那么ROC curve到底长什么样呢?首先ROC的纵坐标是TPR(Ture Positive Rate), 横坐标是FPR(False Positive Rate). 那么具体的TPR和FPR又是什么意思呢,咱们看下面的公式

上面的公式分别表达TPR和FPR的定义的意思,其中Specificity咱们可以将它看成是Ture Negative Rate。那么这么看,咱们还是有点懵懵懂懂的不知道他们的具体含有,咱们可以结合下面的图来理解TPR, FPR

结合上面的图片,咱们可以认为TPR就是在所有的Positive的数据中,咱们正确预测出的Positive占咱们整个Positive数据数量的比例; FPR就是在咱们所有的Negative的数据中, 咱们错误的预测的数量占咱们整个Negative数据的比例。上面一句话一定一定要理解,否则什么都白瞎。那么咱们最终的ROC长什么样呢?看下图

咱们每选择一个threshold,咱ROC上面就绘制一个点,通过选择多个threshold最终画出了上面的ROC curve。 那么接着问题又来了,咱们既然已经绘制了ROC curve,咱们如何用它来判断咱们模型的好坏呢?其实就是通过计算ROC curve下面的面积(AUC)来判断的, AUC的意思是代表这这个模型分辨咱们classes的能力!!!!记住这句话,一定要记住。AUC->1代表着咱们的模型能够完全分辨出classes,AUC->0则说明咱们的模型预测的classes完全是相反的,其实这种情况也非常好,咱们只需要通过简单的取反就能够达到几乎完美的模型;最差的一种情况是AUC-> 0.5,这个时候意味着咱们的模型一点分辨能力都没有,跟咱们胡乱猜的结果是一样的。

  • Log Loss

根据咱们的分析,上面的ROC的方式只适用于binary classes的情况,那么如果咱们的classes有很多怎么办,例如有10个classes,这时候咱们就无法通过计算AUC-ROC的方式来判断咱们的模型了,咱们就得通过另外一种方式来判断咱们的模型好坏了,那么这个就是Log Loss了。具体LogLoss的数学意义以及原理,我会在下一节来解释,这里我就用最简单的方式演示一下他的应用,其实在classification problem中,咱们的loss function就是Log Loss, 在linear regression中咱们的loss function是 Mean Squared Error. 具体它们的意义,我会在后面的一节详细的展示它们的意义和推导过程。好了,那么现在咱们来看一下,咱们如何计算出咱们的log loss从而来判断出咱们classfier是不是一个好的可用的模型

trainning_predictions_one_hot = tf.keras.utils.to_categorical(trainning_class_id,10)
metrics.log_loss(trainning_targets, trainning_predictions_one_hot)
  • 总结

上面咱们介绍了一些classification 模型在训练中和linear regression不一样的地方,以及用什么metrics来最终判断咱们的classification模型,这里介绍了一下Accuracy, AUC-ROC和Log Loss. 其中的重点是AUC-ROC的含义和过程,然后知道Accuracy的一些应用场景,以为为什么有时候不能用它。最后了解一下Log loss是用来干什么的以及如何用它就行了。

机器学习-TensorFlow应用之classification和ROC curve的更多相关文章

  1. 机器学习入门12 - 分类 (Classification)

    原文链接:https://developers.google.com/machine-learning/crash-course/classification/ 1- 指定阈值 为了将逻辑回归值映射到 ...

  2. 机器学习之类别不平衡问题 (2) —— ROC和PR曲线

    机器学习之类别不平衡问题 (1) -- 各种评估指标 机器学习之类别不平衡问题 (2) -- ROC和PR曲线 完整代码 ROC曲线和PR(Precision - Recall)曲线皆为类别不平衡问题 ...

  3. 【AUC】二分类模型的评价指标ROC Curve

    AUC是指:从一堆样本中随机抽一个,抽到正样本的概率比抽到负样本的概率大的可能性! AUC是一个模型评价指标,只能用于二分类模型的评价,对于二分类模型,还有很多其他评价指标,比如logloss,acc ...

  4. Area Under roc Curve(AUC)

    AUC是一种用来度量分类模型好坏的一个标准. ROC分析是从医疗分析领域引入了一种新的分类模型performance评判方法. ROC的全名叫做Receiver Operating Character ...

  5. AUC(Area Under roc Curve)学习笔记

    AUC是一种用来度量分类模型好坏的一个标准. ROC分析是从医疗分析领域引入了一种新的分类模型performance评判方法. ROC的全名叫做Receiver Operating Character ...

  6. AUC(Area Under roc Curve )计算及其与ROC的关系

    转载: http://blog.csdn.net/chjjunking/article/details/5933105 让我们从头说起,首先AUC是一种用来度量分类模型好坏的一个标准.这样的标准其实有 ...

  7. 【转】AUC(Area Under roc Curve )计算及其与ROC的关系

    让我们从头说起,首先AUC是一种用来度量分类模型好坏的一个标准.这样的标准其实有很多,例如:大约10年前在machine learning文献中一统天下的标准:分类精度:在信息检索(IR)领域中常用的 ...

  8. 机器学习-Confusion Matrix混淆矩阵、ROC、AUC

    本文整理了关于机器学习分类问题的评价指标——Confusion Matrix.ROC.AUC的概念以及理解. 混淆矩阵 在机器学习领域中,混淆矩阵(confusion matrix)是一种评价分类模型 ...

  9. iOS机器学习-TensorFlow

    人工智能.机器学习都已走进了我们的日常,尤其是愈演愈热的大数据更是跟我们的生活息息相关,做 人工智能.数据挖掘的人在其他人眼中感觉是很高大上的,总有一种遥不可及的感觉,在我司也经常会听到数据科学部的同 ...

随机推荐

  1. freemarker<三>

    前两篇博客介绍了freemaker是什么以及简单的语法规则,下面我们通过实现一个demo来看在实际应用中如何使用freemaker,本篇博客主要介绍freemaker与spring的整合. 需要的Ja ...

  2. 利用脚本运行APP

    1.电脑安装Xcode(iOS)/Androidsdk(Android),连接手机,并在手机上安装相应代理,下图为iOS的Xcode代理样式: 2.打开Appium,点击搜索图标,添加并设置该手机信息 ...

  3. world 文档中表格旋转180°

    一个好朋友给我打电话,说是有个wps操作把他难住了,他常年跟wps 形影不离,你都搞不定,我都不怎么用.听完他说的以后,我才明白他要的效果是怎么样的,贴图来看: 其实像直接转化成这种效果没有办法,但是 ...

  4. Linux(Centos)安装node及anyproxy

    一.安装node //下载 wget https://nodejs.org/dist/v10.9.0/node-v10.9.0-linux-x64.tar.xz //解压 tar xf node-v1 ...

  5. Mongdb的基本操作及java中用法

    Mongdb中所有数据以Bson(类似JSON)的格式存在,可以存储集合,map,二进制文件等多种数据类型. 数据库的常用操作 use [数据库名称];//有就选中,没有就添加并选中show dbs; ...

  6. Elasticsearch系列---实战搜索语法

    概要 本篇介绍Query DSL的语法案例,查询语句的调试,以及排序的相关内容. 基本语法 空查询 最简单的搜索命令,不指定索引和类型的空搜索,它将返回集群下所有索引的所有文档(默认显示10条): G ...

  7. Flink State Backends (状态后端)

    State Backends 的作用 有状态的流计算是Flink的一大特点,状态本质上是数据,数据是需要维护的,例如数据库就是维护数据的一种解决方案.State Backends 的作用就是用来维护S ...

  8. $NIM$游戏小总结

    $umm$可能之后会写个博弈论总结然后就直接把这个复制粘贴上去就把这个删了 但因为还没学完所以先随便写个$NIM$游戏总结趴$QAQ$ 首先最基础的$NIM$游戏:有$n$堆石子,每次可以从一堆中取若 ...

  9. 1076 Wifi密码 (15 分)C语言

    下面是微博上流传的一张照片:"各位亲爱的同学们,鉴于大家有时需要使用 wifi,又怕耽误亲们的学习,现将 wifi 密码设置为下列数学题答案:A-1:B-2:C-3:D-4:请同学们自己作答 ...

  10. 深度学习论文翻译解析(六):MobileNets:Efficient Convolutional Neural Networks for Mobile Vision Appliications

    论文标题:MobileNets:Efficient Convolutional Neural Networks for Mobile Vision Appliications 论文作者:Andrew ...