主要内容:

一.算法概述

二.距离度量

三.k值的选择

四.分类决策规则

五.利用KNN对约会对象进行分类

六.利用KNN构建手写识别系统

七.KNN之线性扫描法的不足

八.KD树

一.算法概述

1.k近邻算法,简而言之,就是选取k个与输入点的特征距离最近的数据点中出现最多的一种分类,作为输入点的类别。

2.如下面一个例子,有六部电影,可用“打斗镜头”和“接吻镜头”作为每一部电影的特征值,且已知每一部电影的类别,即“爱情片”还是“动作片”。此外,还有一部电影,只知道其特征,但不知道其类别。如下:

为了方便研究,可以将其放到二维平面上:

为了得出?的类别,可以选择与之距离最近的k部电影,然后将这k部电影中出现次数最多的类别作为该部电影的类别。

?与每一部电影的距离为:

假如选取k为3,而前面3部电影的类别均为爱情片,所以可以认为?的类别为爱情片。

3.通过例子可以看出,KNN算法的三个基本要素为:距离度量k值的选择分类决策规则,下面将一一讲解。

二.距离度量

1.特征空间中两个实例点的距离反应了两个实例点的相似程度,k近邻模型的特征空间是n维的实数向量空间。其中使用的距离是欧式距离,即我们平常所说的“直线距离”,但也可以是其他距离。或者可以归于一个类别,即Lp距离。其基本介绍如下:

三.k值的选择

从直觉上可得出:k值的选择对模型的有效性影响很大。

1.如果k值选得比较小,那么预测结果会对临近的点十分敏感。假如附近的点刚好是噪声,那么预测结果就会出错。总体而言,容易发生过拟合。

2.假如k值选得比较大,那么预测结果就很容易受到数量大的类别的干扰,特别地,当k=N时,那么类别就永远为数量最大的那个类别,算法就没有意义的。

3.综上,k过大或者过小,预测结果都可能变得糟糕。所以可以通过交叉验证法来选取最优值k。

思考:在选取了k个最近点之后,每个点对于预测结果的影响所占的权值都是一样的,即都是“一票”,但可不可以设置权值:越靠近的点权值越大呢?这样做会不会好一点?不过这个问题好像归类于下面一节的。

四.分类决策规则

分类决策规则,即得到k个最近点之后,通过什么方式去决定最终的分类。从直觉上可感觉到选取数量最多的那个类别作为输入点的类别或许是比较合理的。下面是具体的数学解释:

五.利用KNN对约会对象进行分类

海伦最近在约会网站上寻找适合自己的约会对象。经过一番总结,她将约会对象分为三种类别:

...不喜欢的人

...魅力一般的人

...极具魅力的人

此外,每个约会对象还有三种特征,分别是:

...每年获得的飞行常客里程数

...玩视频游戏所消耗时间百分比

...每周消费的冰淇淋公升数

为了帮助海伦预测她没有约会过的对象属于那种类别,我们需要根据已有的数据(即已经约会过的对象),利用KNN算法来构建一个预测系统。

基本流程如下:

Python代码:

 # coding:utf-8
