KNN实现手写数字识别


博客上显示这个没有Jupyter的好看,想看Jupyter Notebook的请戳KNN实现手写数字识别.ipynb

1 - 导入模块

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from ld_mnist import load_digits %matplotlib inline

2 - 导入数据及数据预处理

import tensorflow as tf

# Import MNIST data
from tensorflow.examples.tutorials.mnist import input_data def load_digits():
mnist = input_data.read_data_sets("path/", one_hot=True)
return mnist
mnist = load_digits()
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\train-labels-idx1-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-images-idx3-ubyte.gz
Extracting C:/Users/marsggbo/Documents/Code/ML/TF Tutorial/data/MNIST_data\t10k-labels-idx1-ubyte.gz

数据维度

print("Train: "+ str(mnist.train.images.shape))
print("Train: "+ str(mnist.train.labels.shape))
print("Test: "+ str(mnist.test.images.shape))
print("Test: "+ str(mnist.test.labels.shape))
Train: (55000, 784)
Train: (55000, 10)
Test: (10000, 784)
Test: (10000, 10)

mnist数据采用的是TensorFlow的一个函数进行读取的,由上面的结果可以知道训练集数据X_train有55000个,每个X的数据长度是784(28*28)。

x_train, y_train, x_test, y_test = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels

展示手写数字

nums = 6
for i in range(1,nums+1):
plt.subplot(1,nums,i)
plt.imshow(x_train[i].reshape(28,28), cmap="gray")

3 - 构建模型

class Knn():

    def __init__(self,k):
self.k = k
self.distance = {} def topKDistance(self, x_train, x_test):
'''
计算距离,这里采用欧氏距离
'''
print("计算距离...")
distance = {}
for i in range(x_test.shape[0]):
dis1 = x_train - x_test[i]
dis2 = np.sqrt(np.sum(dis1*dis1, axis=1))
distance[str(i)] = np.argsort(dis2)[:self.k]
if i%1000==0:
print(distance[str(i)])
return distance def predict(self, x_train, y_train, x_test):
'''
预测
'''
self.distance = self.topKDistance(x_train, x_test)
y_hat = []
print("选出每项最佳预测结果")
for i in range(x_test.shape[0]):
classes = {}
for j in range(self.k):
num = np.argmax(y_train[self.distance[str(i)][j]])
classes[num] = classes.get(num, 0) + 1
sortClasses = sorted(classes.items(), key= lambda x:x[1], reverse=True)
y_hat.append(sortClasses[0][0])
y_hat = np.array(y_hat).reshape(-1,1)
return y_hat def fit(self, x_train, y_train, x_test, y_test):
'''
计算准确率
'''
print("预测...")
y_hat = self.predict(x_train, y_train, x_test)
# index_hat =np.argmax(y_hat , axis=1)
print("计算准确率...")
index_test = np.argmax(y_test, axis=1).reshape(1,-1)
accuracy = np.sum(y_hat.reshape(index_test.shape) == index_test)*1.0/y_test.shape[0]
return accuracy, y_hat
clf = Knn(10)
accuracy, y_hat = clf.fit(x_train,y_train,x_test,y_test)
print(accuracy)
预测...
计算距离...
[48843 33620 11186 22059 42003 9563 39566 10260 35368 31395]
[54214 4002 11005 15264 49069 8791 38147 47304 51494 11053]
[46624 10708 22134 20108 48606 19774 7855 43740 51345 9308]
[ 8758 47844 50994 45610 1930 3312 30140 17618 910 51918]
[14953 1156 50024 26833 26006 38112 31080 9066 32112 41846]
[45824 14234 48282 28432 50966 22786 40902 52264 38552 44080]
[24878 4655 20258 36065 30755 15075 35584 12152 4683 43255]
[48891 20744 47822 53511 54545 27392 10240 3970 25721 30357]
[ 673 17747 33803 20960 25463 35723 969 50577 36714 35719]
[ 8255 42067 53282 14383 14073 52083 7233 8199 8963 12617]
选出每项最佳预测结果
计算准确率...
0.9672

准确率略高。



MARSGGBO♥原创





2017-8-21

