KNN是一种常见的监督学习算法,工作机制很好理解:给定测试样本,基于某种距离度量找出训练集中与其最靠近的k个训练样本,然后基于这k个“邻居”的信息来进行预测。总结一句话就是“近朱者赤,近墨者黑”。

KNN可用作分类也可用于回归,在分类任务中可使用“投票法”,即选择这k个样本中出现最多的类别标记作为测试结果;在回归任务中可使用“平均法”将这k个样本的标记平均值作为预测结果;还可以基于距离远近进行加权平均或加权投票,距离越近的样本权重越大。

       KNN和之前介绍的监督学习算法有一个很大的不同,它没有前期的训练过程,是一种“懒惰学习”的算法,只有收到测试样本后,再和训练样本进行比较处理。

       初学者容易把KNN和K-means搞混淆,虽然都有K,:-)但这是两种不同的算法,二者区别如下:

  KNN K-Means
不同点 是一种分类算法,属于监督学习的范畴,训练数据是带有label的 是一种聚类算法,属于非监督学习的范畴,训练数据没有label,杂乱无章的
没有明显的训练过程,属于lazy learning 有明确的训练过程
K的含义:与预测样本距离最近的K个样本 K的含义:K是事前人工定好的参数,假设数据集可分为K个簇
相同点 都用到了NN(nearst Neighbor)算法,一般用KD树来实现。

--KNN算法基本原理

KNN算法简单的步骤如下:

(1)计算距离:给定测试对象,计算它与训练集中每个对象的距离,空间距离的计算方法有多种,有欧式距离、夹角余弦(多在文本分类中使用)等。

(2)找邻居:圈定距离最近的k个对象,作为测试对象的近邻。

(3)做分类:根据这k个近邻归属的主要类别,对测试对象进行分类。

下面通过一个简单的示例说明下KNN算法是怎么进行分类的:

上图的蓝色方块和红色三角是已经打好label的数据,绿色圆圈是待分类的测试数据。

如果我们让K=3,那么上图实心圆圈中的两个三角和一个方块就是离测试数据最近的3个点,那么通过投票法则,测试数据会被分类为红色三角;

如果我们让K=5,那么上图虚线圆圈中的两个三角和三个方块就是离测试数据最近的5个点,通过投票法则,测试数据则会被分类为蓝色方块;

整个算法的原理是不是很简单?但实际上并没有那么简单,K如何选择?数据之间的距离怎么计算?

--K值的选择

如果K值太小,整体模型会变得复杂,容易发生过拟合,容易将一些噪声学习进来,二忽略数据的真实分布。

如果K值过大,模型会变得相对简单,可以减少学习的估计误差,但近似误差会变大,比如极端情况下K=N(N维训练样本数),则不论预测对象是什么,预测结果都将是训练集中最多的类型,这显然是一个过渡简化的模型,无法实际应用。

k值一般采用交叉验证或者Grid Search的方法确定。

--距离计算

提取数据的特征值,根据特征值组成一个n维实数向量空间(特征空间),然后计算向量之间的空间距离,如欧式距离、余弦相似度等。

对于数据,其特征空间为n维实数向量空间:

欧式距离计算公式为:

余弦相似度计算公式为:

余弦相似度的值越接近1表示其越相似,接近0表示其差异越大。余弦相似度更多应用在文本类任务中。

--代码示例

依旧以sklearn中的cancer数据集为例,做一个通过30维特征判断是否患癌症的示例,示例中数据量很少,只有569条数据,每条数据各有30个特征数值。采用sklearn中的KNN分类器,除k外都采用默认参数,距离度量采用欧式距离 。通过交叉验证法来确定最佳的K值,从下图可见,K=14时,验证准确率最高。

-Python 代码

__author__ = 'z00421185'

import pandas as pd
from sklearn import datasets
import matplotlib.pyplot as plt
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from sklearn.neighbors import KNeighborsClassifier breast_data = datasets.load_breast_cancer()
data = pd.DataFrame(datasets.load_breast_cancer().data)
data.columns = breast_data['feature_names'] data_np = breast_data['data']
target_np = breast_data['target']
print(data_np.shape) x_train, x_test, y_train, y_test = train_test_split(data_np, target_np, test_size=0.3, random_state=0) # 设定交叉验证k的范围,一般从1~样本数的开方
k_range = range(1, 24)
scores = []
for k in k_range:
knn = KNeighborsClassifier(k, metric='euclidean')
score = cross_val_score(knn, x_train, y_train, cv=10, scoring='accuracy')
scores.append(score.mean()) # 从折线图上看最佳K取值
plt.plot(k_range, scores)
plt.xlabel('K')
plt.ylabel('Accuracy')
plt.show() model = KNeighborsClassifier(n_neighbors=13)
model.fit(x_train, y_train)
y_pred = model.predict(x_test)
print(accuracy_score(y_test, y_pred))
---------------------------------
0.9649122807017544

作者:华为云专家 周捷

