由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢。这里只是想测试观察一下KNN的效果而已,不调参。

K选择之前看过貌似最好不要超过20,因此,此处选择了K=10,距离为欧式距离。如果需要改进,可以再调整K来选择最好的成绩。

先跑了一遍不经过scale的,也就是直接使用像素灰度值来计算欧式距离进行比较。发现开始基本稳定在95%的正确率上,吓了一跳。因为本来觉得KNN算是没有怎么“学习”的机器学习算法了,猜测它的特点可能会是在任何情况下都可以用,但都表现的不是最好。所以估计在60%~80%都可以接受。没想到能基本稳定在95%上,确定算法和代码没什么问题后,突然觉得是不是这个数据集比较没挑战性。。。

去MNIST官网(http://yann.lecun.com/exdb/mnist/),上面挂了以该数据集为数据的算法的结果比较。查看了一下KNN,发现有好多,而且错误率基本都在5%以内,甚至能做到1%以内。唔。

跑的结果是,正确率:96.687%。也就是说,错误率error rate为3.31%左右。

再跑一下经过scale的数据,即对灰度数据归一化到[0,1]范围内。看看效果是否有所提升。

经过scale,最终跑的结果是,正确率:竟然也是96.687%! 也就是说,对于该数据集下,对KNN的数据是否进行归一化并无效果!

在跑scale之前,个人猜测:由于一般对数据进行处理之前都进行归一化,防止高维诅咒(在784维空间中很容易受到高维诅咒)。因此,预测scale后会比前者要好一些的。但是,现在看来二者结果相同。也就是说,对于K=10的KNN算法中,对MNIST的预测一样的。

对scale前后的正确率相同的猜测:由于在训练集合中有60000个数据点,因此0-9每个分类平均都有6000个数据点,在这样的情况下,对于测试数据集中的数据点,相临近的10个点中大部分都是其他分类而导致分类错误的概率会比较地(毕竟10相对与6000来说很小),所以,此时,KNN不仅可以取得较好的分类效果,而且对于是否scale并不敏感,效果相同。

代码如下:

  1. #KNN for MNIST
  2. from numpy import *
  3. import operator
  4. def line2Mat(line):
  5. line = line.strip().split(' ')
  6. label = line[0]
  7. mat = []
  8. for pixel in line[1:]:
  9. pixel = pixel.split(':')[1]
  10. mat.append(float(pixel))
  11. return mat, label
  12. #matrix should be type: array. Or classify() will get error.
  13. def file2Mat(fileName):
  14. f = open(fileName)
  15. lines = f.readlines()
  16. matrix = []
  17. labels = []
  18. for line in lines:
  19. mat, label = line2Mat(line)
  20. matrix.append(mat)
  21. labels.append(label)
  22. print 'Read file '+str(fileName) + ' to matrix done!'
  23. return array(matrix), labels
  24. #classify mat with trained data: matrix and labels. With KNN's K set.
  25. def classify(mat, matrix, labels, k):
  26. diffMat = tile(mat, (shape(matrix)[0], 1)) - matrix
  27. #diffMat = array(diffMat)
  28. sqDiffMat = diffMat ** 2
  29. sqDistances = sqDiffMat.sum(axis=1)
  30. distances = sqDistances ** 0.5
  31. sortedDistanceIndex = distances.argsort()
  32. classCount = {}
  33. for i in range(k):
  34. voteLabel = labels[sortedDistanceIndex[i]]
  35. classCount[voteLabel] = classCount.get(voteLabel,0) + 1
  36. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)
  37. return sortedClassCount[0][0]
  38. def classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K):
  39. rightCnt = 0
  40. for i in range(len(testMatrix)):
  41. if i % 100 == 0:
  42. print 'num '+str(i)+'. ratio: '+ str(float(rightCnt)/(i+1))
  43. label = testLabels[i]
  44. predictLabel = classify(testMatrix[i], trainMatrix, trainLabels, K)
  45. if label == predictLabel:
  46. rightCnt += 1
  47. return float(rightCnt)/len(testMatrix)
  48. trainFile = 'train_60k.txt'
  49. testFile = 'test_10k.txt'
  50. trainMatrix, trainLabels = file2Mat(trainFile)
  51. testMatrix, testLabels = file2Mat(testFile)
  52. K = 10
  53. rightRatio = classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K)
  54. print 'classify right ratio:' +str(right)