KNN实现手写数字识别的更多相关文章

  1. 机器学习(二)-kNN手写数字识别

    一.kNN算法是机器学习的入门算法,其中不涉及训练,主要思想是计算待测点和参照点的距离,选取距离较近的参照点的类别作为待测点的的类别. 1,距离可以是欧式距离,夹角余弦距离等等. 2,k值不能选择太大 ...

  2. 一看就懂的K近邻算法(KNN),K-D树,并实现手写数字识别!

    1. 什么是KNN 1.1 KNN的通俗解释 何谓K近邻算法,即K-Nearest Neighbor algorithm,简称KNN算法,单从名字来猜想,可以简单粗暴的认为是:K个最近的邻居,当K=1 ...

  3. kaggle 实战 (1): PCA + KNN 手写数字识别

    文章目录 加载package read data PCA 降维探索 选择50维度, 拆分数据为训练集,测试机 KNN PCA降维和K值筛选 分析k & 维度 vs 精度 预测 生成提交文件 本 ...

  4. Kaggle竞赛丨入门手写数字识别之KNN、CNN、降维

    引言 这段时间来,看了西瓜书.蓝皮书,各种机器学习算法都有所了解,但在实践方面却缺乏相应的锻炼.于是我决定通过Kaggle这个平台来提升一下自己的应用能力,培养自己的数据分析能力. 我个人的计划是先从 ...

  5. 基于OpenCV的KNN算法实现手写数字识别

    基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...

  6. K近邻实战手写数字识别

    1.导包 import numpy as np import operator from os import listdir from sklearn.neighbors import KNeighb ...

  7. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  8. CNN 手写数字识别

    1. 知识点准备 在了解 CNN 网络神经之前有两个概念要理解,第一是二维图像上卷积的概念,第二是 pooling 的概念. a. 卷积 关于卷积的概念和细节可以参考这里,卷积运算有两个非常重要特性, ...

  9. 【深度学习系列】PaddlePaddle之手写数字识别

    上周在搜索关于深度学习分布式运行方式的资料时,无意间搜到了paddlepaddle,发现这个框架的分布式训练方案做的还挺不错的,想跟大家分享一下.不过呢,这块内容太复杂了,所以就简单的介绍一下padd ...

随机推荐

  1. 《Linux内核分析》第一周学习笔记

    <Linux内核分析>第一周学习笔记 计算机是如何工作的 郭垚 原创作品转载请注明出处 <Linux内核分析>MOOC课程http://mooc.study.163.com/c ...

  2. Linux第七周学习总结——可执行程序的装载

    Linux第七周学习总结--可执行程序的装载 作者:刘浩晨 [原创作品转载请注明出处] <Linux内核分析>MOOC课程http://mooc.study.163.com/course/ ...

  3. 【壹拾壹周】final_review

    项目名:俄罗斯方块 组名:新蜂 组长:武志远 组员:宫成荣 杨柳 谢孝淼 李桥 final Review会议 时间:2016.12.3 会议内容 设想和目标 1.在final阶段发布时的预期目标是什么 ...

  4. 转帖 IBM要推POWER9,来了解一下POWER处理器的前世今生

    https://blog.csdn.net/kwame211/article/details/76669555 先来说一下最新的POWER 9 在Hot Chips会议上首次提到的IBM Power ...

  5. golang 实现线程池

    package main import ( "fmt" "time" ) type Pool struct { Queue chan func() error; ...

  6. 使用 TListView 控件(4)

    本例效果图: 代码文件: unit Unit1; interface uses   Windows, Messages, SysUtils, Variants, Classes, Graphics, ...

  7. ubuntu14.04如何设置静态IP的方法

    第一步: 配置静态IP地址: 打开/etc/network/interfaces文件,内容为 auto lo iface lo inet loopback auto eth0 iface eth0 i ...

  8. Maven父子项目配置-多模块(multi-modules)结构

    Maven创建父子项目,这个项目指的是eclipse中的project,idea中的module.使用idea创建的话很简单,可以直接选择项目的父亲,这些网上有很多资料的. 这里说一下创建父子项目时, ...

  9. 【bzoj3091】 城市旅行

    http://www.lydsy.com/JudgeOnline/problem.php?id=3091 (题目链接) 题意 给出一棵无根树,维护四个操作.link,cut,路径加法,路径期望查询. ...

  10. spring hibernate实现动态替换表名(分表)

    1.概述 其实最简单的办法就是使用原生sql,如 session.createSQLQuery("sql"),或者使用jdbcTemplate.但是项目中已经使用了hql的方式查询 ...