机器学习笔记(十)---- KNN(K Nearst Neighbor)的更多相关文章

  1. 机器学习笔记(5) KNN算法

    这篇其实应该作为机器学习的第一篇笔记的,但是在刚开始学习的时候,我还没有用博客记录笔记的打算.所以也就想到哪写到哪了. 你在网上搜索机器学习系列文章的话,大部分都是以KNN(k nearest nei ...

  2. Machine Learning for hackers读书笔记(十)KNN:推荐系统

    #一,自己写KNN df<-read.csv('G:\\dataguru\\ML_for_Hackers\\ML_for_Hackers-master\\10-Recommendations\\ ...

  3. K NEAREST NEIGHBOR 算法(knn)

    K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法和K-M ...

  4. 机器学习(四) 分类算法--K近邻算法 KNN (上)

    一.K近邻算法基础 KNN------- K近邻算法--------K-Nearest Neighbors 思想极度简单 应用数学知识少 (近乎为零) 效果好(缺点?) 可以解释机器学习算法使用过程中 ...

  5. 机器学习实战 之 KNN算法

    现在 机器学习 这么火,小编也忍不住想学习一把.注意,小编是零基础哦. 所以,第一步,推荐买一本机器学习的书,我选的是Peter harrigton 的<机器学习实战>.这本书是基于pyt ...

  6. Python机器学习笔记:sklearn库的学习

    网上有很多关于sklearn的学习教程,大部分都是简单的讲清楚某一方面,其实最好的教程就是官方文档. 官方文档地址:https://scikit-learn.org/stable/ (可是官方文档非常 ...

  7. Python机器学习笔记:不得不了解的机器学习面试知识点(1)

    机器学习岗位的面试中通常会对一些常见的机器学习算法和思想进行提问,在平时的学习过程中可能对算法的理论,注意点,区别会有一定的认识,但是这些知识可能不系统,在回答的时候未必能在短时间内答出自己的认识,因 ...

  8. K Nearest Neighbor 算法

    文章出处:http://coolshell.cn/articles/8052.html K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KN ...

  9. Python机器学习笔记:不得不了解的机器学习知识点(2)

    之前一篇笔记: Python机器学习笔记:不得不了解的机器学习知识点(1) 1,什么样的资料集不适合用深度学习? 数据集太小,数据样本不足时,深度学习相对其它机器学习算法,没有明显优势. 数据集没有局 ...

随机推荐

  1. 「考试」 Or

    不得不说是一道多项式神题了. 虽然说颓代码颓的很厉害不过最终A掉了. 好好讲一讲这道题. 涉及的知识点是:高阶导数,NTT,指数型母函数,泰勒公式,以及意志力和数学推导能力. 那就开始了. 一个测试点 ...

  2. CSPS模拟 52

    我貌似曾经说过我是个只会做水题的巨型辣鸡.. 这次证明我水题都不会做.. T1 平均数 区间数$n^2$ 枚举是不可能了 可是好像没有无用的计算量.. 刚想到这里,此时开考15min 看见天皇比手势说 ...

  3. uboot启动完成,kernel启动时lcd屏幕出现杂色解决办法

    先说说开发环境吧: 1 内核:linux2.6.xx 2 uboot:买开发板带的 注释:在最后我又添加了问题得到完美解决的办法. 问题:uboot启动完成,kernel启动时lcd屏幕出现杂色(比如 ...

  4. css3的过渡和动画的属性介绍

    一.过渡 什么是过渡? 过渡是指:某元素的css属性值在一段时间内,平滑过渡到另外一个值,过渡主要观察的是过程和结果. 设置能够过渡的属性: 支持过渡的样式属性,颜色的属性,取值为数值,transfo ...

  5. js调用浏览器“打印”与“打印预览”

    用到html <object>标签,具体做法如下: 1.在html文档任意位置添加<object>标签: <div style="border: 1px sol ...

  6. C# II: Class ViewModelBase and RelayCommand in MVVM

    好久不写WPF和MVVM,新建一个Project后,想起来ViewModelBase和RelayCommand没有.以下Code摘自MSDN上的Article:Patterns - WPF Apps ...

  7. ios input输入不了

    在项目中遇到了一个问题就是input输入框在安卓可以输入,而在ios输入不了 经过百度,调试发现,在ios中input默认是有user-select: none;属性把input输入框禁用了,将其删除 ...

  8. nyoj 111-分数加减法 (gcd, switch, 模拟,数学)

    111-分数加减法 内存限制:64MB 时间限制:1000ms 特判: No 通过数:20 提交数:54 难度:2 题目描述: 编写一个C程序,实现两个分数的加减法 输入描述: 输入包含多行数据 每行 ...

  9. nyoj 324-猴子吃桃问题 (m[i] = (m[i-1] + 1) * 2)

    324-猴子吃桃问题 内存限制:64MB 时间限制:3000ms 特判: No 通过数:20 提交数:21 难度:0 题目描述: 有一堆桃子不知数目,猴子第一天吃掉一半,又多吃了一个,第二天照此方法, ...

  10. go中的关键字-defer

    1. defer的使用 defer 延迟调用.我们先来看一下,有defer关键字的代码执行顺序: func main() { defer func() { fmt.Println("1号输出 ...