使用KNN对MNIST数据集进行实验的更多相关文章

  1. 使用libsvm对MNIST数据集进行实验

    使用libsvm对MNIST数据集进行实验 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出话来. 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libs ...

  2. 使用libsvm对MNIST数据集进行实验---浅显易懂!

    原文:http://blog.csdn.net/arthur503/article/details/19974057 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出 ...

  3. 使用PCA + KNN对MNIST数据集进行手写数字识别

    首先引入需要的包 %matplotlib inline import numpy as np import scipy as sp import pandas as pd import matplot ...

  4. 使用Decision Tree对MNIST数据集进行实验

    使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵. 使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何.实验结果如下. #Te ...

  5. 机器学习(2) - KNN识别MNIST

    代码 https://github.com/s055523/MNISTTensorFlowSharp 数据的获得 数据可以由http://yann.lecun.com/exdb/mnist/下载.之后 ...

  6. 机器学习(1) - TensorflowSharp 简单使用与KNN识别MNIST流程

    机器学习是时下非常流行的话题,而Tensorflow是机器学习中最有名的工具包.TensorflowSharp是Tensorflow的C#语言表述.本文会对TensorflowSharp的使用进行一个 ...

  7. 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集

    原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...

  8. 机器学习:PCA(实例:MNIST数据集)

    一.数据 获取数据 import numpy as np from sklearn.datasets import fetch_mldata mnist = fetch_mldata("MN ...

  9. Caffe初试(二)windows下的cafee训练和测试mnist数据集

    一.mnist数据集 mnist是一个手写数字数据库,由Google实验室的Corinna Cortes和纽约大学柯朗研究院的Yann LeCun等人建立,它有60000个训练样本集和10000个测试 ...

随机推荐

  1. TNS-01251: Cannot set trace/log directory under ADR

    试图改变监听日志的名称时,报出TNS-01251错误: $ lsnrctl LSNRCTL - Production on -JUN- :: Copyright (c) , , Oracle. All ...

  2. Java基础之处理事件——实现低级事件监听器(Sketcher 2 implementing a low-level listener)

    控制台程序. 定义事件监听器的类必须实现监听器接口.所有的事件监听器接口都扩展了java.util.EventListener接口.这个接口没有声明任何方法,仅仅用于表示监听器对象.使用EventLi ...

  3. Angular.js+Bootstrap实现手风琴菜单

    说是Angular.js+Bootstrap实现手风琴菜单,其实就是用了Bootstrap的样式而已. 上一篇实现了表格+分页,接着学习实现的Demo. 主要练习自定义指令,向指令中传递参数,老规矩先 ...

  4. 通用窗口类 Inventory Pro 2.1.2 Demo1(下续篇 ),物品消耗扇形显示功能

    本篇想总结的是Inventory Pro中通用窗口的具体实现,但还是要强调下该插件的重点还是装备系统而不是通用窗口系统,所以这里提到的通用窗口类其实是通用装备窗口类(其实该插件中也有非装备窗口比如No ...

  5. EL表达<%@page isELIgnored="false"%>问题

    上网查找资料后得知:主要原因是EL表达式无法被解析到. 其实从后台取值并传值到前台来根本就没有错,而前台JSP页面EL表达式无效,解析不到EL表达式,引起的原因是web.xml中: <web-a ...

  6. tomcat部署方法总结

    可以参考之前的:http://www.cnblogs.com/youxin/archive/2013/01/18/2865814.html 在Tomcat中部署Java Web应用程序有两种方式:静态 ...

  7. C++Primer 第十章

    //1.标准库算法不仅可以应用于容器,还可以应用于内置数组,指针. //2.大多数算法都定义在头文件algorithm中.标准库还在头文件numeric中定义了一组数值泛型算法. //3.算法本身不会 ...

  8. 刨根问底U3D---关于Delegate的方方面面

    我是否需要阅读这篇文章 Code1: private delegate void BloodNumDelegate (); public delegate void ExpNumChangeDeleg ...

  9. html 字体加粗

    <font style="font-weight: bold;">无敌小昆虫</font> <font>无敌小昆虫</font> f ...

  10. 转:Python一些特殊用法(map、reduce、filter、lambda、列表推导式等)

    Map函数: 原型:map(function, sequence),作用是将一个列表映射到另一个列表, 使用方法: def f(x): return x**2 l = range(1,10) map( ...