KNN算法介绍及源码实现
一.KNN算法介绍
邻近算法,或者说K最邻近(KNN,K-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K最近邻,就是K个最近的邻居的意思,说的是每个样本都可以用它最接近的K个邻近值来代表。近邻算法就是将数据集合中每一个记录进行分类的方法 。
k近邻法是一种基本的分类和回归方法,是监督学习方法里的一种常用方法。k近邻算法假设给定一个训练数据集,其中的实例类别已定。分类时,对新的实例,根据其k个最近邻的训练实例类别,通过多数表决等方式进行预测。
二.KNN算法核心思想
KNN算法的核心思想是,如果一个样本在特征空间中的K个最相邻的样本中的大多数属于某一个类别,则该样本也属于这个类别,并具有这个类别上样本的特性。该方法在确定分类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。KNN方法在类别决策时,只与极少量的相邻样本有关。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
三.KNN算法三要素
k近邻法三要素:距离度量、k值的选择和分类决策规则。常用的距离度量是欧氏距离及更一般的pL距离。k值小时,k近邻模型更复杂,容易发生过拟合;k值大时,k近邻模型更简单,又容易欠拟合。因此k值得选择会对分类结果产生重大影响。k值的选择反映了对近似误差与估计误差之间的权衡,通常由交叉验证选择最优的k。
四.KNN算法优缺点
优点
1.简单,易于理解,易于实现,无需估计参数,无需训练;
2. 适合对稀有事件进行分类;
3.特别适合于多分类问题(multi-modal,对象具有多个类别标签), kNN比SVM的表现要好。缺点
1.该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数 。
2.该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点 。
五.源码简单实现
1.导包
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score
2.初始化kNN分类器
class KNNClassifier:
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._X_train = None
self._y_train = None
3.训练数据集
def fit(self, X_train, y_train):
"""根据训练数据集X_train和y_train训练kNN分类器"""
assert X_train.shape[0] == y_train.shape[0], \
"the size of X_train must be equal to the size of y_train"
assert self.k <= X_train.shape[0], \
"the size of X_train must be at least k."
self._X_train = X_train
self._y_train = y_train
return self
4.多个值的预测
def predict(self, X_predict):
"""给定待预测数据集X_predict,返回表示X_predict的结果向量"""
assert self._X_train is not None and self._y_train is not None, \
"must fit before predict!"
assert X_predict.shape[1] == self._X_train.shape[1], \
"the feature number of X_predict must be equal to X_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict)
5.单个值的预测
def _predict(self, x):
"""给定单个待预测数据x,返回x的预测结果值"""
assert x.shape[0] == self._X_train.shape[1], \
"the feature number of x must be equal to X_train"
distances = [sqrt(np.sum((x_train - x) ** 2))
for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0]
6.模型准确度
def score(self, X_test, y_test):
"""根据测试数据集 X_test 和 y_test 确定当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict)
六.全部代码
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score
class KNNClassifier:
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._X_train = None
self._y_train = None
def fit(self, X_train, y_train):
"""根据训练数据集X_train和y_train训练kNN分类器"""
assert X_train.shape[0] == y_train.shape[0], \
"the size of X_train must be equal to the size of y_train"
assert self.k <= X_train.shape[0], \
"the size of X_train must be at least k."
self._X_train = X_train
self._y_train = y_train
return self
def predict(self, X_predict):
"""给定待预测数据集X_predict,返回表示X_predict的结果向量"""
assert self._X_train is not None and self._y_train is not None, \
"must fit before predict!"
assert X_predict.shape[1] == self._X_train.shape[1], \
"the feature number of X_predict must be equal to X_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict)
def _predict(self, x):
"""给定单个待预测数据x,返回x的预测结果值"""
assert x.shape[0] == self._X_train.shape[1], \
"the feature number of x must be equal to X_train"
distances = [sqrt(np.sum((x_train - x) ** 2))
for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0]
def score(self, X_test, y_test):
"""根据测试数据集 X_test 和 y_test 确定当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict)
def __repr__(self):
return "KNN(k=%d)" % self.k
KNN算法介绍及源码实现的更多相关文章
- iOS下使用SHA1WithRSA算法加签源码
首先了解一下几个相关概念,以方便后面遇到的问题的解决: RSA算法:1977年由Ron Rivest.Adi Shamirh和LenAdleman发明的,RSA就是取自他们三个人的名字.算法基于一个数 ...
- OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波
http://blog.csdn.net/chenyusiyuan/article/details/8710462 OpenCV学习笔记(27)KAZE 算法原理与源码分析(一)非线性扩散滤波 201 ...
- Android IntentService使用介绍以及源码解析
版权声明:本文出自汪磊的博客,转载请务必注明出处. 一.IntentService概述及使用举例 IntentService内部实现机制用到了HandlerThread,如果对HandlerThrea ...
- IPerf——网络测试工具介绍与源码解析(4)
上篇随笔讲到了TCP模式下的客户端,接下来会讲一下TCP模式普通场景下的服务端,说普通场景则是暂时不考虑双向测试的可能,毕竟了解一项东西还是先从简单的情况下入手会快些. 对于服务端,并不是我们认为的直 ...
- 机器学习:IB1算法的weka源码详细解析(1NN)
机器学习的1NN最近邻算法,在weka里叫IB1,是因为Instance Base 1 ,也就是只基于一个最近邻的实例的惰性学习算法. 下面总结一下,weka中对IB1源码的学习总结. 首先需要把 ...
- KNN算法介绍
KNN算法全名为k-Nearest Neighbor,就是K最近邻的意思. 算法描述 KNN是一种分类算法,其基本思想是采用测量不同特征值之间的距离方法进行分类. 算法过程如下: 1.准备样本数据集( ...
- SM4密码算法(附源码)
SM4是我们自己国家的一个分组密码算法,是国家密码管理局于2012年发布的.网址戳→_→:http://www.cnnic.NET.cn/jscx/mixbz/sm4/ 具体的密码标准和算法官方有非常 ...
- JUC中Lock和ReentrantLock介绍及源码解析
Lock框架是jdk1.5新增的,作用和synchronized的作用一样,所以学习的时候可以和synchronized做对比.在这里先和synchronized做一下简单对比,然后分析下Lock接口 ...
- Android HandlerThread使用介绍以及源码解析
摘要: 版权声明:本文出自汪磊的博客,转载请务必注明出处. 一.HandlerThread的介绍及使用举例 HandlerThread是什么鬼?其本质就是一个线程,但是Han ...
随机推荐
- maven exclusion 理解
结论:exclusion 表示对传递性依赖进行排除,排除后当前项目的依赖jar中,就不会包含该传递性依赖. 扩展:项目中的jar 都会在classpath下,排除后的传递性依赖,相当于在classpa ...
- Java学习 (八)基础篇 运算符
目录 运算符 基本运算符 1.一元基础运算(重点) 一元运算符 (a++ / ++a) (a-- / --a) 2.二元基础运算 基础 计算返回值类型 关系运算 幂运算 3.三元运算符 4.逻辑运算符 ...
- 5.20 NOI 模拟
万年不更题解的鸽子来更题解了 \(T1\)矩阵 是个炒鸡恶心的推式子题 求\([x_1,x_2],[y_1,y_2]\)内部的数字和,把矩阵分成四份比较容易想到,差分也容易想到 \(Sum[x][y] ...
- 羽夏看Linux内核——中断与分页相关入门知识
写在前面 此系列是本人一个字一个字码出来的,包括示例和实验截图.如有好的建议,欢迎反馈.码字不易,如果本篇文章有帮助你的,如有闲钱,可以打赏支持我的创作.如想转载,请把我的转载信息附在文章后面,并 ...
- 使用 DolphinScheduler 调度 Kylin 构建
本文章经授权转载 Apache Kylin 上游通常有复杂的数据 ETL 过程,如 Hive 入库.数据清洗等:下游有报表刷新,邮件分发等.集成 Apache DolphinScheduler 后,K ...
- Luogu3803 【模板】多项式乘法(FFT)
为什么我这么弱 其实FFT也挺水的,一点数学基础加上细心即可.细节·技巧挺多. 递归 在TLE的边缘苦苦挣扎 #include <iostream> #include <cstdio ...
- Git 09 IDEA撤销提交
参考源 https://www.bilibili.com/video/BV1FE411P7B3?spm_id_from=333.999.0.0 版本 本文章基于 Git 2.35.1.2 如果提交了不 ...
- 急如闪电快如风,彩虹女神跃长空,Go语言高性能Web框架Iris项目实战-初始化项目ep00
在Golang Web编程的世界里,君不言高性能则已,言高性能必称Iris.彩虹女神的名号响彻寰宇.名动江湖,单论一个快字,无人能出其右,就连以简洁轻量著称于世的Gin也难以望其项背,只见彩虹女神Ir ...
- 搞定面试官 - 可以介绍一下在 MySQL 中你平时是怎么使用 COUNT() 的嘛?
大家好,我是程序员啊粥. 相信在大家的工作中,有很多的功能都需要用到 count(*) 来统计表中的数据行数.同时,对于一些大数据的表,用 count 都是瑟瑟发抖,往往会结合缓存等进行处理. 那么, ...
- 离线安装docker
一.安装步骤 1.下载Docker二进制文件(离线安装包) 下载地址:https://download.docker.com/linux/static/stable/x86_64/ 注:本文使用 do ...