思想:

# 1.现将所有样本用交叉验证方法或者(随机抽样方法) 得到 K对 训练集-验证集
# 2.依次对K个训练集,拿出数量不断增加的子集如m个,并在这些K*m个子集上训练模型。
# 3.依次在对应训练集子集、验证集上计算得分。
# 4.对每种大小下的子集,计算K次训练集得分均值和K次验证集得分均值,共得到m对值。
# 5.绘制学习率曲线。x轴训练集样本量,y轴模型得分或预测准确率。

用到的方法:

learning_curve #直接得到1个模型在不同训练集大小参数下:1.训练集大小 2.训练得分 3.测试得分

ShuffleSplit  #实现交叉验证、或 随机抽样划分不同的训练集合验证集

plt.fill_between

python代码:

import numpy as np
from sklearn.model_selection import learning_curve, ShuffleSplit
from sklearn.datasets import load_digits
from sklearn.naive_bayes import GaussianNB
from sklearn import svm
import matplotlib.pyplot as plt
def plot_learning_curve(estimator, title, X, y, ylim=None, cv=None,n_jobs=1, train_size=np.linspace(.1, 1.0, 5 )):
if __name__ == '__main__':
plt.figure()
plt.title(title)
if ylim is not None:
plt.ylim(*ylim)
plt.xlabel('Training example')
plt.ylabel('score')
train_sizes, train_scores, test_scores = learning_curve(estimator, X, y, cv=cv, n_jobs=n_jobs, train_sizes=train_size)
train_scores_mean = np.mean(train_scores, axis=1)
train_scores_std = np.std(train_scores, axis=1)
test_scores_mean = np.mean(test_scores, axis=1)
test_scores_std = np.std(test_scores, axis=1)
plt.grid()#区域
plt.fill_between(train_sizes, train_scores_mean - train_scores_std,
train_scores_mean + train_scores_std, alpha=0.1,
color="r")
plt.fill_between(train_sizes, test_scores_mean - test_scores_std,
test_scores_mean + test_scores_std, alpha=0.1,
color="g")
plt.plot(train_sizes, train_scores_mean, 'o-', color='r',
label="Training score")
plt.plot(train_sizes, test_scores_mean,'o-',color="g",
label="Cross-validation score")
plt.legend(loc="best")
return plt
digits = load_digits()
X = digits.data
y = digits.target
cv = ShuffleSplit(n_splits=100, test_size=0.2, random_state=0)#切割100ci
estimator = GaussianNB()
title = "Learning Curves(naive_bayes)"
plot_learning_curve(estimator, title, X, y, ylim=(0.7, 1.01), cv=cv, n_jobs=4) title = "Learning Curves(SVM,RBF kernel, $\gamma=0.001$)"
cv = ShuffleSplit(n_splits=10, test_size=0.2, random_state=0)#交叉验证传入别的方法,而不是默认的k折交叉验证
estimator = svm.SVC(gamma=0.001)
plot_learning_curve(estimator, title, X, y, (0.7, 1.01), cv=cv, n_jobs=4)
plt.show()

结果:

说明:

1. 贝叶斯模型上,训练集规模达到1100时已经比较合适,太小的话会导致过拟合。。总的来看,该模型的准确率趋向于0.85

2. SVM模型上,训练集规模到800时已经比较合适,太小的话会导致训练集无代表性,对验证集无预测能力。。。总的来看,该模型在这份数据上的表现比第一个好。

参考

http://www.360doc.com/content/18/0424/22/50223086_748481549.shtml

