混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)
原理
在机器学习中, 混淆矩阵是一个误差矩阵, 常用来可视化地评估监督学习算法的性能. 混淆矩阵大小为 (n_classes, n_classes) 的方阵, 其中 n_classes 表示类的数量. 这个矩阵的每一行表示真实类中的实例, 而每一列表示预测类中的实例 (Tensorflow 和 scikit-learn 采用的实现方式). 也可以是, 每一行表示预测类中的实例, 而每一列表示真实类中的实例 (Confusion matrix From Wikipedia 中的定义). 通过混淆矩阵, 可以很容易看出系统是否会弄混两个类, 这也是混淆矩阵名字的由来.
混淆矩阵是一种特殊类型的列联表(contingency table)或交叉制表(cross tabulation or crosstab). 其有两维 (真实值 "actual" 和 预测值 "predicted" ), 这两维都具有相同的类("classes")的集合. 在列联表中, 每个维度和类的组合是一个变量. 列联表以表的形式, 可视化地表示多个变量的频率分布.
使用混淆矩阵( scikit-learn 和 Tensorflow)
下面先介绍在 scikit-learn 和 tensorflow 中计算混淆矩阵的 API (Application Programming Interface) 接口函数, 然后在一个示例中, 使用这两个 API 函数.
scikit-learn 混淆矩阵函数 sklearn.metrics.confusion_matrix API 接口
skearn.metrics.confusion_matrix(
y_true, # array, Gound true (correct) target values
y_pred, # array, Estimated targets as returned by a classifier
labels=None, # array, List of labels to index the matrix.
sample_weight=None # array-like of shape = [n_samples], Optional sample weights
)
在 scikit-learn 中, 计算混淆矩阵用来评估分类的准确度.
按照定义, 混淆矩阵 C 中的元素 Ci,j 等于真实值为组 i , 而预测为组 j 的观测数(the number of observations). 所以对于二分类任务, 预测结果中, 正确的负例数(true negatives, TN)为 C0,0; 错误的负例数(false negatives, FN)为 C1,0; 真实的正例数为 C1,1; 错误的正例数为 C0,1.
如果 labels 为 None, scikit-learn 会把在出现在 y_true 或 y_pred 中的所有值添加到标记列表 labels 中, 并排好序.
Tensorflow 混淆矩阵函数 tf.confusion_matrix API 接口
tf.confusion_matrix(
labels, # 1-D Tensor of real labels for the classification task
predictions, # 1-D Tensor of predictions for a givenclassification
num_classes=None, # The possible number of labels the classification task can have
dtype=tf.int32, # Data type of the confusion matrix
name=None, # Scope name
weights=None, # An optional Tensor whose shape matches predictions
)
Tensorflow tf.confusion_matrix 中的 num_classes 参数的含义, 与 scikit-learn sklearn.metrics.confusion_matrix 中的 labels 参数相近, 是与标记有关的参数, 表示类的总个数, 但没有列出具体的标记值. 在 Tensorflow 中一般是以整数作为标记, 如果标记为字符串等非整数类型, 则需先转为整数表示. 如果 num_classes 参数为 None, 则把 labels 和 predictions 中的最大值 + 1, 作为 num_classes 参数值.
tf.confusion_matrix 的 weights 参数和 sklearn.metrics.confusion_matrix 的 sample_weight 参数的含义相同, 都是对预测值进行加权, 在此基础上, 计算混淆矩阵单元的值.
使用示例
#!/usr/bin/env python
# -*- coding: utf8 -*-
"""
Author: klchang
Description:
A simple example for tf.confusion_matrix and sklearn.metrics.confusion_matrix.
Date: 2018.9.8
"""
from __future__ import print_function
import tensorflow as tf
import sklearn.metrics y_true = [1, 2, 4]
y_pred = [2, 2, 4] # Build graph with tf.confusion_matrix operation
sess = tf.InteractiveSession()
op = tf.confusion_matrix(y_true, y_pred)
op2 = tf.confusion_matrix(y_true, y_pred, num_classes=6, dtype=tf.float32, weights=tf.constant([0.3, 0.4, 0.3]))
# Execute the graph
print ("confusion matrix in tensorflow: ")
print ("1. default: \n", op.eval())
print ("2. customed: \n", sess.run(op2))
sess.close() # Use sklearn.metrics.confusion_matrix function
print ("\nconfusion matrix in scikit-learn: ")
print ("1. default: \n", sklearn.metrics.confusion_matrix(y_true, y_pred))
print ("2. customed: \n", sklearn.metrics.confusion_matrix(y_true, y_pred, labels=range(6), sample_weight=[0.3, 0.4, 0.3]))
参考资料
1. Confusion matrix. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Confusion_matrix
2. Contingency table. In Wikipedia, The Free Encyclopedia. https://en.wikipedia.org/wiki/Contingency_table
3. Tensorflow API - tf.confusion_matrix. https://www.tensorflow.com/api_docs/python/tf/confusion_matrix
4. scikit-learn API - sklearn.metrics.confusion_matrix. http://scikit-learn.org/stable/modules/generated/sklearn.metrics.confusion_matrix.html
混淆矩阵(Confusion matrix)的原理及使用(scikit-learn 和 tensorflow)的更多相关文章
- ML01 机器学习后利用混淆矩阵Confusion matrix 进行结果分析
目标: 快速理解什么是混淆矩阵, 混淆矩阵是用来干嘛的. 首先理解什么是confusion matrix 看定义,在机器学习领域,混淆矩阵(confusion matrix),又称为可能性表格或是 ...
- python画混淆矩阵(confusion matrix)
混淆矩阵(Confusion Matrix),是一种在深度学习中常用的辅助工具,可以让你直观地了解你的模型在哪一类样本里面表现得不是很好. 如上图,我们就可以看到,有一个样本原本是0的,却被预测成了1 ...
- 【分类模型评判指标 一】混淆矩阵(Confusion Matrix)
转自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80520839 略有改动,仅供个人学习使用 简介 混淆矩阵是ROC曲线绘制的基础 ...
- WEKA “Detailed Accuracy By Class”和“Confusion Matrix”含义
原文 === Summary ===(总结) Correctly Classified Instances(正确分类的实例) 45 90 % I ...
- 机器学习-Confusion Matrix混淆矩阵、ROC、AUC
本文整理了关于机器学习分类问题的评价指标——Confusion Matrix.ROC.AUC的概念以及理解. 混淆矩阵 在机器学习领域中,混淆矩阵(confusion matrix)是一种评价分类模型 ...
- 混淆矩阵、准确率、精确率/查准率、召回率/查全率、F1值、ROC曲线的AUC值
准确率.精确率(查准率).召回率(查全率).F1值.ROC曲线的AUC值,都可以作为评价一个机器学习模型好坏的指标(evaluation metrics),而这些评价指标直接或间接都与混淆矩阵有关,前 ...
- 评估分类器性能的度量,像混淆矩阵、ROC、AUC等
评估分类器性能的度量,像混淆矩阵.ROC.AUC等 内容概要¶ 模型评估的目的及一般评估流程 分类准确率的用处及其限制 混淆矩阵(confusion matrix)是如何表示一个分类器的性能 混淆矩阵 ...
- 混淆矩阵在Matlab中PRtools模式识别工具箱的应用
声明:本文用到的代码均来自于PRTools(http://www.prtools.org)模式识别工具箱,并以matlab软件进行实验. 混淆矩阵是模式识别中的常用工具,在PRTools工具箱中有直接 ...
- 机器学习入门-混淆矩阵-准确度-召回率-F1score 1.itertools.product 2. confusion_matrix(test_y, pred_y)
1. itertools.product 进行数据的多种组合 intertools.product(range(0, 1), range(0, 1)) 组合的情况[0, 0], [0, 1], [ ...
随机推荐
- 【Java并发编程】:并发新特性—塞队列和阻塞栈
阻塞队列 阻塞队列是Java5并发新特性中的内容,阻塞队列的接口是Java.util.concurrent.BlockingQueue,它有多个实现类:ArrayBlockingQueue.Delay ...
- 关于Spring配置的一些东西
Spring 配置的三种方式:JAVA配置,注解配置,和XML的配置 注解配置: @Service:标识服务层(业务层)组件 @Component:基本注解, 标识了一个受 Spring 管理的组件( ...
- Django的配置文件(settings.py)
初始项目的配置文件 新建项目默认settings.py的内容的 """ Django settings for ORM project. Generated by 'dj ...
- windows下安装并使用redis
一.安装前首先了解一下phpinfo里面的一些信息,能否正确安装非常有帮助. (下图是我的本机环境) compiler :编译器 Architecture :CPU架构 Configuration F ...
- Impala 使用的端口
下表中列出了 Impala 是用的 TCP 端口.在部署 Impala 之前,请确保每个系统上这些端口都是打开的. 组件 服务 端口 访问需求 备注 Impala Daemon Impala 守护进程 ...
- 深度剖析Dubbo源码
-----------------学习dubbo源码,能给你带来什么好处?----------- 1.提升SOA的微服务架构设计能力 通过读dubbo源码是一条非常不错的通往SOA架构设计之路,毕 ...
- ASP.NET Core 中的日志记录
目录 内置日志的使用 使用Nlog 集成ELK 参考 内置日志的使用 Logger 是 asp .net core 的内置 service,所以我们就不需要在ConfigureService里面注册了 ...
- rake aborted! You have already activated rake 10.1.0, but your Gemfile requires rake 10.0.3. Using bundle exec may solve this.
问题: wyy@wyy:~/moumentei-master$ rake db:createrake aborted!You have already activated rake 10.1.0, b ...
- 任务三十八:UI组件之排序表格
任务三十八:UI组件之排序表格 面向人群: 有一定JavaScript基础 难度: 低 重要说明 百度前端技术学院的课程任务是由百度前端工程师专为对前端不同掌握程度的同学设计.我们尽力保证课程内容的质 ...
- 使用 Flask 框架写用户登录功能的Demo时碰到的各种坑(四)——对 run.py 的调整
使用 Flask 框架写用户登录功能的Demo时碰到的各种坑(一)——创建应用 使用 Flask 框架写用户登录功能的Demo时碰到的各种坑(二)——使用蓝图功能进行模块化 使用 Flask 框架写用 ...