有一句话这样说:如果你想了解一个人,你可以从他身边的朋友开始。

如果与他交往的好友都是一些品行高尚的人,那么可以认为这个人的品行也差不了。

其实古人在这方面的名言警句,寓言故事有很多。例如:人以类聚,物以群分。近朱者赤近墨者黑

其实K-近邻算法和古人的智慧想通,世间万物息息相通,你中有我,我中有你。

K-近邻原理:

存在一个训练集,我们知道每一个样本的标签,例如训练样本是一群人,他们都有相应特征,例如,爱喝酒或爱看书或逛窑子或打架斗殴或乐于助人等等,并且知道他们是好人还是坏人,然后来了一个新人(新样本),然后把新样本的特征与样本集中数据对应的特征进行比较,然后算法提取集中特征最相似数据的分类标签,就是比较这个新人具有的品行与那一群人中谁的品行相近,选取出样本集中数据中前K个数据(这就是K的来历),然后查看这K个数据的标签,选取出现最多类作为新样本的分类。就是查看选出的这些人,看看是好人多还是坏人多,如果好人多,那么我们就确定这个新人是好人。

K-近邻算法没有训练过程,它直接对新样本进行分类。

代码来源机器学习实战,python3.7可用,详细注释:

#coding=utf-8
from numpy import *
import operator
import os,sys def createDataSet():
#数组转换成矩阵
group = array([[1.0,1.1],[1.0,1.0],[0,0],[0,0.1]])
labels = ['A','A','B','B']
return group,labels #inx为测试样本
def classify0(inx,dataSet,labels,k):
#shape[0]给出行数,shape[1]列数
dataSetSize = dataSet.shape[0]
#把inx矩阵的每一行复制dataSetSize次,列不复制
#为了把该样本与训练集中每一个样本计算出距离
#计算欧氏距离
diffMat = tile(inx,(dataSetSize,1)) - dataSet
#距离的平方差
sqDiffMat = diffMat**2
#把数组每一行求和
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
#argsort 从小到大排序,但是返回的是下标
sortedDistIndices = distances.argsort()
classCount = {}
#k是前k个最小距离
for i in range(k):
#把最小距离对应的标签赋值给voteIlabel
voteIlabel = labels[sortedDistIndices[i]]
#投票算法,统计前k个数据的标签类型及其出现的个数
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
#排序选出出现次数最多的标签,(注意:Python 3 renamed dict.iteritems() -> dict.items())
sortedClassCount = sorted(classCount.items(),
key=operator.itemgetter(1),reverse=True)
return sortedClassCount[0][0] def file2matrix(filename):
fr = open(filename)
#文件有多少行
arrayOLines = fr.readlines()
numberOfLines = len(arrayOLines)
#返回一个(numberOfLines,3)的零矩阵
returnMat = zeros((numberOfLines,3))
classLabelVector = []
index = 0
for line in arrayOLines:
#去除字符串的首尾的字符(空格,回车)
line = line.strip()
listFromLine = line.split('\t')
#复制行给returnMat
returnMat[index,:] = listFromLine[0:3]
#获取标签,这里需要把字符串类型转换成int类型
if listFromLine[-1] == 'largeDoses':
classLabelVector.append(3)
elif listFromLine[-1] == 'smallDoses':
classLabelVector.append(2)
elif listFromLine[-1] == 'didntLike':
classLabelVector.append(1)
else:
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat,classLabelVector #分析数据
'''
控制台输入
import matplotlib
import matplotlib.pyplot as plt
#定义一个图像窗口
fig = plt.figure()
#意思是窗口背划分成1*1个格子,使用第一个格子
ax = fig.add_subplot(111)
#描绘散点图
ax.scatter(datingDataMat[:,1],datingDataMat[:,2])
#使用颜色来分辨
ax.scatter(datingDataMat[:,1],datingDataMat[:,2],15.0*array(datingLabels),15.0*array(datingLabels))
plt.show() '''
#给出的数据集往往会遇见这样的问题,就是每一个特征值的取值不在
#同一个数量级,有的取值会很大,这样会严重影响结果的准确性
#所以要归一化特征值到0~1之间
#公式:newValue = (oldValue-min)/(max-min)
def autoNorm(dataSet):
#返回每一列最小值(1,m)
minVals = dataSet.min(0)
#返回每一列最大值
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals,(m,1))
normDataSet = normDataSet/tile(ranges,(m,1))
return normDataSet,ranges,minVals #分类器针对约会网站的测试代码
def datingClassTest():
hoRatio = 0.10
datingDataMat,datingLabels = file2matrix('datingTestSet.txt')
#归一化
normMat,ranges,minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
#选取数据集的10%作为测试集
numTestVecs = int(m*hoRatio)
errorCount = 0.0
#循环对测试集进行分类,然后计算准确率
for i in range(numTestVecs):
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3)
print ("the classifier came back with:%d,the real answer is:%d"%(classifierResult,datingLabels[i]))
if (classifierResult != datingLabels[i]):errorCount += 1.0
print ("the total error rate is:%f"%(errorCount/float(numTestVecs))) def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
#python3 输入是input
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = array([ffMiles, percentTats, iceCream])
classifierResult = classify0((inArr - minVals)/ranges, normMat, datingLabels, 3)
print ("You will probably like this person: %s" % resultList[classifierResult - 1]) #识别手写数字
#把32*32的矩阵转换成1*1024矩阵
def img2vector(filename):
returnVect = zeros((1,1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0,32*i+j] = int(lineStr[j])
return returnVect def handwritingClassTest():
hwLabels= []
#获取目录的内容
trainingFileList = os.listdir('trainingDigits')
m = len(trainingFileList)
trainingMat = zeros((m,1024))
for i in range(m):
#从文本文件的名称中截取是什么数字
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr)
trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr)
testFileList = os.listdir('testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr)
classifierResult = classify0(vectorUnderTest,trainingMat,hwLabels,3)
#计算精确性
print ("the classifier came back with:%d,the real answer is:%d" % (classifierResult,classNumStr))
if (classifierResult != classNumStr):errorCount += 1.0
print ("\n the total number of errors is:%d" % errorCount)
print ("\n the total error rate is:%f" % (errorCount/float(mTest)))

算法主要有两个主要的步骤:

(1)求解两向量之间的距离来比较相似性:

  

(2) 排序选出前K个相似点,筛选出出现频率最高的类别

  代码中直接调用排序算法,如果对于大量数据,排序会很耗费时间,所以可以优化排序算法:Kd树

筛选评论最高的是通过投票的方式。

上面代码中包括了识别手写体的代码,依然用的是欧氏距离,之前做过一个使用神经网络训练做的手写体数字识别,我想比较这两个算法的准确性。

kNN算法没有训练过程,算法也十分简单,但是在实践的过程中我发现,KNN具有局限性。我的做法是

kNN识别手写体:

先把数字的灰度图转换成32*32的字符文件的格式,然后使用kNN算法,发现不同的测试集的准确性相差很大,如果使用和训练集相近的测试集去测试,所谓相近就是说数字的大小,粗细都会影响识别的准确性,所以我用不同的测试集得到的结果完全不同,如果用训练集去作为测试集使用,准确率会达到99%,但是换一个不同的测试集,准确率就会降到34%左右(比蒙的好一点点)。如果要提高准确性,必须加大

训练集(尽量包含所有的手写体类型),再调整K的取值,如果那样的话,做一次分类,就要对大量的数据集进行比对,排序选出相近的,这样效率非常低。

神经网络识别手写体:

在训练的过程中会消耗时间,但是一旦模型训练完毕,准确率会很高。

所以说kNN算法适合数据集较小的情况的分类。

注意:K-近邻是监督学习,K-Means是无监督学习

我眼中的K-近邻算法的更多相关文章

  1. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

  2. k近邻算法的Java实现

    k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...

  3. 基本分类方法——KNN(K近邻)算法

    在这篇文章 http://www.cnblogs.com/charlesblc/p/6193867.html 讲SVM的过程中,提到了KNN算法.有点熟悉,上网一查,居然就是K近邻算法,机器学习的入门 ...

  4. 从K近邻算法谈到KD树、SIFT+BBF算法

    转自 http://blog.csdn.net/v_july_v/article/details/8203674 ,感谢july的辛勤劳动 前言 前两日,在微博上说:“到今天为止,我至少亏欠了3篇文章 ...

  5. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

  6. k近邻算法

    k 近邻算法是一种基本分类与回归方法.我现在只是想讨论分类问题中的k近邻法.k近邻算法的输入为实例的特征向量,对应于特征空间的点,输出的为实例的类别.k邻近法假设给定一个训练数据集,其中实例类别已定. ...

  7. KNN K~近邻算法笔记

    K~近邻算法是最简单的机器学习算法.工作原理就是:将新数据的每一个特征与样本集中数据相应的特征进行比較.然后算法提取样本集中特征最相似的数据的分类标签.一般来说.仅仅提取样本数据集中前K个最相似的数据 ...

  8. 机器学习03:K近邻算法

    本文来自同步博客. P.S. 不知道怎么显示数学公式以及排版文章.所以如果觉得文章下面格式乱的话请自行跳转到上述链接.后续我将不再对数学公式进行截图,毕竟行内公式截图的话排版会很乱.看原博客地址会有更 ...

  9. 机器学习——KNN算法(k近邻算法)

    一 KNN算法 1. KNN算法简介 KNN(K-Nearest Neighbor)工作原理:存在一个样本数据集合,也称为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分 ...

  10. [机器学习] k近邻算法

    算是机器学习中最简单的算法了,顾名思义是看k个近邻的类别,测试点的类别判断为k近邻里某一类点最多的,少数服从多数,要点摘录: 1. 关键参数:k值 && 距离计算方式 &&am ...

随机推荐

  1. 【温故而知新】HTTP 报文

    HTTP 报文是在 HTTP 应用程序之间发送的数据块.这些数据块以一些文本形式的元信息开头,这些信息描述了报文的内容及含义. 报文流 报文在客户端.服务器和代理之间的流动称为报文流. HTTP 使用 ...

  2. MySQL中Identifier Case Sensitivity

    在MySQL当中,有可能遇到表名大小写敏感的问题.其实这个跟平台(操作系统)有关,也跟系统变量lower_case_table_names有关系.下面总结一下,有兴趣可以查看官方文档"Ide ...

  3. MySQL 修改账号的IP限制条件

    今天遇到一个需求:修改MySQL用户的权限,需要限制特定IP地址才能访问,第一次遇到这类需求,结果在测试过程,使用更新系统权限报发现出现了一些问题, 具体演示如下. 下面测试环境为MySQL 5.6. ...

  4. SQL Server -- 回忆笔记(四):case函数,索引,子查询,分页查询,视图,存储过程

    SQL Server知识点回忆篇(四):case函数,索引,子查询,分页查询,视图,存储过程 1. CASE函数(相当于C#中的Switch) then '未成年人' else '成年人' end f ...

  5. Centos6系列Bond配置方法

    在Windows Server平台因业务需求经常会用到NIC双网卡绑定,同样Linux平台下用于网络负载均衡及网络冗余会用到bond模式. Bond模式:0-6,即7种模式. 模式一:mod=0 ,即 ...

  6. 自动化测试之路3-selenium3+python3环境搭建

    1.首先安装火狐浏览器  有单独文章分享怎么安装 2.搭建python环境 安装python,安装的时候把path选好,就不用自己在配置,安装方法有单独文档分享 安装好以后cmd打开输入python查 ...

  7. Spring注解定时器使用

    一.首先要配置我们的spring-service.xml 1.xmlns 多加下面的内容 xmlns:task="http://www.springframework.org/schema/ ...

  8. identity server4 证书

    我们需要对token进行签名, 这意味着identity server需要一对public和private key. 幸运的是, 我们可以告诉identity server在程序的运行时候对这项工作进 ...

  9. sql server 压缩数据库

    收缩日志 ALTER DATABASE 数据库名称 SET RECOVERY SIMPLEDBCC SHRINKDATABASE(数据库名称, 0) 压缩数据库ALTER DATABASE 数据库名称 ...

  10. 前端——JavaScript

    何谓JavaScript?它与Java有什么关系? JavaScript与HTML.CSS组合使用应用于前端开发,JavaScript是一门独立的语言,浏览器内置了JS的解释器.它除了和Java名字长 ...