近邻算法是机器学习算法中的入门算法,该算法用于针对已有数据集对未知数据进行分类。

该算法核心思想是通过计算预測数据与已有数据的相似度猜測结果。

举例:

如果有例如以下一组数据(在下面我们统一把该数据作为训练数据):

身高 年龄 国籍
170 23 中国
180 21 美国
185 22 俄国
175 24 中国
120 23 日本

我们将该组数据的前两列作为一个特征值。最后一项作为一个待分类结果。当数据量足够大时,能够通过该数据便可得出较为精确的分类。

比如我们拿到測试数据。年龄23岁,身高172,须要预測其国籍。

一般预測的方法是通过特征值求測试数据与训练数据中各项的差异(求点与点之间的距离),当有x,y两项特征值时,则距离为

watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQv/font/5a6L5L2T/fontsize/400/fill/I0JBQkFCMA==/dissolve/70/gravity/Center" alt="">

注:xA0为測试数据第一特征值(身高),xB0为训练数据第一特征值(身高),xA1为測试数据第二特征值(年龄),xB1为训练数据第二特征值(年龄)

同理,当特征项有3项时,则为(xA0 - xB0)的平方加上(xA1 - xB1)的平方再加上(xA2 - xB2)的平台,再开平方根。这个就是数学中求空间内两个点的距离

距离近期的训练数据则为特征最为相似的数据,该训练数据的分类则是分类可能性最大的结果

由于数据过少,我们简单能够看出。通过春节龄和身高进行匹配。最为相似的数据就是下面这一条。

170 23 中国

在训练数据较大时,须要程序进行比对分类,程序的一般逻辑是通过训练数据集中的各项数据与測试数据进行距离计算,求出最为相似的数据。可是这样对数据准确性的依赖太强,所以引入了近邻的概念。则是通过最为相似的N个数据,求出这N个数据中出现概率最高的分类则是近邻算法的结果,在本文中使用python和numpy库实现近邻算法。

# train_data为训练数据集的特征值(在本次训练数据集中为[[170, 23], [180, 21], [185, 22], [175, 24], [120, 23]])

# type_set为训练数据集分类。与train_data顺序相相应的(在本次训练数据集中为[中国,美国,俄国,中国,日本])

# test_data为须要分类的測试数据([172, 23])

# k则为上面所述的N值

import numpy as np

def forecast_data_type(train_data, type_set, test_data, k):

# 求出train_data长度,matrix类型的shape是矩阵各维度长度。比如[[1,2,3],[4,5,6]]为(2,3)

train_data_size = train_data.shape[0]

# 首先把test_data转换成与train_data一样的格式[[172,23], [172,23],[172,23], [172,23],[172,23]]

test_data_set = np.tile(test_data, (train_data_size,1))

# 求出train_data和test_data_set的差(结果为[[x1A0-x1A1,x1B0-x1B1], ......, [x4A0-x4A1,x4B0-x4B1]])

data_diff = train_data - test_data_set

# 对差值求平方(这个地方我用matrix来求平方不得行,所以我先转成了array)

data_diff_pow = np.mat(np.asarray(data_diff) ** 2)

# 将平方值相加

data_diff_pow_sum = data_diff_pow.sum(axis=1)

# 求平方根,a的平方根就是a的1/2次方。得出距离(类似[[1.22], [0.31], [0.444]......])

distances = np.mat(np.asarray(data_diff_pow_sum) ** 0.5)

# 对结果进行排序,注意。argsort的返回值是原数据索引列表

# 比如old_data是[1, 3, 5, 2, 4],sorted_data是[1, 2, 3, 4, 5],

# sorted_data[0]相应old_data[0]

# sorted_data[1]相应old_data[3]

# sorted_data[2]相应old_data[1]

# sorted_data[3]相应old_data[4]

# sorted_data[4]相应old_data[2]

# 则argsort返回 [0, 3, 1, 4, 2],这个地方的axis=0是由于原来的数据不是[1, 2, 3, 4]而是[[1], [2], [3], [4]],假设不加这个參数则会返回[[0], [0], [0]]这样的

sorted_distance = distances.argsort(axis=0)

# 通过给定的k值选择最为相似的k个数据,我这边用了collections库的Counter

list_result = []

for i in range(k):

# sorted_distance 是 [[2], [1], [4]...]

list_result.append(type_set[sorted_distance [i][0]])

count_result = Counter(list_result)



count_result是一个有序dict,依照count的大小进行数据排序,比如{'a': 3, 'b': 2, 'c': 2}

count_result的第一项就是分类的结果

注意:

1.该算法较为依赖训练数据集的大小。在一定范围内,训练数据量越大得到的结果最准确。

2.k值比較关键。当k值过大和过小时数据准确性都会受到影响。

3.当特征值的某一项差异太大时,比如a特征的值为1,2,3,4这样。b特征的值为1000,2000,3000这样。b特征对总体推断的影响较大,这个时候就应该对全部特征值做归一化处理,归一化方法例如以下

归一化值 = (数据特征值 - 最小特征值) / (最大特征值 - 最小特征值) ------这样得出的特征值会<=1

比如c特征为 1000, 2000, 3000。 4000, 5000

最小特征值为1000,最大特征值为5000

