1. 导入需要的库

from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

2. 设置随机种子,以获得可复现的结果。

np.random.seed(42)

3. 获取mnist数据集,并将数据集标签 由字符型转换为整数型

1 np.random.seed(42)
2 mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
3 X, y = mnist['data'], mnist['target']
4 y = y.astype(np.uint8)

4. 划分训练集和测试集

X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]

5. 训练模型并测试

knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train) y_test_pred = knn_clf.predict(X_test)
print(accuracy_score(y_test, y_test_pred))

如图我们得到了模型的准确率 0.9688

6. 训练模型中的超参数weights(默认值为'uniform')和n_neighbors(默认值为5)。由于超参数的连续性,所以n_neighbors的备选值可以为 3, 4,  6

from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score np.random.seed(42)
mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
X, y = mnist['data'], mnist['target']
y = y.astype(np.uint8) X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 6]}] knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3)
grid_search.fit(X_train, y_train)
y_pred = grid_search.predict(X_test)
print(accuracy_score(y_test, y_pred))

如图所示,在测试集上得到的准确率达到97.14%

通过如下命令可以获得选取的最合适的超参数以及在验证集上达到的最好结果

基于sk_learn的k近邻算法实现-mnist手写数字识别且要求97%以上精确率的更多相关文章

  1. 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型

    持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...

  2. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  3. 基于TensorFlow的MNIST手写数字识别-初级

    一:MNIST数据集    下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...

  4. Android+TensorFlow+CNN+MNIST 手写数字识别实现

    Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...

  5. Tensorflow之MNIST手写数字识别:分类问题(1)

    一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点:   1.将离散特征的取值扩展 ...

  6. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  7. TensorFlow——MNIST手写数字识别

    MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/   一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...

  8. Tensorflow实现MNIST手写数字识别

    之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...

  9. mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)

    前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...

随机推荐

  1. npm publish bug & solution

    npm publish bug & solution npm ERR! Unexpected token < in JSON at position 0 while parsing ne ...

  2. 使用 js 和 Beacon API 实现一个简易版的前端埋点监控 npm 包

    使用 js 和 Beacon API 实现一个简易版的前端埋点监控 npm 包 前端监控,埋点,数据收集,性能监控 Beacon API https://caniuse.com/beacon 优点,请 ...

  3. 如何获取豆瓣电影 API Key

    如何获取豆瓣电影 API Key 豆瓣 API Key 不能使用了 ! solutions & !== ? https://frodo.douban.com/api/v2/subject_co ...

  4. Python3 & Decorators with arguments & @Decorators with arguments bug

    Python3 & Decorators with arguments & @Decorators with arguments bug @Decorators with argume ...

  5. how to disabled alert function in javascript

    how to disabled alert function in javascript alert 阻塞主线程 default alert; // ƒ alert() { [native code] ...

  6. 前端 & 技术团队 TL & 如何面试 & 如何带人

    前端 & 技术团队 TL & 如何面试 & 如何带人 面试 带人 作为 TL,深度了解你的团队非常重要,要去了解每个人的想法是什么,他的诉求是什么,他目前的状态怎么样,以及对他 ...

  7. NGK公链大事件盘点——回顾过去,展望未来!

    NGK公链构想广阔,愿景宏大,2020年10月NGK正式上线,同时NGK全球发布会正式启动,建立区块链生态体系. 早在这之前,NGK就经过了紧锣密鼓的数年缜密搭建. 2018年6月NGK底层系统技术原 ...

  8. HttpDns 原理是什么

    本文转载自HttpDns 原理是什么 什么是 DNS DNS(Domain Name System,域名系统),DNS 服务用于在网络请求时,将域名转为 IP 地址.能够使用户更方便的访问互联网,而不 ...

  9. .NET Core Swagger 的分组使, 以及相同Action能被多个分组公用,同时加载出尚未分组的数据出来

    1.本文章参考 点击链接跳转 改写的 一对多分组模式.需要一对一的可以参考 2.本文主要讲的是 一对多 分组公用, 同时把尚未分组的加载出来 3.效果演示GIF图: 具体操作代码如下: 1.在项目创建 ...

  10. PHP中间件

    定义 首先什么是php的中间件? 根据zend-framework中的定义: 所谓中间件是指提供在请求和响应之间的,能够截获请求,并在其基础上进行逻辑处理,与此同时能够完成请求的响应或传递到下一个中间 ...