KNN实现手写数字识别


博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb

1 - 导入模块

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits %matplotlib inline

2 - 导入数据及数据预处理

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz

数据维度

print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)

mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。

x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

展示手写数字

nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")

3 - 构建模型

class Knn():

    def __init__(self,k):
self.k = k
self.distance = {} def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672

准确率略高。



MARSGGBO♥原创





2017-8-21

KNN实现手写数字识别的更多相关文章

  1. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  2. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  3. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  4. Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维

    引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...

  5. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  6. K近邻实战手写数字识别

    1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

随机推荐

  1. LeetCode 463. Island Perimeter岛屿的周长 (C++)

    题目: You are given a map in form of a two-dimensional integer grid where 1 represents land and 0 repr ...

  2. 四则运算 SPEC 20160911

    本文档随时可能修改,并且没有另行通知. 请确保每一次在开始修改你的代码前,读标题中的日期,如果晚于你上次阅读, 请重读一次. 教师节你去探望初中数学老师,她感叹你当年真是个优秀学生啊,从来不报怨作 业 ...

  3. [2017BUAA软工助教]个人得分总表(beta阶段)

    一.表 学号 b团队 b团队得分 b贡献分 阅读作业 提问回顾 总分 14011100 hotcode5 228 60 6 7.5 301.5 14061213 PM="PokeMon&qu ...

  4. HashMap相关总结

    1.HashMap:根据键值hashCode值存储数据,大多数情况下可以直接定位到它的值,但是遍历顺序不确定.所有哈希值相同的值存储到同一个链表中                         Ha ...

  5. nil Nil NULL NSNull 之间的区别

    nil -> Null-pointer to objective- c objectNIL -> Null-pointer to objective- c class  表示对类进行赋空值 ...

  6. Mysql存储引擎federated

    Mysql数据库存储引擎federated(联盟) 意思就是把两个不同区域的数据库联系起来,以致可以访问在远程数据库的表中的数据,而不是本地的表.->专门针对远程数据库的实现->一般情况下 ...

  7. python设计模式-单例模式

    单例模式应用场景 代码的设计模式共有25种,设计模式其实是代码无关的.其目的是基于OOP的思想,不同应用场景应用不同的设计模式,从而达到简化代码.利于扩展.提示性能等目的.本文简述Python实现的单 ...

  8. CentOS下搭建Hive

    目录 下载解压hive mysql驱动 配置文件 hive-env.sh hive-site.xml 首次启动hive 使用schemaTool初始化mysql数据库 错误总结 警告汇总 参考:htt ...

  9. IntelliJ IDEA2017 修改缓存文件的路径

    IDEA的缓存文件夹.IntelliJIdea2017.1,存放着IDEA的破解密码,各个项目的缓存,默认是在C盘的用户目录下,目前有1.5G大小.现在想要把它从C盘移出. 在IDEA的安装路径下中, ...

  10. MT【165】分段函数

    (2018浙江省赛12题改编)设$a\in R$,且对任意的实数$b$均有$\max\limits_{x\in[0,1]}|x^2+ax+b|\ge\dfrac{1}{4}$求$a$ 的范围. 提示: ...