# -*- coding: utf-8 -*-
"""
Created on Mon Sep 10 11:21:27 2018 @author: zhen
"""
from sklearn.datasets import fetch_mldata
import numpy as np
from sklearn.linear_model import SGDClassifier
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import precision_recall_curve
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve
from sklearn.metrics import roc_auc_score
from sklearn.ensemble import RandomForestClassifier mnist = fetch_mldata('MNIST original', data_home='D:/AnalyseData学习资源库/人工智能开发/分类评估/资料/test_data_home') x, y = mnist['data'], mnist['target']
some_digit = x[36000] #获取第36000行数据 some_digit_image = some_digit.reshape(28, 28) plt.imshow(some_digit_image, cmap=matplotlib.cm.binary,
interpolation='nearest', vmin=0, vmax=1)
plt.axis('off')
plt.show() x_train, x_test, y_train, y_test = x[:60000], x[60000:], y[:60000], y[60000:]
shuffle_index = np.random.permutation(60000) x_train, y_train = x_train[shuffle_index], y_train[shuffle_index] y_train_5 = (y_train == 5)
y_test_5 = (y_test == 5) sgd_clf = SGDClassifier(loss='log', random_state=42, max_iter=1000, tol=1e-4)
sgd_clf.fit(x_train, y_train_5) result = sgd_clf.predict([some_digit]) print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='accuracy'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='precision'))
print(cross_val_score(sgd_clf, x_train, y_train_5, cv=3, scoring='recall')) sgd_clf.fit(x_train, y_train_5) y_scores = sgd_clf.decision_function([some_digit]) threshold = 0
y_some_digit_pred = (y_scores > threshold) threshold = 200000
y_some_digit_pred = (y_scores > threshold) # cv 数据集划分的个数
y_scores = cross_val_predict(sgd_clf, x_train, y_train_5, cv=3, method='decision_function') precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores) def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
plt.plot(thresholds, precisions[:-1], 'b--',label='Precision')
plt.plot(thresholds, recalls[:-1], 'r--', label='Recall')
plt.xlabel("Threshold")
plt.legend(loc='upper left')
plt.ylim([0, 1])
plt.show() def plot_roc_curve(fpr, tpr, label=None):
plt.plot(fpr, tpr, linewidth=2, label='roc')
plt.plot([0, 1], [0, 1], 'k--', label='mid')
plt.legend(loc='lower right')
# plt.axes([0, 1, 0, 1]) : 前两个参数表示坐标原点的位置,后两个表示x,y轴的长度
plt.xlabel('fpr')
plt.ylabel('tpr')
plt.show() plot_precision_recall_vs_threshold(precisions, recalls, thresholds) fpr, tpr, thresholds = roc_curve(y_train_5, y_scores)
plot_roc_curve(fpr, tpr) print(roc_auc_score(y_train_5, y_scores)) forest_clf = RandomForestClassifier(random_state=42)
y_probas_forest = cross_val_predict(forest_clf, x_train, y_train_5, cv=3, method='predict_proba')
y_scores_forest = y_probas_forest[:, 1]
fpr_forest, tpr_forest, thresholds_forest = roc_curve(y_train_5, y_scores_forest)
plt.plot(fpr, tpr, 'b:', label='SGD')
plt.plot(fpr_forest, tpr_forest, label='Random Forest')
plt.legend(loc='lower right')
plt.show() print(roc_auc_score(y_train_5, y_scores_forest))

          

总结:正向准确率和召回率在整体上成反比,可知在使用相同数据集,相同验证方式的情况下,随机森林要优于随机梯度下降!

