使用KNN对MNIST数据集进行实验
由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢。这里只是想测试观察一下KNN的效果而已,不调参。
K选择之前看过貌似最好不要超过20,因此,此处选择了K=10,距离为欧式距离。如果需要改进,可以再调整K来选择最好的成绩。
先跑了一遍不经过scale的,也就是直接使用像素灰度值来计算欧式距离进行比较。发现开始基本稳定在95%的正确率上,吓了一跳。因为本来觉得KNN算是没有怎么“学习”的机器学习算法了,猜测它的特点可能会是在任何情况下都可以用,但都表现的不是最好。所以估计在60%~80%都可以接受。没想到能基本稳定在95%上,确定算法和代码没什么问题后,突然觉得是不是这个数据集比较没挑战性。。。
去MNIST官网(http://yann.lecun.com/exdb/mnist/),上面挂了以该数据集为数据的算法的结果比较。查看了一下KNN,发现有好多,而且错误率基本都在5%以内,甚至能做到1%以内。唔。
跑的结果是,正确率:96.687%。也就是说,错误率error rate为3.31%左右。
再跑一下经过scale的数据,即对灰度数据归一化到[0,1]范围内。看看效果是否有所提升。
经过scale,最终跑的结果是,正确率:竟然也是96.687%! 也就是说,对于该数据集下,对KNN的数据是否进行归一化并无效果!
在跑scale之前,个人猜测:由于一般对数据进行处理之前都进行归一化,防止高维诅咒(在784维空间中很容易受到高维诅咒)。因此,预测scale后会比前者要好一些的。但是,现在看来二者结果相同。也就是说,对于K=10的KNN算法中,对MNIST的预测一样的。
对scale前后的正确率相同的猜测:由于在训练集合中有60000个数据点,因此0-9每个分类平均都有6000个数据点,在这样的情况下,对于测试数据集中的数据点,相临近的10个点中大部分都是其他分类而导致分类错误的概率会比较地(毕竟10相对与6000来说很小),所以,此时,KNN不仅可以取得较好的分类效果,而且对于是否scale并不敏感,效果相同。
代码如下:
- #KNN for MNIST
- from numpy import *
- import operator
- def line2Mat(line):
- line = line.strip().split(' ')
- label = line[0]
- mat = []
- for pixel in line[1:]:
- pixel = pixel.split(':')[1]
- mat.append(float(pixel))
- return mat, label
- #matrix should be type: array. Or classify() will get error.
- def file2Mat(fileName):
- f = open(fileName)
- lines = f.readlines()
- matrix = []
- labels = []
- for line in lines:
- mat, label = line2Mat(line)
- matrix.append(mat)
- labels.append(label)
- print 'Read file '+str(fileName) + ' to matrix done!'
- return array(matrix), labels
- #classify mat with trained data: matrix and labels. With KNN's K set.
- def classify(mat, matrix, labels, k):
- diffMat = tile(mat, (shape(matrix)[0], 1)) - matrix
- #diffMat = array(diffMat)
- sqDiffMat = diffMat ** 2
- sqDistances = sqDiffMat.sum(axis=1)
- distances = sqDistances ** 0.5
- sortedDistanceIndex = distances.argsort()
- classCount = {}
- for i in range(k):
- voteLabel = labels[sortedDistanceIndex[i]]
- classCount[voteLabel] = classCount.get(voteLabel,0) + 1
- sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1),reverse=True)
- return sortedClassCount[0][0]
- def classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K):
- rightCnt = 0
- for i in range(len(testMatrix)):
- if i % 100 == 0:
- print 'num '+str(i)+'. ratio: '+ str(float(rightCnt)/(i+1))
- label = testLabels[i]
- predictLabel = classify(testMatrix[i], trainMatrix, trainLabels, K)
- if label == predictLabel:
- rightCnt += 1
- return float(rightCnt)/len(testMatrix)
- trainFile = 'train_60k.txt'
- testFile = 'test_10k.txt'
- trainMatrix, trainLabels = file2Mat(trainFile)
- testMatrix, testLabels = file2Mat(testFile)
- K = 10
- rightRatio = classifyFiles(trainMatrix, trainLabels, testMatrix, testLabels, K)
- print 'classify right ratio:' +str(right)
使用KNN对MNIST数据集进行实验的更多相关文章
- 使用libsvm对MNIST数据集进行实验
使用libsvm对MNIST数据集进行实验 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出话来. 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libs ...
- 使用libsvm对MNIST数据集进行实验---浅显易懂!
原文:http://blog.csdn.net/arthur503/article/details/19974057 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出 ...
- 使用PCA + KNN对MNIST数据集进行手写数字识别
首先引入需要的包 %matplotlib inline import numpy as np import scipy as sp import pandas as pd import matplot ...
- 使用Decision Tree对MNIST数据集进行实验
使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵. 使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何.实验结果如下. #Te ...
- 机器学习(2) - KNN识别MNIST
代码 https://github.com/s055523/MNISTTensorFlowSharp 数据的获得 数据可以由http://yann.lecun.com/exdb/mnist/下载.之后 ...
- 机器学习(1) - TensorflowSharp 简单使用与KNN识别MNIST流程
机器学习是时下非常流行的话题,而Tensorflow是机器学习中最有名的工具包.TensorflowSharp是Tensorflow的C#语言表述.本文会对TensorflowSharp的使用进行一个 ...
- 【转载】用Scikit-Learn构建K-近邻算法,分类MNIST数据集
原帖地址:https://www.jiqizhixin.com/articles/2018-04-03-5 K 近邻算法,简称 K-NN.在如今深度学习盛行的时代,这个经典的机器学习算法经常被轻视.本 ...
- 机器学习:PCA(实例:MNIST数据集)
一.数据 获取数据 import numpy as np from sklearn.datasets import fetch_mldata mnist = fetch_mldata("MN ...
- Caffe初试(二)windows下的cafee训练和测试mnist数据集
一.mnist数据集 mnist是一个手写数字数据库,由Google实验室的Corinna Cortes和纽约大学柯朗研究院的Yann LeCun等人建立,它有60000个训练样本集和10000个测试 ...
随机推荐
- SpringJUnit4ClassRunner拉起来的单元测试怎么装配Container实例
由于历史代码的原因,产品中部分spring装配的实例需要通过Container的实现类(自定义的)去获取.那么当在单元测试中怎么实例化这个Container实现呢? 实例化Container实现需要A ...
- Struts2上传图片时报404错误
可能是struts配置文件中定义的拦截器导致的,后缀拦截导致,将该拦截器去掉,在action类里判断后缀 public String upload()throws Exception{ ActionC ...
- Excel公式无法重算,暂无法解决
一份复杂的excel报表,某些单元格是用求和公式算出来的值,但生成之后,用excel打开,无法显示公式结果,按F9也没有用,只能在单元格公式双击后回车才会显示.而在WPS2010按F9就可以重算,WP ...
- layoutsubviews什么时候调用
layoutSubviews在以下情况下会被调用:1.init初始化不会触发layoutSubviews2.addSubview会触发layoutSubviews3.设置view的Frame会触发la ...
- 转:python webdriver API 之cookie 处理
有时候我们需要验证浏览器中是否存在某个 cookie,因为基于真实的 cookie 的测试是无法通过白盒和集成测试完成的.webdriver 可以读取.添加和删除 cookie 信息.webdrive ...
- Spring之我见
Spring 是什么(1) •Spring 是一个开源框架. •Spring 为简化企业级应用开发而生. 使用 Spring 可以使简单的 JavaBean 实现以前只有 EJB 才能实现的功能. • ...
- struts_24_基于XML校验的规则、特点
当为某个action提供了ActionClassName-validation.xml和ActionClassName-ActionName-validation.xml两种规则的校验文件时,系统按下 ...
- paper 58 :机器视觉学习笔记(1)——OpenCV配置
开始学习opencv! 1.什么是OpenCV OpenCV的全称是:Open Source Computer Vision Library.OpenCV是一个基于(开源)发行的跨平台计算机视觉库,可 ...
- java post请求
package com.jfbank.loan.intf.util; import java.io.IOException;import java.util.ArrayList;import java ...
- 夺命雷公狗---DEDECMS----26dedecms面包屑导航的实现
我们在很多项目里面都会用到面包屑导航,而dedecms里面也是给我们封装好面包屑导航的了,如下图所示: 在dede里面实现面包屑导航主要用到{dede:field.position/}标签,我们首先来 ...