实战一:kNN手写识别系统

本文将一步步地构造使用K-近邻分类器的手写识别系统。由于能力有限,这里构造的系统只能识别0-9。需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:32像素*32像素的黑白图像。

当前使用文本格式存储图像,即使不能有效的利用空间,但是为了方便理解,还是将图像转换成文本格式。

示例:使用k-近邻算法的手写识别系统

(1)收集数据:提供文本文件。

(2)处理数据:编写img2vector()函数,将图像格式转换成分类器使用的向量格式。

(3)分析数据:在Python命令提示符中检查数据,确保它符合要求。

(4)训练算法:此步骤不适用于k-近邻算法。

(5)测试算法:编写函数使用提供的部分数据集作为测试样本,对学习算法进行测试。

(6)使用算法:本例没有完成此步骤

准备数据:将图像转换为测试向量

我们所使用的两个文件trainingDigits中包含了大约2000个例子,每个数字大约有200个样本;测试文件testDigits中包含了大约900个测试数据。两组数据没有重叠。为了使用kNN算法分类器必须将一个

32*32的二进制矩阵转换为1*1024的向量,以便我们使用分类器处理数字图像信息。

首先我们定义img2vector()函数,将32*32的二进制矩阵转换成1*1024的矩阵并返回:

def img2vector(filename):
returnVector = zeros((,))
with open(filename) as fr:
for i in range():
lineStr = fr.readline()
for j in range():
returnVector[,*i+j] = lineStr[j]
return returnVector

执行下述代码:

testVector = img2vector("testDigits/0_13.txt")
print(testVector[0,0:31])
print(testVector[0,32:61])

得到结果:

[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 0. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0. 0. 0.]
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 0. 0. 0. 0. 0.
0. 0. 0. 0. 0.]

k-近邻算法

k-近邻算法的一般流程

        1. 收集数据:可使用任何方法
        2. 准备数据:距离计算所需要的数值,最好是结构化的数据格式。
        3. 分析数据:可以使用任何方法。
        4. 训练算法:此步骤不适用于K-近邻算法
        5. 使用算法:首先需要输入样本数据和节后话的输出结果,然后运行k-近邻算法判定输入数据分别属于哪个分类,最后应用对计算出的分类执行后续的处理
#kNN分类器
def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0] #得到数据总量
diffMat = tile(inX,(dataSetSize,1)) - dataSet #将输入数据扩充成与数据集同样大小的矩阵并作差
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1) #axis = 1 参数是维度参数等于1在此处表示将一个矩阵的每一行向量相加
distances = sqDistances** 0.5
sortedDistancesIndicies = distances .argsort() #将列表值进行对比返回一个按照数值升序的下标值
classCount={}
for i in range(k):
voteIlabel = labels[sortedDistancesIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0)+1
#dict.get("key") 返回value dict.get("key",default= None)如果能找到就返回对应的value找不到返回默认值
sortedClassCount = sorted(classCount.items(),key = operator.itemgetter(1),reverse=True)
#sorted 返回一个list operator.itemgetter(x,y)表示根据x+1维度的第y+1维度
return sortedClassCount[0][0]

测试算法:使用k-近邻算法识别手写数字

现在我们得到处理完成的数据还有分类算法,现在我们需要构造handwritingClassTest()函数进行分类器测试。为了处理大量的文本文件我们需要from os import listdir用于列出指定目录的文件名,读取多个

数字文本文件。

def handwritingClassTest():
hwLabels = [] #训练数据真实值数组
trainingFileList = listdir("trainingDigits") #获取trainingDigits文件子目录的列表
m = len(trainingFileList) #获得训练数据总数
trainingMat = zeros((m,1024)) #初始化训练数据矩阵
for i in range(m): #循环将trainingDigits文件下的训练数据文本文件放入矩阵traningMat中,真实值放入hwLabels中
fileNameStr = trainingFileList[i] #获取该次循环的文件名字符串
fileStr = fileNameStr.split('.')[0] #将获得的字符串按分隔符'.'分隔并取第一个即去拓展名的文件名
classNumber = int(fileStr.split('_')[0]) #获取训练数据的真实值 非numpy数据需要指定数据类型int
hwLabels.append(classNumber) #将得到的单个真实值按顺序加入到真实值列表hwLabels中
trainingMat[i,:] = img2vector("trainingDigits/%s"%fileNameStr) #把32*32的二进制文本文件转换成1*1024矩阵并按行存储到训练数据总矩阵中
testFileList = listdir("testDigits")
errorCount = 0.0 #错误预测计数器
mTest = len(testFileList) #测试数据总量
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumber = int(fileStr.split('_')[0])
vectorUnderTest = img2vector("testDigits/%s"%fileNameStr)
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3) #用kNN分类算法分类
if(classifierResult != classNumber) : #判断预测是否正确,不正确计数器+1打印错误预测
errorCount +=1.0
print("预测值为:%d ,真实值为:%d " % (classifierResult, classNumber))
print("测试总数:%d,预测错误总数:%d ,错误率为:%f"%(mTest,errorCount,errorCount/float(mTest)))
handwritingClassTest()

执行效果:

预测值为:7 ,真实值为:1
预测值为:9 ,真实值为:3
预测值为:3 ,真实值为:5
预测值为:6 ,真实值为:5
预测值为:6 ,真实值为:8
预测值为:3 ,真实值为:8
预测值为:1 ,真实值为:8
预测值为:1 ,真实值为:8
预测值为:1 ,真实值为:9
预测值为:7 ,真实值为:9
测试总数:946,预测错误总数:10 ,错误率为:0.010571

总结

