KNN实现手写数字识别


博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb

1 - 导入模块

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits %matplotlib inline

2 - 导入数据及数据预处理

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz

数据维度

print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)

mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。

x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

展示手写数字

nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")

3 - 构建模型

class Knn():

    def __init__(self,k):
self.k = k
self.distance = {} def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672

准确率略高。



MARSGGBO♥原创





2017-8-21

KNN实现手写数字识别的更多相关文章

  1. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  2. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  3. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  4. Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维

    引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...

  5. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  6. K近邻实战手写数字识别

    1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

随机推荐

  1. Python中的字典详解

    https://www.cnblogs.com/yjd_hycf_space/p/6880026.html

  2. Java中的基本数据据类型

    1.整数类型 类型 字节数 表示范围 byte 1 -128~127 short 2 -32768 ~ 32767 int 4 -2147483648~2147483647 long 8 -92233 ...

  3. ElasticSearch 2 (11) - 节点调优(ElasticSearch性能)

    ElasticSearch 2 (11) - 节点调优(ElasticSearch性能) 摘要 一个ElasticSearch集群需要多少个节点很难用一种明确的方式回答,但是,我们可以将问题细化成一下 ...

  4. HDU 2097 Sky数

    http://acm.hdu.edu.cn/showproblem.php?pid=2097 Problem Description Sky从小喜欢奇特的东西,而且天生对数字特别敏感,一次偶然的机会, ...

  5. spring动态数据源+事务

    今天在尝试配置spring的动态数据源和事务管理的时候,遇到了几处配置上的问题,在此记录下: 1.使用了spring的aop思想,实现了动态数据源的切换. 2.spring的事务管理,是基于数据源的, ...

  6. require.js text 插件使用

    相比于使用script构建DOM结构,使用HTML标签来构建html是一个很好的方式.然而, 并没有很好的方式可以在js文件中嵌入 HTML .最好的方式是使用 HTML字符串, 但这很难管理,尤其实 ...

  7. 2013长春网赛 1006 hdu 4764 Stone(巴什博弈)

    题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4764 题意:Tang 和 Jiang 玩一个游戏,轮流写下一个数,Tang先手,第一次Tang只能写[ ...

  8. TCP 协议连接与关闭的握手

     原文链接 http://blog.csdn.net/oney139/article/details/8103223   TCP头部: 其中 ACK   SYN  序号  这三个部分在以下会用到,它们 ...

  9. BZOJ 4569 [Scoi2016]萌萌哒 | ST表 并查集

    传送门 BZOJ 4569 题解 ST表和并查集是我认为最优雅(其实是最好写--)的两个数据结构. 然鹅!他俩加一起的这道题,我却--没有做出来-- 咳咳. 正解是这样的: 类似ST表有\(\log ...

  10. 【转】STM32 - 程序跳转、中断、开关总中断

    程序跳转注意: 1.如果跳转之前的程序A里有些中断没有关,在跳转之后程序B的中断触发,但程序B里没有定义中断响应函数,找不到地址会导致死机. 2.程序跳转前关总中断,程序跳转后开总中断(关总中断,只是 ...