手写算法-python代码实现KNN
原理解析
KNN-全称K-Nearest Neighbor,最近邻算法,可以做分类任务,也可以做回归任务,KNN是一种简单的机器学习方法,它没有传统意义上训练和学习过程,实现流程如下:
1、在训练数据集中,找到和需要预测样本最近邻的K个实例;
2、分别统计这K个实例所属的类别,最多的那个类别就是样本预测的类别(多数表决法);
对于回归任务而言,则是求这K个实例输出值的平均值(选择平均法);
因此,该算法的几个重点在于:
1、K值的选取,K值的不同直接会导致最终结果的不同;
选择较小的k值,就相当于用较小的领域中的训练实例进行预测,训练误差会减小,只有与输入实例较近或相似的训练实例才会对预测结果起作用,与此同时带来的问题是泛化误差会增大,换句话说,K值的减小就意味着整体模型变得复杂,容易发生过拟合;
选择较大的k值,就相当于用较大领域中的训练实例进行预测,其优点是可以减少泛化误差,但缺点是训练误差会增大。这时候,与输入实例较远(不相似的)训练实例也会对预测器作用,使预测发生错误,且K值的增大就意味着整体的模型变得简单,容易欠拟合;
一般的,最佳K值的选取,我们可以用交叉验证法来寻找,分类准确率最高的(回归问题中就是均方误差最小的),就是最佳的K值;
2、距离的计算,计算最近邻的K个样本时,用哪种度量方式,最常用的是欧氏距离;
3、决策规则,一般就是多数表决法或者选择平均法,但是,K个近邻数据,到样本的距离也不一样,都一视同仁,也不太合理;
代码实现
根据上面的KNN原理解析,我们来编写python代码(分类任务代码):
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.metrics import classification_report
class Knn():
#默认k=5,设置和sklearn中的一样
def __init__(self,k=5):
self.k = k
def fit(self,x,y):
self.x = x
self.y = y
def predict(self,x_test):
labels = []
#这里可以看出,KNN的计算复杂度很高,一个样本就是O(m * n)
for i in range(len(x_test)):
#初始化一个y标签的统计字典
dict_y = {}
#计算第i个测试数据到所有训练样本的欧氏距离
diff = self.x - x_test[i]
distances = np.sqrt(np.square(diff).sum(axis=1))
#对距离排名,取最小的k个样本对应的y标签
rank = np.argsort(distances)
rank_k = rank[:self.k]
y_labels = self.y[rank_k]
#生成类别字典,key为类别,value为样本个数
for j in y_labels:
if j not in dict_y:
dict_y.setdefault(j,1)
else:
dict_y[j] += 1
#取得y_labels里面,value值最大对应的类别标签即为测试样本的预测标签
#label = sorted(dict_y.items(),key = lambda x:x[1],reverse=True)[0][0]
#下面这种实现方式更加优雅
label = max(dict_y,key = dict_y.get)
labels.append(label)
return labels
实例展示
利用sklearn生成实验数据集:
#有同学私信我,为什么每次都是生成2维数据,因为2维数据方便画图,哈哈
x,y = make_classification(n_features=2,n_redundant=0,random_state=2019)
plt.scatter(x[:,0],x[:,1],c=y)
plt.show()

数据表现如上,来看看分类效果:
#预测
knn = Knn()
knn.fit(x,y)
labels = knn.predict(x)
#查看分类报告
print(classification_report(y,labels))

f1的值为94%,下面画图看看分类边界:
#画等高线图
x_min,x_max = x[:,0].min() - 1,x[:,0].max() + 1
y_min,y_max = x[:,1].min() - 1,x[:,1].max() + 1
xx = np.arange(x_min,x_max,0.02)
yy = np.arange(y_min,y_max,0.02)
xx,yy = np.meshgrid(xx,yy)
x_1 = np.c_[xx.ravel(),yy.ravel()]
y_1 = knn.predict(x_1)
#list没有reshape方法,转为np.array的格式
plt.contourf(xx,yy,np.array(y_1).reshape(xx.shape),cmap='GnBu')
plt.scatter(x[:,0],x[:,1],c=y)
plt.show()

看起来还是很好的。
sklearn对比
下面调用sklearn里面的KNN库对比效果:
from sklearn.neighbors import KNeighborsClassifier
clf = KNeighborsClassifier(n_neighbors=5)
clf.fit(x,y)
#输出分类报告
print(classification_report(y,clf.predict(x)))
#画图
y_pred = clf.predict(x_1)
plt.contourf(xx,yy,y_pred.reshape(xx.shape),cmap='GnBu')
plt.scatter(x[:,0],x[:,1],c=y)
plt.show()