from numpy import *
import operator
from os import listdir def file2matrix(filename): #从文件中提取数据
fr = open(filename)
numberOfLines = len(fr.readlines()) #数据的条数
returnMat = zeros((numberOfLines,3)) #特征数组X
classLabelVector = [] #每条数据对应的分类Y
fr = open(filename)
index = 0
for line in fr.readlines(): #读取每一条数据
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3] #读取特征x
classLabelVector.append(int(listFromLine[-1])) #读取分类y
index += 1
return returnMat,classLabelVector #返回特征数组X和分类数组Y def autoNorm(dataSet): #特征归一化,作用是:使得每个特征的权重相等。范围[0,1]
minVals = dataSet.min(0)
maxVals = dataSet.max(0)
ranges = maxVals - minVals
normDataSet = zeros(shape(dataSet))
m = dataSet.shape[0]
normDataSet = dataSet - tile(minVals, (m,1))
normDataSet = normDataSet/tile(ranges, (m,1))
return normDataSet, ranges, minVals #返回归一化矩阵、范围、最小值 def classify0(inX, dataSet, labels, k): #使用KNN进行分类
dataSetSize = dataSet.shape[0] #训练数据集的大小
diffMat = tile(inX, (dataSetSize,1)) - dataSet #从此步起到第四步为计算欧氏距离
sqDiffMat = diffMat**2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances**0.5
sortedDistIndicies = distances.argsort() #对距离进行排序,得到的是排序后的下标,而不是数据本身
classCount={} #记录k近邻中每种类别出现的次数
for i in range(k): #枚举k近邻
voteIlabel = labels[sortedDistIndicies[i]] #获取该数据点的类别
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1 #累加
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) #排序
return sortedClassCount[0][0] #返回数量最多的类别 def datingClassTest(): #使用KNN对约会对象进行分类的测试
hoRatio = 0.50 #用于测试的数据所占的比例
datingDataMat,datingLabels = file2matrix('datingTestSet2.txt') #读取数据
normMat, ranges, minVals = autoNorm(datingDataMat) #特征归一化
m = normMat.shape[0] #数据总量:训练数据+测试数据
numTestVecs = int(m*hoRatio) #训练数据的总量
errorCount = 0.0 #分类错误的总数
for i in range(numTestVecs): #利用KNN为每个测试数据进行分类
classifierResult = classify0(normMat[i,:],normMat[numTestVecs:m,:],datingLabels[numTestVecs:m],3) #得到分类结果
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, datingLabels[i]) #输出结果
if (classifierResult != datingLabels[i]): errorCount += 1.0 #如果分类错误,则累计
print "the total error rate is: %f" % (errorCount/float(numTestVecs)) #最后计算错误率
print "the total error count is: %d"%errorCount if __name__ == "__main__":
datingClassTest()

运行结果如下:

错误率为6.4%,效果还是挺好的。

六.利用KNN构建手写识别系统

KNN算法还可用于识别手写字。为了方便,这里构造的识别系统自能识别0~9的的数字。

首先,我们可以将手写字投影到一个矩阵中,有墨水的地方就设为1,空白的地方设为0,如图:

(分别是:9 2 7)

这是一个32*32的矩阵,我们将其转换为1*1024的一维向量以方便操作。之后,就可以利用KNN进行识别了,这里选取的k为3。

Python代码:

 # coding:utf-8
from numpy import *
import operator
from os import listdir def file2matrix(filename): # 从文件中提取数据
fr = open(filename)
numberOfLines = len(fr.readlines()) # 数据的条数
returnMat = zeros((numberOfLines, 3)) # 特征数组X
classLabelVector = [] # 每条数据对应的分类Y
fr = open(filename)
index = 0
for line in fr.readlines(): # 读取每一条数据
line = line.strip()
listFromLine = line.split('\t')
returnMat[index, :] = listFromLine[0:3] # 读取特征x
classLabelVector.append(int(listFromLine[-1])) # 读取分类y
index += 1
return returnMat, classLabelVector # 返回特征数组X和分类数组Y def classify0(inX, dataSet, labels, k): # 使用KNN进行分类
dataSetSize = dataSet.shape[0] # 训练数据集的大小
diffMat = tile(inX, (dataSetSize, 1)) - dataSet # 从此步起到第四步为计算欧氏距离
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort() # 对距离进行排序,得到的是排序后的下标,而不是数据本身
classCount = {} # 记录k近邻中每种类别出现的次数
for i in range(k): # 枚举k近邻
voteIlabel = labels[sortedDistIndicies[i]] # 获取该数据点的类别
classCount[voteIlabel] = classCount.get(voteIlabel, 0) + 1 # 累加
sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True) # 排序
return sortedClassCount[0][0] # 返回数量最多的类别 def img2vector(filename): # 将32*32的二维数组转换成1*1024的一维数组
returnVect = zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32 * i + j] = int(lineStr[j])
return returnVect def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('trainingDigits') # 读取训练数据
m = len(trainingFileList) # m为数据的条数
trainingMat = zeros((m, 1024)) # 特征矩阵X
for i in range(m):
fileNameStr = trainingFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
hwLabels.append(classNumStr) # 读取类别y
trainingMat[i, :] = img2vector('trainingDigits/%s' % fileNameStr) # 读取特征x
testFileList = listdir('testDigits') # 读取测试数据
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest):
fileNameStr = testFileList[i]
fileStr = fileNameStr.split('.')[0]
classNumStr = int(fileStr.split('_')[0])
vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) # 读取特征x
classifierResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) # 利用KNN进行分类
print "the classifier came back with: %d, the real answer is: %d" % (classifierResult, classNumStr)
if (classifierResult != classNumStr): errorCount += 1.0 # 如果分类错误,则累加
print "\nthe total number of errors is: %d" % errorCount
print "\nthe total error rate is: %f" % (errorCount / float(mTest)) # 最后输出错误率 if __name__ == "__main__":
handwritingClassTest()

