1. K近邻算法(KNN)

2. KNN和KdTree算法实现

1. 前言

KNN一直是一个机器学习入门需要接触的第一个算法,它有着简单,易懂,可操作性强的一些特点。今天我久带领大家先看看sklearn中KNN的使用,在带领大家实现出自己的KNN算法。

2. KNN在sklearn中的使用

knn在sklearn中是放在sklearn.neighbors的包中的,我们今天主要介绍KNeighborsClassifier的分类器。

KNeighborsClassifier的主要参数是:

参数 意义
n_neighbors K值的选择与样本分布有关,一般选择一个较小的K值,可以通过交叉验证来选择一个比较优的K值,默认值是5
weights ‘uniform’是每个点权重一样,‘distance’则权重和距离成反比例,即距离预测目标更近的近邻具有更高的权重
algorithm ‘brute’对应第一种蛮力实现,‘kd_tree’对应第二种KD树实现,‘ball_tree’对应第三种的球树实现, ‘auto’则会在上面三种算法中做权衡,选择一个拟合最好的最优算法。
leaf_size 这个值控制了使用KD树或者球树时, 停止建子树的叶子节点数量的阈值。
metric K近邻法和限定半径最近邻法类可以使用的距离度量较多,一般来说默认的欧式距离(即p=2的闵可夫斯基距离)就可以满足我们的需求。
p p是使用距离度量参数 metric 附属参数,只用于闵可夫斯基距离和带权重闵可夫斯基距离中p值的选择,p=1为曼哈顿距离, p=2为欧式距离。默认为2

我个人认为这些个参数,比较重要的应该属n_neighbors、weights了,其他默认的也都没太大问题。

3. KNN基础版实现

直接看代码如下,完整代码GitHub

def fit(self, X_train, y_train):
self.X_train = X_train
self.y_train = y_train def predict(self, X):
# 取出n个点
knn_list = []
for i in range(self.n):
dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
knn_list.append((dist, self.y_train[i])) for i in range(self.n, len(self.X_train)):
max_index = knn_list.index(max(knn_list, key=lambda x: x[0]))
dist = np.linalg.norm(X - self.X_train[i], ord=self.p)
if knn_list[max_index][0] > dist:
knn_list[max_index] = (dist, self.y_train[i]) # 统计
knn = [k[-1] for k in knn_list]
return Counter(knn).most_common()[0][0]

我的接口设计都是按照sklearn的样子设计的,fit方法其实主要用来接收参数了,没有进行任何的处理。所有的操作都在predict中,着就会导致,我们对每个点预测的时候,时间消耗比较大。这个基础版本大家看看就好,我想大家自己去写,肯定也没问题的。

4. KdTree版本实现

kd树算法包括三步,第一步是建树,第二部是搜索最近邻,最后一步是预测。

4.1 构建kd树

kd树是一种对n维空间的实例点进行存储,以便对其进行快速检索的树形结构。kd树是二叉树,构造kd树相当于不断的用垂直于坐标轴的超平面将n维空间进行划分,构成一系列的n维超矩阵区域。

下面的流程图更加清晰的描述了kd树的构建过程:

kdtree树的生成代码:

# 建立kdtree
def create(self, dataSet, label, depth=0):
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n
axis = depth % self.n
mid = int(m / 2)
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
node = Node(dataSetcopy[mid], label[mid], depth)
if depth == 0:
self.KdTree = node
node.lchild = self.create(dataSetcopy[:mid], label, depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1)
return node
return None

4.2 kd树搜索最近邻和预测

当我们生成kd树以后,就可以去预测测试集里面的样本目标点了。预测的过程如下:

  1. 对于一个目标点,我们首先在kd树里面找到包含目标点的叶子节点。以目标点为圆心,以目标点到叶子节点样本实例的距离为半径,得到一个超球体,最近邻的点一定在这个超球体内部。
  2. 然后返回叶子节点的父节点,检查另一个子节点包含的超矩形体是否和超球体相交,如果相交就到这个子节点寻找是否有更加近的近邻,有的话就更新最近邻,并且更新超球体。如果不相交那就简单了,我们直接返回父节点的父节点,在另一个子树继续搜索最近邻。
  3. 当回溯到根节点时,算法结束,此时保存的最近邻节点就是最终的最近邻。

    kdtree树的搜索代码:
# 搜索kdtree的前count个近的点
def search(self, x, count = 1):
nearest = []
for i in range(count):
nearest.append([-1, None])
# 初始化n个点,nearest是按照距离递减的方式
self.nearest = np.array(nearest) def recurve(node):
if node is not None:
# 计算当前点的维度axis
axis = node.depth % self.n
# 计算测试点和当前点在axis维度上的差
daxis = x[axis] - node.data[axis]
# 如果小于进左子树,大于进右子树
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 计算预测点x到当前点的距离dist
dist = np.sqrt(np.sum(np.square(x - node.data)))
for i, d in enumerate(self.nearest):
# 如果有比现在最近的n个点更近的点,更新最近的点
if d[0] < 0 or dist < d[0]:
# 插入第i个位置的点
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
# 删除最后一个多出来的点
self.nearest = self.nearest[:-1]
break # 统计距离为-1的个数n
n = list(self.nearest[:, 0]).count(-1)
'''
self.nearest[-n-1, 0]是当前nearest中已经有的最近点中,距离最大的点。
self.nearest[-n-1, 0] > abs(daxis)代表以x为圆心,self.nearest[-n-1, 0]为半径的圆与axis
相交,说明在左右子树里面有比self.nearest[-n-1, 0]更近的点
'''
if self.nearest[-n-1, 0] > abs(daxis):
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild) recurve(self.KdTree) # nodeList是最近n个点的
nodeList = self.nearest[:, 1] # knn是n个点的标签
knn = [node.label for node in nodeList]
return self.nearest[:, 1], Counter(knn).most_common()[0][0]