评估指标【交叉验证&ROC曲线】的更多相关文章

  1. 【分类模型评判指标 二】ROC曲线与AUC面积

    转自:https://blog.csdn.net/Orange_Spotty_Cat/article/details/80499031 略有改动,仅供个人学习使用 简介 ROC曲线与AUC面积均是用来 ...

  2. 【机器学习】--模型评估指标之混淆矩阵,ROC曲线和AUC面积

    一.前述 怎么样对训练出来的模型进行评估是有一定指标的,本文就相关指标做一个总结. 二.具体 1.混淆矩阵 混淆矩阵如图:  第一个参数true,false是指预测的正确性.  第二个参数true,p ...

  3. 评价指标的局限性、ROC曲线、余弦距离、A/B测试、模型评估的方法、超参数调优、过拟合与欠拟合

    1.评价指标的局限性 问题1 准确性的局限性 准确率是分类问题中最简单也是最直观的评价指标,但存在明显的缺陷.比如,当负样本占99%时,分类器把所有样本都预测为负样本也可以获得99%的准确率.所以,当 ...

  4. 评估指标:ROC,AUC,Precision、Recall、F1-score

    一.ROC,AUC ROC(Receiver Operating Characteristic)曲线和AUC常被用来评价一个二值分类器(binary classifier)的优劣 . ROC曲线一般的 ...

  5. 召回率、AUC、ROC模型评估指标精要

    混淆矩阵 精准率/查准率,presicion 预测为正的样本中实际为正的概率 召回率/查全率,recall 实际为正的样本中被预测为正的概率 TPR F1分数,同时考虑查准率和查全率,二者达到平衡,= ...

  6. PR曲线,ROC曲线,AUC指标等,Accuracy vs Precision

    作为机器学习重要的评价指标,标题中的三个内容,在下面读书笔记里面都有讲: http://www.cnblogs.com/charlesblc/p/6188562.html 但是讲的不细,不太懂.今天又 ...

  7. 从TP、FP、TN、FN到ROC曲线、miss rate、行人检测评估

    从TP.FP.TN.FN到ROC曲线.miss rate.行人检测评估 想要在行人检测的evaluation阶段要计算miss rate,就要从True Positive Rate讲起:miss ra ...

  8. [机器学习] 性能评估指标(精确率、召回率、ROC、AUC)

    混淆矩阵 介绍这些概念之前先来介绍一个概念:混淆矩阵(confusion matrix).对于 k 元分类,其实它就是一个k x k的表格,用来记录分类器的预测结果.对于常见的二元分类,它的混淆矩阵是 ...

  9. 机器学习 - 案例 - 样本不均衡数据分析 - 信用卡诈骗 ( 标准化处理, 数据不均处理, 交叉验证, 评估, Recall值, 混淆矩阵, 阈值 )

    案例背景 银行评判用户的信用考量规避信用卡诈骗 ▒ 数据 数据共有 31 个特征, 为了安全起见数据已经向了模糊化处理无法读出真实信息目标 其中数据中的 class 特征标识为是否正常用户 (0 代表 ...

随机推荐

  1. Python猜数小游戏

    使用random变量随机生成一个1到100之间的数 采集用户所输入的数字,如果输入的不符合要求会让用户重新输入. 输入符合要求,游戏开始.如果数字大于随机数,输出数字太大:如果小于随机数,输出数字太小 ...

  2. apache-jmeter-5.0的简单压力测试使用方法

    同事交接工作,压测部分交给我,记录一下使用方法 我将下载下来的压缩包解压后放置在E盘 然后配置环境变量: 变量名JMETER_HOME,变量值 E:\javatool\apache-jmeter-5. ...

  3. 关于微信JS-SDK 分享接口的两个报错记录

    一.前提: 微信测试号,用微信开发者工具测试 二.简单复述文档: 1.引入JS文件 在需要调用JS接口的页面引入如下JS文件,(支持https):http://res.wx.qq.com/open/j ...

  4. Win7 + CentOS7 双系统

    记录一下更改系统启动菜单的方法. 前提: 1. 先安装 Win7 在硬盘第一分区,其它分区在 Win7 下处于未分配状态. 2. 再安装 CentOS 到上述未分配分区.(注意:手动分区时,可以留一定 ...

  5. PyQt5 api 帮助文档

    学习PyQt5的帮助文档是通过,使用help(PyQt5 class)的方式在console端输出帮助内容,常用的方法和属性查找起来不是很方便,现在放在网上以方便大家使用. QWidget Qt QM ...

  6. Python爬虫入门教程 51-100 Python3爬虫通过m3u8文件下载ts视频-Python爬虫6操作

    什么是m3u8文件 M3U8文件是指UTF-8编码格式的M3U文件. M3U文件是记录了一个索引纯文本文件, 打开它时播放软件并不是播放它,而是根据它的索引找到对应的音视频文件的网络地址进行在线播放. ...

  7. 从壹开始前后端分离 [ Vue2.0+.NET Core2.1] 二十五║初探SSR服务端渲染(个人博客二)

    缘起 时间真快,现在已经是这个系列教程的下半部 Vue 第 12 篇了,昨天我也简单思考了下,可能明天再来一篇,Vue 就基本告一段落了,因为什么呢,这里给大家说个题外话,当时写博文的时候,只是想给大 ...

  8. qml demo分析(rssnews-常见新闻布局)

    一.效果展示 今儿来分析一篇常见的ui布局,完全使用qml编写,ui交互效果友好,如图1所示,是一个常见的客户端新闻展示效果,左侧是一个列表,右侧是新闻详情. 图1 新闻效果图 二.源码分析 首先先来 ...

  9. windows系统下用python更新svn和Git

    转载请标明出处:http://www.cnblogs.com/zblade/ 最近在思考怎么实现python的一键打包,利用python的跨平台特性,可以实现在windows和mac下均可执行的特点. ...

  10. SmartSql 动态代理仓储

    SmartSql 动态代理仓储,一个高生产力的组件.该组件看似很难懂,实际上仅做了映射Statement,转发请求的功能.但却意义重大. SmartSql提供了一个通用泛型仓储接口 SmartSql. ...