KNN算法 0基础小白也能懂(附代码)

原文链接

1.K近邻是啥

1968年,Cover 和 Hart 提出了最初的近邻法,思路是——未知的豆离哪种豆最近,就认为未知豆和该豆是同一种类。

近邻算法的定义:为了判定未知样本的类别,以全部训练样本作为代表点计算未知样本与所有训练样本的距离,并以最近邻者的类别作为决策未知样本类别的唯一依据。

说人话就是,看这玩意离哪个东西距离最近,越近越像。

最近邻算法的缺陷是对噪声数据过于敏感。从图中可以得到,一个圈起来的蓝点和两个圈起来的红点到绿点的距离是相等的,根据最近邻算法,该点的形状无法判断。其实也就是东西太多太杂的话判断不清

为了解决这个问题,我们可以把位置样本周边的多个最近样本计算在内,扩大参与决策的样本量,以避免个别数据直接决定决策结果。也就是增大数据量减少误差。

引进K近邻算法——选择未知样本一定范围内确定个数的K个样本,该K个样本大多数属于某一类型,则未知样本判定为该类型。K近邻算法是最近邻算法的一个延伸。根据K近邻算法,离绿点最近的三个点中有两个是红点,一个是蓝点,红点的样本数量多于蓝点的样本数量,因此绿点的类别被判定为红点。

2.KNN算法步骤

一般来说,只选择样本数据集中前N个最相似的数据.K一般不大于20,最后,选择K个中出现次数最多的分类,作为新数据的分类。

那K值到底如何选择呢?

  1. 数据特性

    噪声与非相关特征:如果数据中存在较多噪声或非相关特征,较大的K值可以平滑分类结果,减小噪声的影响。然而,这也可能导致类别之间的界限变得模糊,特别是在数据分布复杂或类别间存在明显界限的情况下。

    特征选择与缩放:为减小噪声和非相关特征的影响,可以通过特征选择和特征缩放来优化输入数据。例如,利用进化算法或互信息进行特征选择,从而提高KNN算法的性能。

  2. 奇数K值

    避免平票:在二元分类问题中,选择奇数的K值有助于避免分类器投票时出现平票的情况,从而提高分类器的确定性。这一点在K值较小、数据分布相对均匀时尤为重要。

  3. 超参数优化

    启发式方法:K值的选择可以通过各种启发式技术来优化。交叉验证(Cross-validation)是一种常用的方法,它通过将数据集划分为训练集和验证集,尝试不同的K值,并选择在验证集上表现最佳的K值。

    自助法(Bootstrap):在二元分类问题中,自助法可以用于评估不同K值的性能,并帮助选择最佳的K值。自助法通过多次重复采样训练集,计算每次采样的分类准确率,从而估计K值的期望性能。(从原始数据集中随机有放回地抽样,生成多个新的子数据集。每个子数据集都与原始数据集大小相同,但是由于是有放回抽样,因此某些样本可能会在一个子数据集中出现多次,而另一些样本可能根本没有出现(OOB)。)

3.实战实现KNN算法

3.1 背景

假如一套房子打算出租,但不知道市场价格,可以根据房子的规格(面积、房间数量、厕所数量、容纳人数等),在已有数据集中查找相似(K近邻)规格的房子价格,看别人的相同或相似户型租了多少钱。

数据集在这,CardioGoodFitness 数据集 提取码:show

3.2数据分类

已知的数据集中,每个已出租住房都有房间数量、厕所数量、容纳人数等字段,并有对应出租价格。将预计出租房子数据与数据集中每条记录比较计算欧式距离(坐标系里的距离),取出距离最小的5条记录,将其价格取平均值,可以将其看做预计出租房子的市场平均价格。

import pandas as pd
import numpy as np
from scipy.spatial import distance#用于计算欧式距离
from sklearn.preprocessing import StandardScaler#用于对数据进行标准化操作
from sklearn.neighbors import KNeighborsRegressor#KNN算法
from sklearn.metrics import mean_squared_error#用于计算均方根误差

上面导入包的作用原理会在后面一一指出,先导入数据并提取目标字段

#导入数据并提取目标字段
path = r'rent_price.csv'
file = open(path, encoding = 'gb18030', errors = 'ignore')
dc_listings = pd.read_csv(file)
features = ['accommodates','bedrooms','bathrooms','beds','price','minimum_nights','maximum_nights','number_of_reviews']
dc_listings = dc_listings[features]

dc_listings长下面这样

3.3 初步数据清洗

数据集中非数值类型的字段需要转换,替换掉美元$符号和千分位逗号

#数据初步清洗
our_acc_value = 3
dc_listings['distance'] = np.abs(dc_listings.accommodates - our_acc_value)
dc_listings = dc_listings.sample(frac=1, random_state=0) #抽取 100% 的样本重排
dc_listings = dc_listings.sort_values('distance')
dc_listings['price'] = dc_listings['price'].str.replace(r'[\$,]', '', regex=True).astype(float)
dc_listings = dc_listings.dropna() #删除包含空值(NaN)的行

