在上一个博文中,我们用learning_curve函数来确定应该拥有多少的训练集能够达到效果,就像一个人进行学习时需要做多少题目就能拥有较好的考试成绩了。

本次我们来看下如何调整学习中的参数,类似一个人是在早上7点钟开始读书好还是晚上8点钟读书好。

加载数据

数据仍然利用手写数字识别作为训练数据:

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target

调整参数

我们想要调整·SVC(gamma=0.001)·SVC中的gamma参数,看到底把gamma参数设置成哪个值是最优的。

因此需要定义测试的参数范围,这里设置了参数值的范围为从10的-6次方到10的-2.3次方,总共5个值:

import numpy as np
# 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
param_range = np.logspace(-6, -2.3, 5)

validation_curve不停尝试在不同参数值下的损失函数值:

from sklearn.model_selection import validation_curve
from sklearn.svm import SVC
# param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

可视化图形

可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值

# 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
import matplotlib.pyplot as plt
plt.plot(param_range, train_loss_mean, label="Train")
plt.plot(param_range, test_loss_mean, label="Test")
plt.legend()
plt.show()

图形显示为:

在这个图形中,我们发现gamma值有一个转折点,当其在0.001之后,测试集的误差值就开始扩大了,因此,从图形上看,一个比较好的学习参数值是gamma=0.001或者再往前一点点,大概在0.0007左右。

完整代码

完整的代码如下:

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target import numpy as np
# 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
param_range = np.logspace(-6, -2.3, 5) from sklearn.model_selection import validation_curve
from sklearn.svm import SVC
# param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1) # 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
import matplotlib.pyplot as plt
plt.plot(param_range, train_loss_mean, label="Train")
plt.plot(param_range, test_loss_mean, label="Test")
plt.legend()
plt.show()

sklearn交叉验证3-【老鱼学sklearn】的更多相关文章

  1. sklearn交叉验证-【老鱼学sklearn】

    交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法.于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证. 一开始 ...

  2. sklearn交叉验证2-【老鱼学sklearn】

    过拟合 过拟合相当于一个人只会读书,却不知如何利用知识进行变通. 相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨. 从图形上看,类似下图的最右图: 从数学公式上来看,这个曲线应该是 ...

  3. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  4. sklearn数据库-【老鱼学sklearn】

    在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:htt ...

  5. sklearn模型的属性与功能-【老鱼学sklearn】

    本节主要讲述模型中的各种属性及其含义. 例如上个博文中,我们有用线性回归模型来拟合房价. # 创建线性回归模型 model = LinearRegression() # 训练模型 model.fit( ...

  6. sklearn标准化-【老鱼学sklearn】

    在前面的一篇博文中关于计算房价中我们也大致提到了标准化的概念,也就是比如对于影响房价的参数中有面积和户型,面积的取值范围可以很广,它可以从0-500平米,而户型一般也就1-5. 标准化就是要把这两种参 ...

  7. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  8. tensorflow卷积神经网络-【老鱼学tensorflow】

    前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...

  9. 机器学习- Sklearn (交叉验证和Pipeline)

    前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...

随机推荐

  1. LCD学习

    LCD简介(1)显示器,常见显示器(2)LCD(Liquid Crystal Display),液晶显示器,原理介绍(3)LCD应用领域(4)LED OLED1.17.1.2.电子显示器的原理(1)像 ...

  2. 解决win10隔几分钟自动黑屏睡眠的方法

    来源:win10总是很快自动休眠怎么解决? - 风格不空格的回答 - 知乎 https://www.zhihu.com/question/39263412/answer/87430653 1.运行注册 ...

  3. JS类型

    1.查看类型 可以使用 typeof 操作符来检测变量的数据类型. typeof "John" // 返回 string typeof 3.14 // 返回 number type ...

  4. LOJ2340 [WC2018] 州区划分 【FMT】【欧拉回路】

    题目分析: 这题是WC的题??? 令 $g[S] = (\sum_{x \in S}w_x)^p$ $h[S] = g[S]$如果$S$不是欧拉回路 $d[S] = \frac{f[S]}{g[All ...

  5. Scrapy 框架,爬虫文件相关

    Spiders 介绍 由一系列定义了一个网址或一组网址类如何被爬取的类组成 具体包括如何执行爬取任务并且如何从页面中提取结构化的数据. 简单来说就是帮助你爬取数据的地方 内部行为 #1.生成初始的Re ...

  6. Codeforces 1077D Cutting Out(二分答案)

    题目链接:Cutting Out 题意:给定一个n长度的数字序列s,要求得到一个k长度的数字序列t,每次从s序列中删掉完整的序列t,求出能删次数最多的那个数字序列t. 题解:数字序列s先转换成不重复的 ...

  7. Python【第三篇】文件操作、字符编码

    一.文件操作 文件操作分为三个步骤:文件打开.操作文件.关闭文件,但是,我们可以用with来管理文件操作,这样就不需要手动来关闭文件. 实现原理: import contextlib @context ...

  8. ant在windows及linux环境下安装

    ant下载 http://ant.apache.org/ https://ant.apache.org/bindownload.cgi 历史版本 ant在windows下安装 解压到D盘 新建系统变量 ...

  9. gprof性能测试工具

    1.编译时加-pg选项,例如:gcc -pg test.c-o test_gprof.其中test后的_gprof一定要加上.会生成gmon.out. 2.运行程序.gprof test_gprof ...

  10. MySQL 导出数据库,出现 “mysqldump: Got error: 1146”

    出现场景 在 cmd 导出数据库时: mysqldump -hlocalhost -uroot -p student_db > C:\student_db.sql 出现: mysqldump: ...