公号:码农充电站pro

主页:https://codeshellme.github.io

上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字

1,手写数字数据集

手写数字数据集是一个用于图像处理的数据集,这些数据描绘了 [0, 9] 的数字,我们可以用KNN 算法来识别这些数字。

MNIST 是完整的手写数字数据集,其中包含了60000 个训练样本和10000 个测试样本。

sklearn 中也有一个自带的手写数字数据集

  • 共包含 1797 个数据样本,每个样本描绘了一个 8*8 像素的 [0, 9] 的数字。
  • 每个样本由 65 个数字组成:
    • 前 64 个数字是特征数据,特征数据的范围是 [0, 16]
    • 最后一个数字是目标数据,目标数据的范围是 [0, 9]

我们抽出 5 个样本来看下:

0,0,5,13,9,1,0,0,0,0,13,15,10,15,5,0,0,3,15,2,0,11,8,0,0,4,12,0,0,8,8,0,0,5,8,0,0,9,8,0,0,4,11,0,1,12,7,0,0,2,14,5,10,12,0,0,0,0,6,13,10,0,0,0,0
0,0,0,12,13,5,0,0,0,0,0,11,16,9,0,0,0,0,3,15,16,6,0,0,0,7,15,16,16,2,0,0,0,0,1,16,16,3,0,0,0,0,1,16,16,6,0,0,0,0,1,16,16,6,0,0,0,0,0,11,16,10,0,0,1
0,0,0,4,15,12,0,0,0,0,3,16,15,14,0,0,0,0,8,13,8,16,0,0,0,0,1,6,15,11,0,0,0,1,8,13,15,1,0,0,0,9,16,16,5,0,0,0,0,3,13,16,16,11,5,0,0,0,0,3,11,16,9,0,2
0,0,7,15,13,1,0,0,0,8,13,6,15,4,0,0,0,2,1,13,13,0,0,0,0,0,2,15,11,1,0,0,0,0,0,1,12,12,1,0,0,0,0,0,1,10,8,0,0,0,8,4,5,14,9,0,0,0,7,13,13,9,0,0,3
0,0,0,1,11,0,0,0,0,0,0,7,8,0,0,0,0,0,1,13,6,2,2,0,0,0,7,15,0,9,8,0,0,5,16,10,0,16,6,0,0,4,15,16,13,16,1,0,0,0,0,3,15,10,0,0,0,0,0,2,16,4,0,0,4

使用该数据集,需要先加载:

>>> from sklearn.datasets import load_digits
>>> digits = load_digits()

查看第一个图像数据:

>>> digits.images[0]
array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
[ 0., 0., 13., 15., 10., 15., 5., 0.],
[ 0., 3., 15., 2., 0., 11., 8., 0.],
[ 0., 4., 12., 0., 0., 8., 8., 0.],
[ 0., 5., 8., 0., 0., 9., 8., 0.],
[ 0., 4., 11., 0., 1., 12., 7., 0.],
[ 0., 2., 14., 5., 10., 12., 0., 0.],
[ 0., 0., 6., 13., 10., 0., 0., 0.]])

我们可以用 matplotlib 将该图像画出来:

>>> import matplotlib.pyplot as plt
>>> plt.imshow(digits.images[0])
>>> plt.show()

画出来的图像如下,代表 0

2,sklearn 对 KNN 算法的实现

sklearn 库的 neighbors 模块实现了KNN 相关算法,其中:

  • KNeighborsClassifier 类用于分类问题
  • KNeighborsRegressor 类用于回归问题

这两个类的构造方法基本一致,这里我们主要介绍 KNeighborsClassifier 类,原型如下:

KNeighborsClassifier(
n_neighbors=5,
weights='uniform',
algorithm='auto',
leaf_size=30,
p=2,
metric='minkowski',
metric_params=None,
n_jobs=None,
**kwargs)

来看下几个重要参数的含义:

  • n_neighbors:即 KNN 中的 K 值,一般使用默认值 5。
  • weights:用于确定邻居的权重,有三种方式:
    • weights=uniform,表示所有邻居的权重相同。
    • weights=distance,表示权重是距离的倒数,即与距离成反比。
    • 自定义函数,可以自定义不同距离所对应的权重,一般不需要自己定义函数。
  • algorithm:用于设置计算邻居的算法,它有四种方式:
    • algorithm=auto,根据数据的情况自动选择适合的算法。
    • algorithm=kd_tree,使用 KD 树 算法。
      • KD 树是一种多维空间的数据结构,方便对数据进行检索。
      • KD 树适用于维度较少的情况,一般维数不超过 20,如果维数大于 20 之后,效率会下降。
    • algorithm=ball_tree,使用球树算法。
      • KD 树一样都是多维空间的数据结构。
      • 球树更适用于维度较大的情况。
    • algorithm=brute,称为暴力搜索
      • 它和 KD 树相比,采用的是线性扫描,而不是通过构造树结构进行快速检索。
      • 缺点是,当训练集较大的时候,效率很低。
    • leaf_size:表示构造 KD 树球树时的叶子节点数,默认是 30。
      • 调整 leaf_size 会影响树的构造和搜索速度。

3,构造 KNN 分类器

首先加载数据集:

from sklearn.datasets import load_digits

digits = load_digits()
data = digits.data # 特征集
target = digits.target # 目标集

将数据集拆分为训练集(75%)和测试集(25%),

from sklearn.model_selection import train_test_split

train_x, test_x, train_y, test_y = train_test_split(
data, target, test_size=0.25, random_state=33)

构造KNN 分类器:

from sklearn.neighbors import KNeighborsClassifier

