# -*- coding: utf-8 -*-
"""
Created on Mon Sep 10 11:21:27 2018 @author: zhen
"""
from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home') x, y = mnist['data'], mnist['target']
some_digit = x[36000] #获取第36000行数据 some_digit_image = some_digit.reshape(28, 28) plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
interpolation='nearest', vmin=0, vmax=1)
plt.axis('off')
plt.show() x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000) x_train, y_train = x_train[shuffle_index], y_train[shuffle_index] y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5) sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
sgd_clf.fit(x_train, y_train_5) result = sgd_clf.predict([some_digit]) print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall')) sgd_clf.fit(x_train, y_train_5) y_scores = sgd_clf.decision_function([some_digit]) threshold = 0
y_some_digit_pred = (y_scores > threshold) threshold = 200000
y_some_digit_pred = (y_scores > threshold) # cv 数据集划分的个数
y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function') precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
plt.xlabel("Threshold")
plt.legend(loc='upper left')
plt.ylim([0, 1])
plt.show() def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label='roc')
plt.plot([0, 1], [0, 1], 'k--', label='mid')
plt.legend(loc='lower right')
# plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度
plt.xlabel('fpr')
plt.ylabel('tpr')
plt.show() plot_precision_recall_vs_threshold(precisions, recalls, thresholds) fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
plot_roc_curve(fpr, tpr) print(roc_auc_score(y_train_5, y_scores)) forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, 'b:', label='SGD')
plt.plot(fpr_forest, tpr_forest, label='Random Forest')
plt.legend(loc='lower right')
plt.show() print(roc_auc_score(y_train_5, y_scores_forest))

          

总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!

评估指标【交叉验证&ROC曲线】的更多相关文章

  1. 【分类模型评判指标 二】ROC曲线与AUC面积

    转自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80499031 略有改动,仅供个人学习使用 简介 ROC曲线与AUC面积均是用来 ...

  2. 【机器学习】--模型评估指标之混淆矩阵,ROC曲线和AUC面积

    一.前述 怎么样对训练出来的模型进行评估是有一定指标的,本文就相关指标做一个总结. 二.具体 1.混淆矩阵 混淆矩阵如图:  第一个参数true,false是指预测的正确性.  第二个参数true,p ...

  3. 评价指标的局限性、ROC曲线、余弦距离、A/B测试、模型评估的方法、超参数调优、过拟合与欠拟合

    1.评价指标的局限性 问题1 准确性的局限性 准确率是分类问题中最简单也是最直观的评价指标,但存在明显的缺陷.比如,当负样本占99%时,分类器把所有样本都预测为负样本也可以获得99%的准确率.所以,当 ...

  4. 评估指标:ROC,AUC,Precision、Recall、F1-score

    一.ROC,AUC ROC(Receiver Operating Characteristic)曲线和AUC常被用来评价一个二值分类器(binary classifier)的优劣 . ROC曲线一般的 ...

  5. 召回率、AUC、ROC模型评估指标精要

    混淆矩阵 精准率/查准率,presicion 预测为正的样本中实际为正的概率 召回率/查全率,recall 实际为正的样本中被预测为正的概率 TPR F1分数,同时考虑查准率和查全率,二者达到平衡,= ...

  6. PR曲线,ROC曲线,AUC指标等,Accuracy vs Precision

    作为机器学习重要的评价指标,标题中的三个内容,在下面读书笔记里面都有讲: http://www.cnblogs.com/charlesblc/p/6188562.html 但是讲的不细,不太懂.今天又 ...

  7. 从TP、FP、TN、FN到ROC曲线、miss rate、行人检测评估

    从TP.FP.TN.FN到ROC曲线.miss rate.行人检测评估 想要在行人检测的evaluation阶段要计算miss rate,就要从True Positive Rate讲起:miss ra ...

  8. [机器学习] 性能评估指标(精确率、召回率、ROC、AUC)

    混淆矩阵 介绍这些概念之前先来介绍一个概念:混淆矩阵(confusion matrix).对于 k 元分类,其实它就是一个k x k的表格,用来记录分类器的预测结果.对于常见的二元分类,它的混淆矩阵是 ...

  9. 机器学习 - 案例 - 样本不均衡数据分析 - 信用卡诈骗 ( 标准化处理, 数据不均处理, 交叉验证, 评估, Recall值, 混淆矩阵, 阈值 )

    案例背景 银行评判用户的信用考量规避信用卡诈骗 ▒ 数据 数据共有 31 个特征, 为了安全起见数据已经向了模糊化处理无法读出真实信息目标 其中数据中的 class 特征标识为是否正常用户 (0 代表 ...

随机推荐

  1. 从On-Premise本地到On-Cloud云上运维的演进

    摘要: 从用户的声音中,我们听到用户对稳定.弹性.透明的诉求,我们也在不断升级ECS的运维能力和体验,助力用户建立主动运维体系,赋能业务永续运行.为了让大家更好的了解和用好ECS弹性计算服务,从本期开 ...

  2. C++对象生存期&&static

    生存期,即从诞生到消失的时间段,在生存期内,对象的值或保持不变,知道改变他的值为止.对象生存期分为静态生存期和动态生存期两种. 静态生存期 指对象的生存期与程序运行期相同.在namespace中声明的 ...

  3. SQL两列数据,行转列

    SQL中只有两列数据(字段1,字段2),将其相同字段1的行转列 转换前: 转换后: --测试数据 if not object_id(N'Tempdb..#T') is null drop table ...

  4. WinForm 工作流设计 1

    从事软件行业那么多年,一直很少写博.很多技术,长时间不用都慢慢淡忘. 把自己学到的用笔记下来,可以巩固和发现不足,也可以把自己对技术的一些 理解,分享出来供大家批评指正. 废话不多说,进入正题.工作流 ...

  5. c语言-自己写的库

    一.俗话说算法是程序的灵魂,这下面本人写了一部分常用算法,欢迎大家使用,并提出批评和指正,当然也可以改进或者添加. 1.这是自己实现的算法库头文件 #ifndef _INC_ALGORITHM #de ...

  6. Docker 删除&清理镜像

    文章首发自个人网站:https://www.exception.site/docker/docker-delete-image 本文中,您将学习 Docker 如何删除及清理镜像? 一.通过标签删除镜 ...

  7. ASP.NET Core WebApi使用Swagger生成api说明文档看这篇就够了

    引言 在使用asp.net core 进行api开发完成后,书写api说明文档对于程序员来说想必是件很痛苦的事情吧,但文档又必须写,而且文档的格式如果没有具体要求的话,最终完成的文档则完全取决于开发者 ...

  8. springboot~如何去掌握它(新手可以看看)

    springboot~如何去掌握它 主讲:仓储大叔 每讲40分钟 架构图 graph LR App-->A Web-->A A(zuul proxy)-->B(eureka serv ...

  9. Linux用户和权限管理看了你就会用啦

    前言 只有光头才能变强 回顾前面: 看完这篇Linux基本的操作就会了 没想到上一篇能在知乎获得千赞呀,Linux也快期末考试了,也有半个月没有写文章了.这篇主要将Linux下的用户和权限知识点再整理 ...

  10. 2.1命令行和JSON的配置「深入浅出ASP.NET Core系列」

    希望给你3-5分钟的碎片化学习,可能是坐地铁.等公交,积少成多,水滴石穿,谢谢关注. 命令行配置 1.新建控制台项目 2.nuget引入microsoft.aspnetcore.all 这里要注意版本 ...