基于sk_learn的k近邻算法实现-mnist手写数字识别且要求97%以上精确率
1. 导入需要的库
from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
2. 设置随机种子,以获得可复现的结果。
np.random.seed(42)
3. 获取mnist数据集,并将数据集标签 由字符型转换为整数型
1 np.random.seed(42)
2 mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
3 X, y = mnist['data'], mnist['target']
4 y = y.astype(np.uint8)
4. 划分训练集和测试集
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
5. 训练模型并测试
knn_clf = KNeighborsClassifier()
knn_clf.fit(X_train, y_train) y_test_pred = knn_clf.predict(X_test)
print(accuracy_score(y_test, y_test_pred))
如图我们得到了模型的准确率 0.9688
6. 训练模型中的超参数weights(默认值为'uniform')和n_neighbors(默认值为5)。由于超参数的连续性,所以n_neighbors的备选值可以为 3, 4, 6
from sklearn.datasets import fetch_openml
import numpy as np
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV
from sklearn.metrics import accuracy_score np.random.seed(42)
mnist = fetch_openml("mnist_784", version = 1, as_frame=False)
X, y = mnist['data'], mnist['target']
y = y.astype(np.uint8) X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
param_grid = [{'weights': ["uniform", "distance"], 'n_neighbors': [3, 4, 6]}] knn_clf = KNeighborsClassifier()
grid_search = GridSearchCV(knn_clf, param_grid, cv=5, verbose=3)
grid_search.fit(X_train, y_train)
y_pred = grid_search.predict(X_test)
print(accuracy_score(y_test, y_pred))
如图所示,在测试集上得到的准确率达到97.14%
通过如下命令可以获得选取的最合适的超参数以及在验证集上达到的最好结果
基于sk_learn的k近邻算法实现-mnist手写数字识别且要求97%以上精确率的更多相关文章
- 持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型
持久化的基于L2正则化和平均滑动模型的MNIST手写数字识别模型 觉得有用的话,欢迎一起讨论相互学习~Follow Me 参考文献Tensorflow实战Google深度学习框架 实验平台: Tens ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 基于TensorFlow的MNIST手写数字识别-初级
一:MNIST数据集 下载地址 MNIST是一个包含很多手写数字图片的数据集,一共4个二进制压缩文件 分别是test set images,test set labels,training se ...
- Android+TensorFlow+CNN+MNIST 手写数字识别实现
Android+TensorFlow+CNN+MNIST 手写数字识别实现 SkySeraph 2018 Email:skyseraph00#163.com 更多精彩请直接访问SkySeraph个人站 ...
- Tensorflow之MNIST手写数字识别:分类问题(1)
一.MNIST数据集读取 one hot 独热编码独热编码是一种稀疏向量,其中:一个向量设为1,其他元素均设为0.独热编码常用于表示拥有有限个可能值的字符串或标识符优点: 1.将离散特征的取值扩展 ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- TensorFlow——MNIST手写数字识别
MNIST手写数字识别 MNIST数据集介绍和下载:http://yann.lecun.com/exdb/mnist/ 一.数据集介绍: MNIST是一个入门级的计算机视觉数据集 下载下来的数据集 ...
- Tensorflow实现MNIST手写数字识别
之前我们讲了神经网络的起源.单层神经网络.多层神经网络的搭建过程.搭建时要注意到的具体问题.以及解决这些问题的具体方法.本文将通过一个经典的案例:MNIST手写数字识别,以代码的形式来为大家梳理一遍神 ...
- mnist手写数字识别——深度学习入门项目(tensorflow+keras+Sequential模型)
前言 今天记录一下深度学习的另外一个入门项目——<mnist数据集手写数字识别>,这是一个入门必备的学习案例,主要使用了tensorflow下的keras网络结构的Sequential模型 ...
随机推荐
- git alias all in one
git alias all in one workspace:工作区 staging area:暂存区/缓存区 local repository:或本地仓库 remote repository:远程仓 ...
- PM2 All In One
PM2 All In One https://pm2.keymetrics.io/ https://pm2.io/ $ yarn global add pm2 # OR $ npm install p ...
- Chrome DevTools & Slow 3G Network
Chrome DevTools & Slow 3G Network shortcuts https://developers.google.com/web/tools/chrome-devto ...
- V8 & ECMAScript & ES-Next
V8 & ECMAScript & ES-Next ES6, ES7, ES8, ES9, ES10, ES11, ES2015, ES2016, ES2017, ES2018, ES ...
- Flutter 使用高德地图定位
amap_location 包 获取debug SHA1 // 使用debug.keystore获取debug SHA1 C:\Users\ajanuw\.android>keytool -li ...
- 算法型稳定币USDN有什么价值和用途?
USDN的标签是"数字美元",与大多数稳定资产一样,USDN是一种金融服务产品.基于NGK公链发行的算法型稳定币USDN,USDN是和美元1:1锚定的加密数字货币,1USDN等于1 ...
- 从崩溃的选课系统,论为什么更安全的 HTTPS 协议没有被全面采用
尽人事,听天命.博主东南大学研究生在读,热爱健身和篮球,正在为两年后的秋招准备中,乐于分享技术相关的所见所得,关注公众号 @ 飞天小牛肉,第一时间获取文章更新,成长的路上我们一起进步 本文已收录于 C ...
- [转]ROS订阅激光数据
https://github.com/robopeak/rplidar_ros/blob/master/src/client.cpp /* * Copyright (c) 2014, RoboPe ...
- (1)MySQL进阶篇在linux环境下安装
1.概述 对于mysql二进制安装,优点是可以安装到任何路径下,灵活性好,一台服务器可以安装多个mysql.缺点是已经编译过,性能不如源码编译得好,不能灵活定制编译参数.如果用户即不想安装最简单却不够 ...
- 使用sun.net.ftp.FtpClient进行上传功能开发,在jdk1.7上不适用问题的解决
问题如下图片: 之前项目上开发了一个上传文件的功能,使用的是sun.net.ftp.FtpClient这个类 连接服务器的代码大概如下: public static FtpClient ftpClie ...