在上一个博文中,我们用learning_curve函数来确定应该拥有多少的训练集能够达到效果,就像一个人进行学习时需要做多少题目就能拥有较好的考试成绩了。

本次我们来看下如何调整学习中的参数,类似一个人是在早上7点钟开始读书好还是晚上8点钟读书好。

加载数据

数据仍然利用手写数字识别作为训练数据:

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target

调整参数

我们想要调整·SVC(gamma=0.001)·SVC中的gamma参数,看到底把gamma参数设置成哪个值是最优的。

因此需要定义测试的参数范围,这里设置了参数值的范围为从10的-6次方到10的-2.3次方,总共5个值:

import numpy as np
# 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
param_range = np.logspace(-6, -2.3, 5)

validation_curve不停尝试在不同参数值下的损失函数值:

from sklearn.model_selection import validation_curve
from sklearn.svm import SVC
# param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1)

可视化图形

可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值

# 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
import matplotlib.pyplot as plt
plt.plot(param_range, train_loss_mean, label="Train")
plt.plot(param_range, test_loss_mean, label="Test")
plt.legend()
plt.show()

图形显示为:

在这个图形中,我们发现gamma值有一个转折点,当其在0.001之后,测试集的误差值就开始扩大了,因此,从图形上看,一个比较好的学习参数值是gamma=0.001或者再往前一点点,大概在0.0007左右。

完整代码

完整的代码如下:

from sklearn.datasets import load_digits

# 加载数据
digits = load_digits()
X = digits.data
y = digits.target import numpy as np
# 定义gamma参数的可能取值范围,从10**-6, 到10**-2.3,总共5个参数值
param_range = np.logspace(-6, -2.3, 5) from sklearn.model_selection import validation_curve
from sklearn.svm import SVC
# param_name中指定了修改SVC中的哪个参数值,这里修改的是gamma参数值;param_range参数指定了具体参数值的可选范围
train_loss, test_loss = validation_curve(SVC(), X, y, param_name="gamma", param_range=param_range, cv=10, scoring='neg_mean_squared_error')
train_loss_mean = -np.mean(train_loss, axis=1)
test_loss_mean = -np.mean(test_loss, axis=1) # 可视化图形,横坐标为参数可选值的范围,纵坐标为在各参数下的损失函数值
import matplotlib.pyplot as plt
plt.plot(param_range, train_loss_mean, label="Train")
plt.plot(param_range, test_loss_mean, label="Test")
plt.legend()
plt.show()

sklearn交叉验证3-【老鱼学sklearn】的更多相关文章

  1. sklearn交叉验证-【老鱼学sklearn】

    交叉验证(Cross validation),有时亦称循环估计, 是一种统计学上将数据样本切割成较小子集的实用方法.于是可以先在一个子集上做分析, 而其它子集则用来做后续对此分析的确认及验证. 一开始 ...

  2. sklearn交叉验证2-【老鱼学sklearn】

    过拟合 过拟合相当于一个人只会读书,却不知如何利用知识进行变通. 相当于他把考试题目背得滚瓜烂熟,但一旦环境稍微有些变化,就死得很惨. 从图形上看,类似下图的最右图: 从数学公式上来看,这个曲线应该是 ...

  3. sklearn保存模型-【老鱼学sklearn】

    训练好了一个Model 以后总需要保存和再次预测, 所以保存和读取我们的sklearn model也是同样重要的一步. 比如,我们根据房源样本数据训练了一下房价模型,当用户输入自己的房子后,我们就需要 ...

  4. sklearn数据库-【老鱼学sklearn】

    在做机器学习时需要有数据进行训练,幸好sklearn提供了很多已经标注好的数据集供我们进行训练. 本节就来看看sklearn提供了哪些可供训练的数据集. 这些数据位于datasets中,网址为:htt ...

  5. sklearn模型的属性与功能-【老鱼学sklearn】

    本节主要讲述模型中的各种属性及其含义. 例如上个博文中,我们有用线性回归模型来拟合房价. # 创建线性回归模型 model = LinearRegression() # 训练模型 model.fit( ...

  6. sklearn标准化-【老鱼学sklearn】

    在前面的一篇博文中关于计算房价中我们也大致提到了标准化的概念,也就是比如对于影响房价的参数中有面积和户型,面积的取值范围可以很广,它可以从0-500平米,而户型一般也就1-5. 标准化就是要把这两种参 ...

  7. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  8. tensorflow卷积神经网络-【老鱼学tensorflow】

    前面我们曾有篇文章中提到过关于用tensorflow训练手写2828像素点的数字的识别,在那篇文章中我们把手写数字图像直接碾压成了一个784列的数据进行识别,但实际上,这个图像是2828长宽结构的,我 ...

  9. 机器学习- Sklearn (交叉验证和Pipeline)

    前面一节咱们已经介绍了决策树的原理已经在sklearn中的应用.那么这里还有两个数据处理和sklearn应用中的小知识点咱们还没有讲,但是在实践中却会经常要用到的,那就是交叉验证cross_valid ...

随机推荐

  1. libstdc++.so.6: cannot open shared object file: No such file or directory

    sudo apt-get install lib32stdc++6 sudo apt-get install lib32z1

  2. HashMap初认识

    什么是HashSet? 它实现了Set接口,HashSet是Set集合的子类 有哈希表支持的,元素不可重复的哈希码值(实际上是一个HashMap的实例). 它不保证set的迭代顺序(遍历元素的顺序), ...

  3. 分享4个运维平台工具OSSIM、Ansible的学习思路

    对于当今企业安全来说,真正价值不在于亡羊补牢,也不在于一个或多个高危漏洞.企业在乎的是如何防患于未然,如何快速定位攻击,如何快速解决安全问题.OSSIM作为开源的安全信息管理平台,对于企业的需求来说毋 ...

  4. Hadoop记录- Yarn scheduler队列采集

    #!/bin/sh ip=10.116.100.11 port=8088 export HADOOP_HOME=/app/hadoop/bin rmstate1=$($HADOOP_HOME/yarn ...

  5. VS2012/2013/2015/Visual Studio 2017 关闭单击文件进行预览的功能

    Visual Studio在2010版本后推出了点击项目管理器预览文件的功能,但是对于配置不咋地的旧电脑总是觉得有点卡,下面是解决方案. 英文版方法:Tools->Options->Env ...

  6. Geometric regularity criterion for NSE: the cross product of velocity and vorticity 1: $u\times \om$

    在 [Chae, Dongho. On the regularity conditions of suitable weak solutions of the 3D Navier-Stokes equ ...

  7. tr1

    tr echo 12345 | tr '0-9' '9876543210' echo 87654 | tr '9876543210' '0-9' ROT13 echo "tr came, t ...

  8. 计算int数组中的最大,最小,平均值

    public static void testNumber(int[] arr) { int max = arr[0]; int min = arr[0]; int avg = 0; int sum ...

  9. 【洛谷P1706全排列问题】

    题目描述 输出自然数1到n所有不重复的排列,即n的全排列,要求所产生的任一数字序列中不允许出现重复的数字. 代码如下: #include<iostream>#include<cstd ...

  10. 【转】利用apktool反编译apk,并且重新签名打包

    网站:https://ibotpeaches.github.io/Apktool,下载安装好apktool. 我的安装在 C:\Users\Administrator\Downloads\apktoo ...