这段代码其实比较好的实现了上面搜索的思想。如果读者对递归的过程想不太清楚,可以画下图,或者debug下我完整的代码GitHub

5. 总结

本文实现了KNN的基础版和KdTree版本,读者可以去尝试下ballTree的版本,理论上效率比KdTree还要好一些。

2. KNN和KdTree算法实现的更多相关文章

  1. Kd-Tree算法原理和开源实现代码

    本文介绍一种用于高维空间中的高速近期邻和近似近期邻查找技术--Kd-Tree(Kd树). Kd-Tree,即K-dimensional tree,是一种高维索引树形数据结构,经常使用于在大规模的高维数 ...

  2. KNN邻近分类算法

    K邻近(k-Nearest Neighbor,KNN)分类算法是最简单的机器学习算法了.它采用测量不同特征值之间的距离方法进行分类.它的思想很简单:计算一个点A与其他所有点之间的距离,取出与该点最近的 ...

  3. [机器学习] ——KNN K-最邻近算法

    KNN分类算法,是理论上比较成熟的方法,也是最简单的机器学习算法之一. 该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别 ...

  4. Kd-tree算法原理

    参考资料: Kd Tree算法原理 Kd-Tree,即K-dimensional tree,是一棵二叉树,树中存储的是一些K维数据.在一个K维数据集合上构建一棵Kd-Tree代表了对该K维数据集合构成 ...

  5. 基本分类方法——KNN(K近邻)算法

    在这篇文章 http://www.cnblogs.com/charlesblc/p/6193867.html 讲SVM的过程中,提到了KNN算法.有点熟悉,上网一查,居然就是K近邻算法,机器学习的入门 ...

  6. KNN及其改进算法的python实现

    一. 马氏距离 我们熟悉的欧氏距离虽然很有用,但也有明显的缺点.它将样品的不同属性(即各指标或各变量)之间的差别等同看待,这一点有时不能满足实际要求.例如,在教育研究中,经常遇到对人的分析和判别,个体 ...

  7. k近邻法(KNN)和KMeans算法

    k近邻算法(KNN): 三要素:k值的选择,距离的度量和分类决策规则 KMeans算法,是一种无监督学习聚类方法: 通过上述过程可以看出,和EM算法非常类似.一个简单例子, k=2: 畸变函数(dis ...

  8. [机器学习笔记]kNN进邻算法

    K-近邻算法 一.算法概述 (1)采用测量不同特征值之间的距离方法进行分类 优点: 精度高.对异常值不敏感.无数据输入假定. 缺点: 计算复杂度高.空间复杂度高. (2)KNN模型的三个要素 kNN算 ...

  9. kNN进邻算法

    一.算法概述 (1)采用测量不同特征值之间的距离方法进行分类 优点: 精度高.对异常值不敏感.无数据输入假定. 缺点: 计算复杂度高.空间复杂度高. (2)KNN模型的三个要素 kNN算法模型实际上就 ...

随机推荐

  1. Docker容器相互访问

    原文地址:https://blog.csdn.net/subfate/article/details/81396532?utm_source=copy 很多时候,同一台机器上,需要运行多个docker ...

  2. ASP.NET Core Linux环境安装并运行项目

    原文地址:https://blog.csdn.net/u014368040/article/details/79192622 一 安装环境 1.  从微软官网下载 Linux版本的.NetCoreSd ...

  3. golang学习 ---并发获取多个URL

    package main import ( "fmt" "io" "io/ioutil" "net/http" &quo ...

  4. MySQL子查询的优化

    本文基于MySQL5.7.19测试 创建四张表,pt1.pt2表加上主键 mysql> create table t1 (a1 int, b1 int); mysql> create ta ...

  5. CentOS 7.3 系统安装配置图解教程

    一.安装CentOS 7.3 CentOS 7.x系列只有64位系统,没有32位.生产服务器建议安装CentOS-7-x86_64-Minimal-1611.iso版本 成功引导系统后,会出现下面的界 ...

  6. windows 中安装及使用 SSH Key

    转自 简书技术博客:https://www.jianshu.com/p/a3b4f61d4747 联系管理员开通ssh功能: 重新创建环境: 下载工具包到本地机器wsCli 0.4 解压后,把相应的w ...

  7. Java – Stream has already been operated upon or closed

    Java – Stream has already been operated upon or closed package com.mkyong.java8; import java.util.Ar ...

  8. java打印条形码Code128C

    生成编码类型为Code128C的条形码的javaCODE: package test; import java.awt.Color; import java.awt.Graphics; import ...

  9. spring+mybatis的插件【shardbatis2.0】+mysql+java自定义注解实现分表

    一.业务场景分析 只有大表才需要分表,而且这个大表还会有经常需要读的需要,即使经过sql服务器优化和sql调优,查询也会非常慢.例如共享汽车的定位数据表等. 二.实现步骤 1.准备pom依赖 < ...

  10. 在 Visual Studio 2017 中找回消失的“在浏览器中查看”命令

    不知为何,在新安装 Visual Studio 2017 后,发现所有 Web 项目上右键菜单的"在浏览器中查看"命令消失了,只能以调试模式启动网站,非常别扭. 最后在 Stack ...