理想情况下,数据集中每个字段取值范围都相同,但实际上这是几乎不可能的,如果计算时直接用原数数据计算,则会造成较大训练误差,所以需要对各列数据进行标准化或归一化操作,尽量减少不必要的训练误差。归一化的目的就是使得预处理的数据被限定在一定的范围内(比如[0,1]或者[-1,1]),从而消除奇异样本数据导致的不良影响。

#数据标准化
dc_listings[features] = StandardScaler().fit_transform(dc_listings[features]) #都变成标准正态分布
normalized_listings = dc_listings

最好不要将所有数据全部拿来测试,需要分出训练集和测试集具体划分比例按数据集确定。

#取得训练集和测试集
norm_train_df = normalized_listings[:2792]
norm_test_df = normalized_listings[2792:]

3.4 计算欧式距离并预测房屋价格

#scipy包distance模块计算欧式距离
first_listings = normalized_listings.iloc[0][['accommodates', 'bathrooms']]
fifth_listings = normalized_listings.iloc[20][['accommodates', 'bathrooms']]
#用python方法做多变量KNN模型
def predict_price_multivariate(new_listing_value, feature_columns):
temp_df = norm_train_df
#distance.cdist计算两个集合的距离
temp_df['distance'] = distance.cdist(temp_df[feature_columns], [new_listing_value[feature_columns]])
temp_df = temp_df.sort_values('distance')#temp_df按distance排序
knn_5 = temp_df.price.iloc[:5] #选择距离最近的前5个样本
predicted_price = knn_5.mean()
return predicted_price
cols = ['accommodates', 'bathrooms']
norm_test_df['predicted_price'] = norm_test_df[cols].apply(predict_price_multivariate, feature_columns=cols, axis=1)
norm_test_df['squared_error'] = (norm_test_df['predicted_price'] - norm_test_df['price']) ** 2
mse = norm_test_df['squared_error'].mean()
rmse = mse ** (1/2)
print(rmse) #利用sklearn完成KNN
col = ['accommodates', 'bedrooms']
knn = KNeighborsRegressor()
#将自变量和因变量放入模型训练,并用测试数据测试
knn.fit(norm_train_df[cols], norm_train_df['price'])
two_features_predictions = knn.predict(norm_test_df[cols])
#计算预测值与实际值的均方根误差
two_features_mse = mean_squared_error(norm_test_df['price'], two_features_predictions)
two_features_rmse = two_features_mse ** (1/2)
print(two_features_rmse)

输出为

1.4667825805653032
......(一堆报错,表示你正在对一个可能是原 DataFrame 的切片的数据进行修改,不过不影响结果)
1.5356457412450537

总结:K近邻算法的核心要素

K的大小

在实际的应用中,一般采用一个比较小的K值。并采用交叉验证的方法,选取一个最优的K值。比如在之前的代码中,手动实现的K值选的就是5,同时sklearn包里K值默认也是5.

距离度量准则

  1. 欧氏距离(Euclidean Distance)

    欧氏距离是最常用的距离度量准则之一,适用于连续型变量。它表示两个点之间的“直线”距离。

    \(d(p,q)=\sqrt {\sum_{i=1}^{n}(p_i-q_i)^2}\)

    优点:简单易懂,计算方便。

    缺点:对特征尺度敏感,需要进行特征缩放(如标准化)。

  2. 曼哈顿距离(Manhattan Distance)

    曼哈顿距离,也称为“城市街区距离”或“L1距离”,表示两个点在各维度上的绝对差值的和。

    \(d(p,q)=\sum_{i=1}^{n}|p_i-q_i|\)

    优点:对特征缩放不太敏感,适用于高维空间。

    缺点:不能反映“直线”距离,可能导致某些情况下误差较大。

  3. 切比雪夫距离(Chebyshev Distance)

    切比雪夫距离,也称为L∞距离,表示两个点之间在所有坐标轴上最大差值的距离。

    \(d(p,q)=max_i|p_i-q_i|\)

    优点:适用于棋盘格状的网格空间,特别适合某些特殊情况下的度量。

    缺点:在某些应用中可能不够精确。