结果基本上是一样的。
总结
上面我们写了python代码,来实现基本的KNN分类,实际上sklearn里面的KNeighborsClassifier分类器,封装的内容则很多:
我们python代码中采用的办法称为蛮力实现(brute-force),即需要计算每一个测试样本到所有训练样本的距离,才能确定最终的预测标签,当数据集很大,特征很多,并且测试样本也很多时,需要的计算量大家可以想象一下,基本上跑不出来结果,这点我自己是有实际案例的;
而sklearn里面则对这点做出了优化,除了蛮力实现(brute-force),还有KD树实现(KDTree)和球树(BallTree)实现,后两者则大大提高了处理大数据集时的效率(感兴趣的同学可自行去查找这两者的资料),对于这三个算法,sklearn会根据输入的样本,自动选择一种算法(默认参数algorithm=‘auto’)。
对于距离的计算,我们直接用的欧氏距离,在sklearn里面,封装了很多种距离的度量方式,比如欧氏距离、曼哈顿距离、马氏距离、闵可夫斯基距离等,默认的是p=2的闵可夫斯基距离,也就是欧氏距离。
本文的文字及图片来源于网络,仅供学习、交流使用,不具有任何商业用途,如有问题请及时联系我们以作处理
想要获取更多Python学习资料可以加
QQ:2955637827私聊
或加Q群630390733
大家一起来学习讨论吧!
手写算法-python代码实现KNN的更多相关文章
- 4.redis 的过期策略都有哪些?内存淘汰机制都有哪些?手写一下 LRU 代码实现?
作者:中华石杉 面试题 redis 的过期策略都有哪些?内存淘汰机制都有哪些?手写一下 LRU 代码实现? 面试官心理分析 如果你连这个问题都不知道,上来就懵了,回答不出来,那线上你写代码的时候,想当 ...
- 将自己写的Python代码打包放到PyPI上
如果是开源的Python代码,为了能够让大家更方便的使用,放到PyPI上也许是个非常不错的主意(PyPI:Python Package Index).刚开始我以为要将代码打包放到PyPI上是一件非常复 ...
- 我写的python代码的规则
1.Python文件的命名: 采用每个单词的首字母大写,不使用下划线 2.Python类的命名: 采用每个单词的首字母大写,不使用下划线 3.Python包名的命名:采用每个单词都是小写,不使用下划线 ...
- redis的过期策略都有哪些?内存淘汰机制都有哪些?手写一下LRU代码实现?
redis的过期策略都有哪些? 设置过期时间: set key 的时候,使用expire time,就是过期时间.指定这个key比如说只能存活一个小时?10分钟?指定缓存到期就会失效. redis的过 ...
- 【MaixPy3文档】写好 Python 代码!
本文是给有一点 Python 基础但还想进一步深入的同学,有经验的开发者建议跳过. 前言 上文讲述了如何认识开源项目和一些编程方法的介绍,这节主要来说说 Python 代码怎么写的一些演化过程和可以如 ...
- 我写的 Python 代码,同事都说好
原文链接: 我写的 Python 代码,同事都说好 人生苦短,我用 Python. 程序员的追求就是不写代码,早日财务自由.不对,一不小心把实话说出来了,应该是将代码写得简洁,优雅. Python 程 ...
- k-近邻算法python代码实现(非常全)
1.k近邻算法是学习机器学习算法最为经典和简单的算法,它是机器学习算法入门最好的算法之一,可以非常好并且快速地理解机器学习的算法的框架与应用.它是一种经典简单的分类算法,当然也可以用来解决回归问题.2 ...
- 如何优雅的写好python代码?
Python与其他语言(比如 java或者 C ++ )相比有较大的区别,其中最大的特点就是非常简洁,如果按照其他语言的思路老师写Python代码,则会使得代码繁琐复杂,并且容易出现bug,在Pyth ...
- 手写神经网络Python深度学习
import numpy import scipy.special import matplotlib.pyplot as plt import scipy.misc import glob impo ...
随机推荐
- vulnhub: DC 3
通过nmap扫描,只开放了80端口,并且该web服务是基于Joomla搭建: root@kali:~# nmap -A 192.168.74.140 Starting Nmap 7.80 ( http ...
- Spring MVC系列-(5) AOP
5 AOP 5.1 什么是AOP AOP(Aspect-Oriented Programming,面向切面编程),可以说是OOP(Object-Oriented Programing,面向对象编程)的 ...
- uni-app 封装接口request请求
我们知道一个项目中对于前期架构的搭建工作对于后期的制作有多么重要,所以不管做什么项目我们拿到需求后一定要认真的分析一下,要和产品以及后台沟通好,其中尤为重要的一个环节莫过于封装接口请求了.因为前期封装 ...
- Prafab Varient 预制体变体
预制体与类的类比思维: 预制体相当于一个类,当它应用到场景当中,就是一个实例. 类的继承特性也充分运用到预制体中,即预制体变体. 相似预制体的需求场景: 例子1:多个游戏的窗口 ...
- 第一次UML作业
这个作业属于哪个课程 https://edu.cnblogs.com/campus/fzzcxy/2018SE2/ 这个作业要求在哪里 https://edu.cnblogs.com/campus/f ...
- MySQL慢查询日志(SLOW LOG)
慢查询日志可以帮助DBA或开发人员定位可能存在问题的SQL语句,从而进行优化. 如何开启 默认情况下,MySQL是不开启慢查询日志的.可以通过以下命令查看是否开启: mysql> SHOW VA ...
- 第九章 Python文件操作
前一阵子写类相关的内容,把老猿写得心都累了,本来准备继续介绍一些类相关的知识的,如闭包.装饰器.描述符.枚举类.异常等,现在实在不想继续,以后再开章节吧.本章弄点开胃的小菜提提神,介绍Python中文 ...
- PyQt(Python+Qt)学习随笔:QTreeView树形视图的wordWrap属性
老猿Python博文目录 专栏:使用PyQt开发图形界面Python应用 老猿Python博客地址 QTreeView树形视图的wordWrap属性用于控制视图展示数据项文本的单词换行原则,如果该值为 ...
- pandas 移动列的方法
import pandas as pd df = pd.DataFrame(np.random.randn(3,4),columns=['a','b','c','d']) k = df.pop(&qu ...
- Making Games with Python & Pygame 中文翻译
Making Games with Python & Pygame 用Pygame做游戏 第1章-安装python和pygame 原文作者:Al Sweigart 翻译:bigbigli/李超 ...