运行结果如下:

七.KNN之线性扫描法的不足

KNN最简单的实现方法就是线性扫描。但是,该做法需要求出输入点与每个训练点的距离,且还需要进行排序、统计。假如训练集很大,且特征的维度很高,那么计算量将会变得十分庞大,这时,线性扫描法将不可行。为了提高k近邻的搜索效率,可以使用特殊的数据结构来存储训练集,以减少计算距离的次数,于是就引入了KD树。下一篇博客进行详细介绍。

八.KD树

1.KD树的构造

例子:

2.搜索KD树

例子:

3.KD树Python代码实现(来自《机器学习—K近邻,KD树算法python实现》

代码:

 # -*- coding: utf-8 -*-
"""
Created on Thu Dec 14 17:46:52 2017 @author: Q
"""
import numpy as np
import matplotlib.pyplot as plt def createKDTree(dataSet,depth): #构造kd树
n = np.shape(dataSet)[0]
if n == 0: #列表为空,则返回空值
return None treeNode = {} #当前节点
n, m = np.shape(dataSet) #n为实例点的个数,m为维度
split_axis = depth % m #轮流选取特征,作为空间切割的依据
treeNode['split'] = split_axis #记录切割空间的特征
dataSet = sorted(dataSet, key=lambda a: a[split_axis]) #在选取特征数对实例点进行排序
num = n // 2
treeNode['median'] = dataSet[num] #选取特征是中位数的实例点作为该节点
treeNode['left'] = createKDTree(dataSet[:num], depth + 1) #递归左右子树继续进行切割空间、构造kd树
treeNode['right'] = createKDTree(dataSet[num + 1:], depth + 1)
return treeNode def searchTree(tree,point): #在KD树中搜索point的最近邻
k = len(point) #k为维度
if tree is None: #如果当前节点为空,则直接返回“距离无限大”表示不可能
return [0]*k, float('inf') '''在切割特征上,根据大小进入相应的子树'''
split_axis = tree['split'] #获取切割特征
median_point = tree['median'] #获取该节点的实例点
if point[split_axis] <= median_point[split_axis]: #在切割特征上,根据大小进入相应的子树
nearestPoint,nearestDistance = searchTree(tree['left'],point)
else:
nearestPoint,nearestDistance = searchTree(tree['right'],point)
nowDistance = np.linalg.norm(point-median_point) #计算point与当前实例点的距离
if nowDistance < nearestDistance: #如果两者距离小于最近距离,则更新
nearestDistance = nowDistance
nearestPoint = median_point.copy() '''检测最近点是否可能出现在另外一颗子树所表示的超平面'''
splitDistance = abs(point[split_axis] - median_point[split_axis]) #计算point与另一个子树所表示的超平面的距离
if splitDistance > nearestDistance: #如果两者距离小于当前的最近距离,则最近点必定不可能落在另一棵子树所表示的平面上,直接返回
return nearestPoint,nearestDistance
else: #否则,最近点有可能落在另一棵子树所表示的平面上,继续搜索
if point[split_axis] <= median_point[split_axis]:
nextTree = tree['right']
else:
nextTree = tree['left']
nearPoint,nearDistanc = searchTree(nextTree,point) #进入另一棵子树继续搜索
if nearDistanc < nearestDistance: #更新
nearestDistance = nearDistanc
nearestPoint = nearPoint.copy()
return nearestPoint,nearestDistance #返回当前结果 def loadData(fileName):
dataSet = []
with open(fileName) as fd:
for line in fd.readlines():
data = line.strip().split()
data = [float(item) for item in data]
dataSet.append(data)
dataSet = np.array(dataSet)
label = dataSet[:,2]
dataSet = dataSet[:,:2]
return dataSet,label if __name__ == "__main__":
'''加载数据,并绘制离散图'''
dataSet,label = loadData('testSet.txt')
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(dataSet[:,0],dataSet[:,1],c = label,cmap = plt.cm.Paired)
'''构造KD树'''
tree = createKDTree(dataSet, 0)
'''搜索最近邻'''
point = [3,9.8]
nearpoint,neardis = searchTree(tree,point)
'''将结果标示于离散图上'''
ax.scatter(point[0],point[1],c = 'g',s=50)
ax.scatter(nearpoint[0],nearpoint[1],c = 'r',s=50)
plt.show()

训练数据:

-0.017612    14.053064    0
-1.395634 4.662541 1
-0.752157 6.538620 0
-1.322371 7.152853 0
0.423363 11.054677 0
0.406704 7.067335 1
0.667394 12.741452 0
-2.460150 6.866805 1
0.569411 9.548755 0
-0.026632 10.427743 0
0.850433 6.920334 1
1.347183 13.175500 0
1.176813 3.167020 1
-1.781871 9.097953 0
-0.566606 5.749003 1
0.931635 1.589505 1
-0.024205 6.151823 1
-0.036453 2.690988 1
-0.196949 0.444165 1
1.014459 5.754399 1
1.985298 3.230619 1
-1.693453 -0.557540 1
-0.576525 11.778922 0
-0.346811 -1.678730 1
-2.124484 2.672471 1
1.217916 9.597015 0
-0.733928 9.098687 0
-3.642001 -1.618087 1
0.315985 3.523953 1
1.416614 9.619232 0
-0.386323 3.989286 1
0.556921 8.294984 1
1.224863 11.587360 0
-1.347803 -2.406051 1
1.196604 4.951851 1
0.275221 9.543647 0
0.470575 9.332488 0
-1.889567 9.542662 0
-1.527893 12.150579 0
-1.185247 11.309318 0
-0.445678 3.297303 1
1.042222 6.105155 1
-0.618787 10.320986 0
1.152083 0.548467 1
0.828534 2.676045 1
-1.237728 10.549033 0
-0.683565 -2.166125 1
0.229456 5.921938 1
-0.959885 11.555336 0
0.492911 10.993324 0
0.184992 8.721488 0
-0.355715 10.325976 0
-0.397822 8.058397 0
0.824839 13.730343 0
1.507278 5.027866 1
0.099671 6.835839 1
-0.344008 10.717485 0
1.785928 7.718645 1
-0.918801 11.560217 0
-0.364009 4.747300 1
-0.841722 4.119083 1
0.490426 1.960539 1
-0.007194 9.075792 0
0.356107 12.447863 0
0.342578 12.281162 0
-0.810823 -1.466018 1
2.530777 6.476801 1
1.296683 11.607559 0
0.475487 12.040035 0
-0.783277 11.009725 0
0.074798 11.023650 0
-1.337472 0.468339 1
-0.102781 13.763651 0
-0.147324 2.874846 1
0.518389 9.887035 0
1.015399 7.571882 0
-1.658086 -0.027255 1
1.319944 2.171228 1
2.056216 5.019981 1
-0.851633 4.375691 1
-1.510047 6.061992 0
-1.076637 -3.181888 1
1.821096 10.283990 0
3.010150 8.401766 1
-1.099458 1.688274 1
-0.834872 -1.733869 1
-0.846637 3.849075 1
1.400102 12.628781 0
1.752842 5.468166 1
0.078557 0.059736 1
0.089392 -0.715300 1
1.825662 12.693808 0
0.197445 9.744638 0
0.126117 0.922311 1
-0.679797 1.220530 1
0.677983 2.556666 1
0.761349 10.693862 0
-2.168791 0.143632 1
1.388610 9.341997 0
0.317029 14.739025 0

《机器学习实战》学习笔记第二章 —— K-近邻算法的更多相关文章

  1. 《机实战》第2章 K近邻算法实战(KNN)

    1.准备:使用Python导入数据 1.创建kNN.py文件,并在其中增加下面的代码: from numpy import * #导入科学计算包 import operator #运算符模块,k近邻算 ...

  2. 《机器学习实战》---第二章 k近邻算法 kNN

    下面的代码是在python3中运行, # -*- coding: utf-8 -*- """ Created on Tue Jul 3 17:29:27 2018 @au ...

  3. 【机器学习实战学习笔记(1-1)】k-近邻算法原理及python实现

    笔者本人是个初入机器学习的小白,主要是想把学习过程中的大概知识和自己的一些经验写下来跟大家分享,也可以加强自己的记忆,有不足的地方还望小伙伴们批评指正,点赞评论走起来~ 文章目录 1.k-近邻算法概述 ...

  4. 【机器学习实战学习笔记(1-2)】k-近邻算法应用实例python代码

    文章目录 1.改进约会网站匹配效果 1.1 准备数据:从文本文件中解析数据 1.2 分析数据:使用Matplotlib创建散点图 1.3 准备数据:归一化特征 1.4 测试算法:作为完整程序验证分类器 ...

  5. 《DOM Scripting》学习笔记-——第二章 js语法

    <Dom Scripting>学习笔记 第二章 Javascript语法 本章内容: 1.语句. 2.变量和数组. 3.运算符. 4.条件语句和循环语句. 5.函数和对象. 语句(stat ...

  6. The Road to learn React书籍学习笔记(第二章)

    The Road to learn React书籍学习笔记(第二章) 组件的内部状态 组件的内部状态也称为局部状态,允许保存.修改和删除在组件内部的属性,使用ES6类组件可以在构造函数中初始化组件的状 ...

  7. [HeadFrist-HTMLCSS学习笔记]第二章深入了解超文本:认识HTML中的“HT”

    [HeadFrist-HTMLCSS学习笔记]第二章深入了解超文本:认识HTML中的"HT" 敲黑板!!! 创建HTML超链接 <a>链接文本(此处会有下划线,可以单击 ...

  8. 机器学习实战 - 读书笔记(07) - 利用AdaBoost元算法提高分类性能

    前言 最近在看Peter Harrington写的"机器学习实战",这是我的学习笔记,这次是第7章 - 利用AdaBoost元算法提高分类性能. 核心思想 在使用某个特定的算法是, ...

  9. 【机器学习实战学习笔记(2-2)】决策树python3.6实现及简单应用

    文章目录 1.ID3及C4.5算法基础 1.1 计算香农熵 1.2 按照给定特征划分数据集 1.3 选择最优特征 1.4 多数表决实现 2.基于ID3.C4.5生成算法创建决策树 3.使用决策树进行分 ...

随机推荐

  1. IDEA破解 2017 IDEA license server 激活(可用)

    进入ide主页面,help-register-license server,然后输入 http://idea.iteblog.com/key.PHP(注意:php要小写)即可~

  2. 【SharePoint】SharePoint2013中使用客户端对象模型给用户控件赋初值

    本文要实现的功能:新建一条列表记录,打开新建记录画面时,自动给[申请人]赋值为当前登录用户. 在SharePoint2010中,可以使用SPServices的SPFindPeoplePicker方法来 ...

  3. ExtJS4 自己主动生成控制grid的列显示、隐藏的checkbox

    因为某种原因.须要做一个控制grid列显示的checkboxgroup,尽管EXTJS4中的gridpanel自带列表能够来控制列的显示隐藏,可是有这种需求(须要一目了然) 以下先上图 waterma ...

  4. binary-tree-zigzag-level-order-traversal——二叉树分层输出

    Given a binary tree, return the zigzag level order traversal of its nodes' values. (ie, from left to ...

  5. CentOS6.5下Oracle11G-R2安装、卸载

    CentOS6.5下Oracle11G-R2安装.卸载 资源下载地址(包含本人全部安装过程中,系统备份文件):http://download.csdn.net/detail/attagain/7700 ...

  6. git是一种分布式代码管理工具,git通过树的形式记录文件的更改历史,比如: base'<--base<--A<--A' ^ | --- B<--B' 小米工程师常常需要寻找两个分支最近的分割点,即base.假设git 树是多叉树,请实现一个算法,计算git树上任意两点的最近分割点。 (假设git树节点数为n,用邻接矩阵的形式表示git树:字符串数组matrix包含n个字符串,每个字符串由字符'0

    // ConsoleApplication10.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include <iostream& ...

  7. Tsung 初步介绍安装

    tsung是erlang的一个开源的一个压力测试工具,可以测试包括HTTP, WebDAV, Mysql, PostgreSQL, LDAP, and XMPP/Jabber等服务器.针对 HTTP ...

  8. linux crontab 定时任务解析

    -----------crontab定时任务---------------------- 检查crontab工具是否安装 crontab -l 检查crontab服务是否启动 service cron ...

  9. ssh无密码登陆屌丝指南

    [0]写在前面 由于ssh 实现的是免密码登陆,大致步骤是: 0.1) client通过ssh登陆到server: 0.2) server检查家目录下的.ssh文件, 并发送公钥文件 authoriz ...

  10. Android锁屏状态下弹出activity,如新版qq的锁屏消息提示

    在接收消息广播的onReceive里,跳转到你要显示的界面.如: Intent intent = new Intent(arg0,MainActivity.class); intent.addFlag ...