# 采用默认参数
knn = KNeighborsClassifier()

拟合模型:

knn.fit(train_x, train_y)

预测数据:

predict_y = knn.predict(test_x)

计算模型准确度:

from sklearn.metrics import accuracy_score

score = accuracy_score(test_y, predict_y)
print score # 0.98

最终计算出来模型的准确度是 98%,准确度还是不错的。

4,总结

本篇文章使用KNN 算法处理了一个实际的分类问题,主要介绍了以下几点:

  • 介绍了sklearn 中自带的手写数字集,并用 matplotlib 模块画出了数字图像。
  • 介绍了sklearnneighbors.KNeighborsClassifier 类的用法。
  • 使用 KNeighborsClassifier 来识别手写数字。

(本节完。)


推荐阅读:

KNN 算法-理论篇-如何给电影进行分类

决策树算法-理论篇-如何计算信息纯度

决策树算法-实战篇-鸢尾花及波士顿房价预测

朴素贝叶斯分类-理论篇-如何通过概率解决分类问题

朴素贝叶斯分类-实战篇-如何进行文本分类


欢迎关注作者公众号,获取更多技术干货。

KNN 算法-实战篇-如何识别手写数字的更多相关文章

  1. 一文全解:利用谷歌深度学习框架Tensorflow识别手写数字图片(初学者篇)

    笔记整理者:王小草 笔记整理时间2017年2月24日 原文地址 http://blog.csdn.net/sinat_33761963/article/details/56837466?fps=1&a ...

  2. TensorFlow实战之Softmax Regression识别手写数字

         关于本文说明,本人原博客地址位于http://blog.csdn.net/qq_37608890,本文来自笔者于2018年02月21日 23:10:04所撰写内容(http://blog.c ...

  3. 使用神经网络来识别手写数字【译】(三)- 用Python代码实现

    实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...

  4. 学习笔记TF024:TensorFlow实现Softmax Regression(回归)识别手写数字

    TensorFlow实现Softmax Regression(回归)识别手写数字.MNIST(Mixed National Institute of Standards and Technology ...

  5. python手写神经网络实现识别手写数字

    写在开头:这个实验和matlab手写神经网络实现识别手写数字一样. 实验说明 一直想自己写一个神经网络来实现手写数字的识别,而不是套用别人的框架.恰巧前几天,有幸从同学那拿到5000张已经贴好标签的手 ...

  6. 用BP人工神经网络识别手写数字

    http://wenku.baidu.com/link?url=HQ-5tZCXBQ3uwPZQECHkMCtursKIpglboBHq416N-q2WZupkNNH3Gv4vtEHyPULezDb5 ...

  7. python机器学习使用PCA降维识别手写数字

    PCA降维识别手写数字 关注公众号"轻松学编程"了解更多. PCA 用于数据降维,减少运算时间,避免过拟合. PCA(n_components=150,whiten=True) n ...

  8. 3 TensorFlow入门之识别手写数字

    ------------------------------------ 写在开头:此文参照莫烦python教程(墙裂推荐!!!) ---------------------------------- ...

  9. KNN (K近邻算法) - 识别手写数字

    KNN项目实战——手写数字识别 1. 介绍 k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法.它的工作原理是:存在一个 ...

随机推荐

  1. 【Luogu】P1436 棋盘分割 题解

    嗯,点开题目,哇!是一道闪亮亮的蓝题! 不要被吓到了,其实,这道题就是一个简单的DP啦! 我们设 \(f[x1][y1][x2][y2][c]\) 为以 \((x1,y1)\) 为左上角,以 \((x ...

  2. 按揭贷款的计算原理与java实现

    Number部分(6) Mortgage Calculator--按揭贷款计算器 题目描述: Mortgage Calculator – Calculate the monthly payments ...

  3. 腾讯云--对象存储cos绑定自定义域名

    1.登录腾讯云控制台,找到对象存储一栏 2.选择一个你想绑定域名的存储桶 3.进入你选择的存储桶,点击域名管理 4.选择自定义源站域名.在域名处填写你要设置的自定义域名,在源站类型处选择静态网站源站, ...

  4. layui导航

    关于导航 首先看一下官网的样式: <!DOCTYPE html><html><head> <meta charset="utf-8" /& ...

  5. .Net Newtonsoft.Json 转json时将枚举转为字符串

    1:非列表类型枚举 [JsonConverter(typeof(StringEnumConverter))] public SubjectTypeEnum subject_type { get; se ...

  6. python菜鸟教程学习1:背景性学习

    https://www.runoob.com/python3/python3-intro.html 优点 简单 -- Python 是一种代表简单主义思想的语言.阅读一个良好的 Python 程序就感 ...

  7. 第05组 Alpha冲刺 (1/6)

    .th1 { font-family: 黑体; font-size: 25px; color: rgba(0, 0, 255, 1) } #ka { margin-top: 50px } .aaa11 ...

  8. 92. Reverse Linked List II 翻转链表II

    Reverse a linked list from position m to n. Do it in one-pass. Note: 1 ≤ m ≤ n ≤ length of list. Exa ...

  9. 快速熟悉 Oracle AWR 报告解读

    目录 AWR报告简介 AWR报告结构 基本信息 Report Summary Main Report RAC statistics Wait Event Statistics 参考资料 本文面向没有太 ...

  10. mysql mybatis Date java时间和写入数据库时间不符差一秒问题

    1,java的数据库实体定义 private Timestamp createTime:2,非常重要!ddl语句建表字段的单位 datetime要手动设置保留3位毫秒数,不然就四舍五入了! ALTER ...