SKlearn库学习曲线的更多相关文章

  1. 2.sklearn库中的标准数据集与基本功能

    sklearn库中的标准数据集与基本功能 下面我们详细介绍几个有代表性的数据集: 当然同学们也可以用sklearn机器学习函数来挖掘这些数据,看看可不可以捕捉到一些有趣的想象或者是发现: 波士顿房价数 ...

  2. 1.sklearn库的安装

    sklearn库 sklearn是scikit-learn的简称,是一个基于Python的第三方模块.sklearn库集成了一些常用的机器学习方法,在进行机器学习任务时,并不需要实现算法,只需要简单的 ...

  3. day-10 sklearn库实现SVM支持向量算法

    学习了SVM分类器的简单原理,并调用sklearn库,对40个线性可分点进行训练,并绘制出图形画界面. 一.问题引入 如下图所示,在x,y坐标轴上,我们绘制3个点A(1,1),B(2,0),C(2,3 ...

  4. 复盘一篇讲sklearn库学习文章(上)

    认识 sklearn 官网地址: https://scikit-learn.gor/stable/ 从2007年发布以来, scikit-learn已成为重要的Python机器学习库, 简称sklea ...

  5. Python: sklearn库——数据预处理

    Python: sklearn库 —— 数据预处理 数据集转换之预处理数据:      将输入的数据转化成机器学习算法可以使用的数据.包含特征提取和标准化.      原因:数据集的标准化(服从均值为 ...

  6. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  7. python3安装pandas执行pip3 install pandas命令后卡住不动的问题及安装scipy、sklearn库的numpy.distutils.system_info.NotFoundError: no lapack/blas resources found问题

    一直尝试在python3中安装pandas等一系列软件,但每次执行pip3 install pandas后就卡住不动了,一直停在那,开始以为是pip命令的版本不对,还执行过 python -m pip ...

  8. day-9 sklearn库和python自带库实现最近邻KNN算法

    K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一.该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的 ...

  9. Sklearn库例子——决策树分类

    Sklearn上关于决策树算法使用的介绍:http://scikit-learn.org/stable/modules/tree.html 1.关于决策树:决策树是一个非参数的监督式学习方法,主要用于 ...

随机推荐

  1. 收藏:SQL Server 数据库改名

    SQL SERVER 2005以前通常使用sp_renameDB存储过程. EXEC sp_renameDB 'oldDB','newDB'   或者:数据库先分离,然后再附加也可以改名. SQL S ...

  2. mongodb集群故障转移实践

    简介 NOSQL有这些优势: 大数据量,可以通过廉价服务器存储大量的数据,轻松摆脱传统mysql单表存储量级限制. 高扩展性,Nosql去掉了关系数据库的关系型特性,很容易横向扩展,摆脱了以往老是纵向 ...

  3. vcf格式文件转化为Excel(csv)格式文件(R语言的write.csv,write.table功能,Excel表的文件导入功能)

    最近在整理文件,准备把vcf文件转化为Excel格式,或者CSV格式,网上搜了一堆资料,还真有人专门开发出转化格式的工具:叫vcf2csv(下载地址http://vcf2csv.sourceforge ...

  4. gcc-linaro-arm-linux-gnueabihf交叉编译器配置

    系统Ubuntu14.04 版本:gcc 版本 4.7.3 20130328 (prerelease) (crosstool-NG linaro-1.13.1-4.7-2013.04-20130415 ...

  5. 深入flask中的request

    缘起 在使用flask的时候一直比较纳闷request是什么原理,他是如何保证多线程情况下对不同请求参数的隔离的. 准备知识 在讲request之前首先需要先理解一下werkzeug.local中的几 ...

  6. 监听INPUT值的即时变化

    <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/ ...

  7. Python基础【day03】:集合进阶(四)

    本节内容 1.关系测试(特殊符号) 1.交集2.并集3.差集4.对称差集5.是否是子集6.是否是父集 2.基本操作 1.add2.update3.remove VS pop vs discard4.l ...

  8. ajax跨域原理以及jsonp使用

    jsonp介绍: JSONP(JSON with Padding)是JSON的一种“使用模式”,可用于解决主流浏览器的跨域数据访问的问题. 由于同源策略,一般来说位于 server1.example. ...

  9. 查看windows下指定的端口是否开放

    有时候会出现ip  ping的通   但是就是连接不上的情况.这时候我们需要检测一下这个端口是否被开放 netstat -ano -p tcp | find >nul && ec ...

  10. 流媒体技术学习笔记之(一)nginx+nginx-rtmp-module+ffmpeg搭建流媒体服务器

    参照网址: [1]http://blog.csdn.net/redstarofsleep/article/details/45092147 [2]HLS介绍:http://www.cnblogs.co ...