KNN算法 0基础小白也能懂(附代码)的更多相关文章

  1. Docker_入门?只要这篇就够了!(纯干货适合0基础小白)

    与sgy一起开启你的Docker之路 关键词: Docker; mac; Docker中使用gdb无法进入断点,无法调试; 更新1: 看起来之前那一版博文中参考资料部分引用的外站链接太多,被系统自动屏 ...

  2. 0基础小白怎么学好Java?

    自身零基础,我们应该先学好Java,小编给大家介绍一下Java的特性: Java语言是简单的 Java语言的语法与C语言和C++语言很接近,使得大多数程序员很容易学习和使用Java.Java丢弃了C+ ...

  3. 大一0基础小白用最基础C写哥德巴赫猜想

    #include <stdio.h>int main (){ int a,b,c,k,count1,count2; for(a=4;a<=1200;a=a+2){ for(b=2;b ...

  4. MySQL下载,安装,配置环境变量【0基础小白用】

    一,下载 选择社区版的,下载地址:https://dev.mysql.com/downloads/installer/  ,选择离线安装包 二,安装 1,双击安装包文件,这里选择服务模式,会安装在默认 ...

  5. (五)SpringBoot2.0基础篇- Mybatis与插件生成代码

    SpringBoot与Mybatis合并 一.创建SpringBoot项目,引入相关依赖包: <?xml version="1.0" encoding="UTF-8 ...

  6. (六)SpringBoot2.0基础篇- Redis整合(JedisCluster集群连接)

    一.环境 Redis:4.0.9 SpringBoot:2.0.1 Redis安装:Linux(Redhat)安装Redis 二.SpringBoot整合Redis 1.项目基本搭建: 我们基于(五) ...

  7. 【机器学习算法基础+实战系列】KNN算法

    k 近邻法(K-nearest neighbor)是一种基本的分类方法 基本思路: 给定一个训练数据集,对于新的输入实例,在训练数据集中找到与该实例最邻近的k个实例,这k个实例多数属于某个类别,就把输 ...

  8. 0基础的小白怎么学习Java?

    自身零基础,那么我们应该先学好Java,首先我们来了解下Java的特性: Java语言是简单的 Java语言的语法与C语言和C++语言很接近,使得大多数程序员很容易学习和使用Java.另一方面,Jav ...

  9. 0基础算法基础学算法 第八弹 递归进阶,dfs第一讲

    最近很有一段时间没有更新了,主要是因为我要去参加一个重要的考试----小升初!作为一个武汉的兢兢业业的小学生当然要去试一试我们那里最好的几个学校的考试了,总之因为很多的原因放了好久的鸽子,不过从今天开 ...

  10. 机器学习实战 之 KNN算法

    现在 机器学习 这么火,小编也忍不住想学习一把.注意,小编是零基础哦. 所以,第一步,推荐买一本机器学习的书,我选的是Peter harrigton 的<机器学习实战>.这本书是基于pyt ...

随机推荐

  1. [WPF]用HtmlTextBlock实现消息对话框的内容高亮和跳转

    动手写一个简单的消息对话框一文介绍了如何实现满足常见应用场景的消息对话框.但是内容区域的文字仅仅起到信息展示作用,对于需要部分关键字高亮,或者部分内容有交互性的场景(例如下图提示信息中的"w ...

  2. P7687 题解

    考场上数组开大了直接 MLE 了,气. 考虑把 A,B 两种服务分开算,一个边双连通分量内的点如过有一个有服务,那么整个联通分量就都有服务. 然后按边双联通分量缩点后原图变成树,一条边是关键路线当且仅 ...

  3. LVS介绍与配置

    目录 LVS(Linux Virtual Server) 1. 概述 1.1 LVS简介 1.2 LVS架构 2. LVS工作模式 2.1 NAT模式(Network Address Translat ...

  4. node.js (原生模板引擎模板)

    app01 // 引入http模块 const http = require('http'); //连接数据库 require('./model/connects'); // 创建网站服务器 cons ...

  5. SpringMVC面试题及答案

    SpringMvc 的控制器是不是单例模式,如果是,有什么问题,怎么解决? 问题:单例模式,在多线程访问时有线程安全问题 解决方法:不要用同步,在控制器里面不能写字段 SpringMvc 中控制器的注 ...

  6. JDK9之后 Eureka依赖

    <!--Eureka添加依赖开始--> <dependency> <groupId>javax.xml.bind</groupId> <artif ...

  7. 存储系列DAS,SAN,NAS常见网络架构

    随着主机.磁盘.网络等技术的发展,对于承载大量数据存储的服务器来说,服务器内置存储空间,或者说内置磁盘往往不足以满足存储需要.因此,在内置存储之外,服务器需要采用外置存储的方式扩展存储空间,今天在这里 ...

  8. 操作系统|SPOOLing(假脱机)技术

    什么是假脱机技术,它可以解决什么问题? 什么是脱机技术 要回答什么是假脱机技术,首先我们需要知道什么是脱机技术.<计算机操作系统(第四版)>写道: 为了解决人机矛盾及CPU和I/O设备之间 ...

  9. oeasy教您玩转vim - 46 - # 范围控制

    ​ 范围控制 回忆上节课内容 这次我们主要就是看命令行 首先是选择一个 [range] 这个范围 然后进行相应的操作 :11,30d :2,7y 还可以指定寄存器 :"a3,40y :&qu ...

  10. Jmeter函数助手14-TestPlanName

    TestPlanName函数获取当前测试计划保存的文件名称.该函数没有参数,直接引用即可${__TestPlanName}.