1. 导入需要的库

from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

2. 设置随机种子,以获得可复现的结果。

np.random.seed(42)

3. 获取mnist数据集,并将数据集标签 由字符型转换为整数型

1 np.random.seed(42)
2 mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
3 X, y = mnist['data'], mnist['target']
4 y = y.astype(np.uint8)

4. 划分训练集和测试集

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

5. 训练模型并测试

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train) y_test_pred = knn_clf.predict(X_test)
print(accuracy_score(y_test, y_test_pred))

如图我们得到了模型的准确率 0.9688

6. 训练模型中的超参数weights(默认值为'uniform')和n_neighbors(默认值为5)。由于超参数的连续性,所以n_neighbors的备选值可以为 3, 4,  6

from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score np.random.seed(42)
mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
X, y = mnist['data'], mnist['target']
y = y.astype(np.uint8) X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 6]}] knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3)
grid_search.fit(X_train, y_train)
y_pred = grid_search.predict(X_test)
print(accuracy_score(y_test, y_pred))

如图所示,在测试集上得到的准确率达到97.14%

通过如下命令可以获得选取的最合适的超参数以及在验证集上达到的最好结果

基于sk_learn的k近邻算法实现-mnist手写数字识别且要求97%以上精确率的更多相关文章

  1. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  6. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  7. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  8. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  9. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

随机推荐

  1. Apple 订单系统 bug

    Apple 订单系统 bug 看不到最近的购买信息 https://secure1.www.apple.com.cn/shop/order/list refs xgqfrms 2012-2020 ww ...

  2. GitHub Ribbons : 谈网站的安全性-资源链接如何 预防/实现 爬虫的批量下载!

    GitHub Ribbons : 谈网站的安全性-资源链接如何 预防/实现 爬虫的批量下载! 预防方法: 1. 使用随机数字符串,拼接URL! https://camo.githubuserconte ...

  3. 使用 js 实现十大排序算法: 快速排序

    使用 js 实现十大排序算法: 快速排序 QuickSort 快速排序 /** * * @author xgqfrms * @license MIT * @copyright xgqfrms * @c ...

  4. record terminal sessions

    record terminal sessions asciinema https://asciinema.org/ # install $ brew install asciinema # Start ...

  5. JavaScript & Error Types

    JavaScript & Error Types JavaScript提供了8个错误对象,这些错误对象会根据错误类型在try / catch表达式中引发: Error EvalError Ra ...

  6. npm & package.json & directories & files

    npm & package.json & directories & files package.json https://docs.npmjs.com/files/packa ...

  7. 在浏览器上播放m3u8视频

    在edge上有效 <video width="600" controls> <source src="https://www.gentaji.com/2 ...

  8. BGV劝早买内存

    12月3日,BGV全球首发,上线AOFEX交易所(A网),全球区块链爱好者震惊.很多人争相抢挖BGV,希望能够及早获取BGV带来的红利.有趣的是,随着BGV抢挖人数的增多,NGK内存也迎来了暴涨,在1 ...

  9. Docker Tips: 关于/var/run/docker.sock

    本文转载自Docker Tips: 关于/var/run/docker.sock 导语 你可能已经运行过docker hub上的container并且注意到其中的一些需要绑定挂载(mount)/var ...

  10. 开源OA办公平台功能介绍:应用市场-固定资产管理(一)功能设计

    概述 应用市场-固定资产管理,是用来维护管理企业固定资产的一个功能.其整个功能包括对固定资产的台账信息.领用.调拨.借用.维修.盘点.报废等一整个生命周期的动态管理过程.力求客户安装就可以使用. 本应 ...