那么假设值为3000,那么归一化后的特征值为 (3000 - 1000) / (5000 - 1000) 为 0.5

參考资料:

1.<<机器学习实战>>

machine_learning-knn算法具体解释(近邻算法)的更多相关文章

  1. 机器学习——KNN算法(k近邻算法)

    一 KNN算法 1. KNN算法简介 KNN(K-Nearest Neighbor)工作原理:存在一个样本数据集合,也称为训练样本集,并且样本集中每个数据都存在标签,即我们知道样本集中每一数据与所属分 ...

  2. 机器学习算法之K近邻算法

    0x00 概述   K近邻算法是机器学习中非常重要的分类算法.可利用K近邻基于不同的特征提取方式来检测异常操作,比如使用K近邻检测Rootkit,使用K近邻检测webshell等. 0x01 原理   ...

  3. Python实现机器学习算法:K近邻算法

    ''' 数据集:Mnist 训练集数量:60000 测试集数量:10000(实际使用:200) ''' import numpy as np import time def loadData(file ...

  4. 机器学习(四) 分类算法--K近邻算法 KNN (上)

    一.K近邻算法基础 KNN------- K近邻算法--------K-Nearest Neighbors 思想极度简单 应用数学知识少 (近乎为零) 效果好(缺点?) 可以解释机器学习算法使用过程中 ...

  5. 【机器学习】k近邻算法(kNN)

    一.写在前面 本系列是对之前机器学习笔记的一个总结,这里只针对最基础的经典机器学习算法,对其本身的要点进行笔记总结,具体到算法的详细过程可以参见其他参考资料和书籍,这里顺便推荐一下Machine Le ...

  6. 从K近邻算法、距离度量谈到KD树、SIFT+BBF算法

    转载自:http://blog.csdn.net/v_july_v/article/details/8203674/ 从K近邻算法.距离度量谈到KD树.SIFT+BBF算法 前言 前两日,在微博上说: ...

  7. 02-16 k近邻算法

    目录 k近邻算法 一.k近邻算法学习目标 二.k近邻算法引入 三.k近邻算法详解 3.1 k近邻算法三要素 3.1.1 k值的选择 3.1.2 最近邻算法 3.1.3 距离度量的方式 3.1.4 分类 ...

  8. KNN-k近邻算法

    目录 KNN-k近邻算法 一.KNN基础 二.自己写一个knn函数 三.使用sklearn中的KNN 四.自己写一个面向对象的KNN 五.分割数据集 六.使用sklearn中的鸢尾花数据测试KNN 七 ...

  9. 1.K近邻算法

    (一)K近邻算法基础 K近邻(KNN)算法优点 思想极度简单 应用数学知识少(近乎为0) 效果好 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 图解K近邻算法 上图是以 ...

随机推荐

  1. SQL Server的thread scheduling(线程调度)

      https://www.zhihu.com/question/53125711/answer/134461670 https://www.zhihu.com/question/53125711   ...

  2. CRUD using Spring MVC 4.0 RESTful Web Services and AngularJS

    国内私募机构九鼎控股打造APP,来就送 20元现金领取地址:http://jdb.jiudingcapital.com/phone.html内部邀请码:C8E245J (不写邀请码,没有现金送)国内私 ...

  3. Arcgis license 服务无法启动的解决问题

    来自:http://blog.csdn.net/u013719339/article/details/51240312 1.检查服务开没开.打开资源管理器然后按照下面就出现了.也可以打开运行——ser ...

  4. mysql中避免使用保留字和关键字做列的名字

    设计数据表时,应尽量避免使用MySQL的关键字和保留字作为表名或列名. 比如key和keys为保留字,如果不小心使用关键字或者保留字作为列名字,执行下面的语句会出现语法错误: select * fro ...

  5. 使用HttpClient发送请求接收响应

    1.一般需要如下几步:(1) 创建HttpClient对象.(2)创建请求方法的实例,并指定请求URL.如果需要发送GET请求,创建HttpGet对象:如果需要发送POST请求,创建HttpPost对 ...

  6. HTML:基本的标签

    概述: <html></html>标准的语言格式,回环标签,有头和躯体部分,头里面一般显示标题title,躯体部分显示内容:背景色.文字.图片.超链接.表格.表单等. 可以直接 ...

  7. 关于一道面试题,使用C#实现字符串反转算法

    关于一道面试题,使用C#实现字符串反转算法. 题目见http://student.csdn.net/space.php?do=question&ac=detail&qid=490 详细 ...

  8. thinkphp问题

    这几天组里有个php系统报安全漏洞,负责的厂商跑了,没办法,被组长丢过来改漏洞,记录一下部分内容. 配置php的环境  参考https://blog.csdn.net/u011415782/artic ...

  9. 解决eclipse中运行web项目时弹出的"Port 8080 required by Tomcat 9.0 Server at localhost is already in use...

    1.tomcat默认端口是8080,可以修改通过tomcat的端口 修改tomcat\conf\server.xml     结果运行程序,还是报"Port 8080 required by ...

  10. 多个Mapper和Reducer的Job

    多个Mapper和Reducer的Job @(Hadoop) 对于复杂的mr任务来说,只有一个map和reduce往往是不能够满足任务需求的,有可能是需要n个map之后进行reduce,reduce之 ...