k-近邻算法识别手写数据集,错误率为1%。改变kNN分类函数中的k值、修改训练样本的内容和数目都会对错误率产生影响,可以改变这些数值观察错误率的变化。实际使用这个算法的时候,算法的执行

效率并不高。我们需要进行2000次距离计算,每个距离计算包括了1024个维度的浮点数,总计执行900次,此外,我们还要为测试向量准备存储空间。期待有更好的算法能够改进。

机器学习实战一:kNN手写识别系统的更多相关文章

  1. 第三篇:基于K-近邻分类算法的手写识别系统

    前言 本文将继续讲解K-近邻算法的项目实例 - 手写识别系统. 该系统在获取用户的手写输入后,判断用户写的是什么. 为了突出核心,简化细节,本示例系统中的输入为32x32矩阵,分类结果也均为数字.但对 ...

  2. 人工智能-深度学习(3)TensorFlow 实战一:手写图片识别

    http://gitbook.cn/gitchat/column/59f7e38160c9361563ebea95/topic/59f7e86d60c9361563ebeee5 wiki.jikexu ...

  3. 【Machine Learning in Action --2】K-近邻算法构造手写识别系统

    为了简单起见,这里构造的系统只能识别数字0到9,需要识别的数字已经使用图形处理软件,处理成具有相同的色彩和大小:宽高是32像素的黑白图像.尽管采用文本格式存储图像不能有效地利用内存空间,但是为了方便理 ...

  4. k-近邻算法-手写识别系统

    手写数字是32x32的黑白图像.为了能使用KNN分类器,我们需要把32x32的二进制图像转换为1x1024 1. 将图像转化为向量 from numpy import * # 导入科学计算包numpy ...

  5. 《机器学习实战》之k-近邻算法(手写识别系统)

    这个玩意和改进约会网站的那个差不多,它是提前把所有数字转换成了32*32像素大小的黑白图,然后转换成字符图(用0,1表示),将所有1024个像素点用一维矩阵保存下来,这样就可以通过knn计算欧几里得距 ...

  6. knn手写识别

    import numpy as np import operator import os #KNN算法 def knn(k,testdata,traindata,labels):#(k,测试样本,训练 ...

  7. 机器学习实战kNN之手写识别

    kNN算法算是机器学习入门级绝佳的素材.书上是这样诠释的:“存在一个样本数据集合,也称作训练样本集,并且样本集中每个数据都有标签,即我们知道样本集中每一条数据与所属分类的对应关系.输入没有标签的新数据 ...

  8. python 实现 KNN 分类器——手写识别

    1 算法概述 1.1 优劣 优点:进度高,对异常值不敏感,无数据输入假定 缺点:计算复杂度高,空间复杂度高 应用:主要用于文本分类,相似推荐 适用数据范围:数值型和标称型 1.2 算法伪代码 (1)计 ...

  9. 10分钟搞懂Tensorflow 逻辑回归实现手写识别

    1. Tensorflow 逻辑回归实现手写识别 1.1. 逻辑回归原理 1.1.1. 逻辑回归 1.1.2. 损失函数 1.2. 实例:手写识别系统 1.1. 逻辑回归原理 1.1.1. 逻辑回归 ...

随机推荐

  1. 【luogu T24743 [愚人节题目5]永世隔绝的理想乡】 题解

    题意翻译 我们来说说王的故事吧. 星之内海,瞭望之台.从乐园的角落告知汝等.汝等的故事充满了祝福.只有无罪之人可以进入——『永世隔绝的理想乡(Garden of Avalon)』! 题目背景 zcy入 ...

  2. 在react中实现CSS模块化

    react中使用普通的css样式表会造成作用域的冲突,css定义的样式的作用域是全局,在Vue 中我们还可以使用scope来定义作用域,但是在react中并没有指令一说,所以只能另辟蹊径了.下面我将简 ...

  3. Spring知识点总结(三)之Spring DI

    1. IOC(DI) - 控制反转(依赖注入) 所谓的IOC称之为控制反转,简单来说就是将对象的创建的权利及对象的生命周期的管理过程交由Spring框架来处理,从此在开发过程中不再需要关注对象的创建和 ...

  4. PThread 学习笔记

    POSIX 线程,也被称为Pthreads,是一个线程的POSIX标准: pthread.h int pthread_create(pthread_t * thread, pthread_attr_t ...

  5. ABAP术语-Business Connector

    Business Connector 原文:http://www.cnblogs.com/qiangsheng/archive/2007/12/27/1016379.html XML-based st ...

  6. thinkphp3.2.3实现多条件查询实例.

    $data = M("datainfo"); $projectsname = I('get.projectsname');//前台提交的模糊查询字段 // 查询条件 $where ...

  7. sftp上传到远程服务器

    开发遇到一个需求,需要将图片通过sftp上传到远程服务器上,之前没用过这个功能,折腾了我好几天才搞定,下面记录下我的处理方法: $sftp = 'ssh2.sftp://';//连接sftp $con ...

  8. js 节点

    var chils= s.childNodes; //得到s的全部子节点 var par=s.parentNode;  //得到s的父节点 var ns=s.nextSbiling;  //获得s的下 ...

  9. pyspider -- 禁止请求非200响应码抛异常

    在pyspider中若crawl()网址时出现非200的异常信息,会抛出一个异常. 可以在对应的回调函数上面通过@catch_status_code_error 进行修饰,这样就能不抛出异常正常进入回 ...

  10. mysql5.7数据库与5.7之前版本比较

    数据库初始化方式变更 <5.7 版本 mysql_install_db >5.7 版本 bin/mysqld --initialize --user =mysql --basedir=/u ...