k-近邻算法原理

像之前提到的那样,机器学习的一个要点就是分类,对于分类来说有许多不同的算法,所谓的物以聚类,分以群分。我们非常的清楚,一个地域的人群,不管在生活习惯,还是在习俗上都是非常相似的,也就是我们说的一类人。每一类人都会形成自己的一个中心,越靠近这个中心的人越为相似。k近邻算法就是为了找到这个中心点,把这中心点当成这类关键点,在有新的数据需要分类的话,就看离哪个中心点近,那么就属于哪一类。

假设我们有这样的一组数据,他代表一个人的地理坐标位置:

x坐标 y坐标 哪省人
4.035615117 4.920529835 0
4.665299994 4.702897321 0
1.711128297 1.031989236 1

根据这坐标在图上绘出图形:

两个蓝色的点互相靠近,它们的属性应该是相似的,而红色的点,离这两个蓝色的点有一定的距离,可能属于另一个聚合。

在这里导入一组数据,这一组数据中有三个分类,每一个分类就是一个群,组成了三个中心,具体的数据和图如下:

import numpy as np
import random
import matplotlib.pyplot as plt def read_clusters(clustersfile):
cl = []
tl = []
with open(clustersfile, 'r') as f:
for line in f:
line = line.strip()
if line != '':
line = line.split()
constraint = [float(line[0]), float(line[1])] cl.append(constraint)
tl.append(int(line[2]))
return cl,tl train_data,train_labels = read_clusters('clusters3.txt')
train_data = np.array(train_data)
key_name = {0:'red',1:'blue',2:'orange'} for i in range(train_data.shape[0]):
plt.scatter(train_data[i:i + 1, 0:1], train_data[i:i + 1, 1:2], c=key_name[train_labels[i]], marker='o',s=20) plt.savefig('clusters.png')

k-近邻算法步骤

k-近邻的一般步骤如下:

1.先随机的产生几个中心,中心点的确认来自于需要组建几个类群。

def _init_random_centroids(self, data):
n_samples, n_features = np.shape(data)
centroids = np.zeros((self.k, n_features))
for i in range(self.k):
centroid = data[np.random.choice(range(n_samples))]
centroids[i] = centroid
return centroids

2.接下来是把所有的数据点跟这几个中心点进行比较,数据点里哪个中心点近,那么这个点就属于哪个类群。

计算距离的公式如下:

def euclidean_distance(vec_1, vec_2):
if(len(vec_1) != len(vec_2)):
raise Exception("The two vectors do NOT have equal length") distance = 0
for i in range(len(vec_1)):
distance += pow((vec_1[i] - vec_2[i]), 2) return np.sqrt(distance)

根据距离查找属于哪个中心点。

def _closest_centroid(self, sample, centroids):
closest_i = None
closest_distance = float("inf")
for i, centroid in enumerate(centroids):
distance = ml_helpers.euclidean_distance(sample, centroid)
if distance < closest_distance:
closest_i = i
closest_distance = distance
return closest_i

3.通过中心点确定了类群,在通过类群更新中心点。中心点是这个类群所有点的均值点,计算均值更新中心点。

def _calculate_centroids(self, clusters, data):
n_features = np.shape(data)[1]
centroids = np.zeros((self.k, n_features))
for i, cluster in enumerate(clusters):
centroid = np.mean(data[cluster], axis=0)
centroids[i] = centroid
return centroids

4.不断的更新这一个过程,直到中心点不在变化。

整个过程如下:

import numpy as np
import random
import sys import matplotlib.pyplot as plt def euclidean_distance(vec_1, vec_2):
if(len(vec_1) != len(vec_2)):
raise Exception("The two vectors do NOT have equal length") distance = 0
for i in range(len(vec_1)):
distance += pow((vec_1[i] - vec_2[i]), 2) return np.sqrt(distance) def read_clusters(clustersfile):
cl = []
tl = []
with open(clustersfile, 'r') as f:
for line in f:
line = line.strip()
if line != '':
line = line.split()
constraint = [float(line[0]), float(line[1])] cl.append(constraint)
tl.append(int(line[2]))
return cl,tl class KMeans():
def __init__(self, k=2, max_iterations=500):
self.k = k
self.max_iterations = max_iterations
self.kmeans_centroids = [] def _init_random_centroids(self, data):
n_samples, n_features = np.shape(data)
centroids = np.zeros((self.k, n_features))
for i in range(self.k):
centroid = data[np.random.choice(range(n_samples))]
centroids[i] = centroid
return centroids def _closest_centroid(self, sample, centroids):
closest_i = None
closest_distance = float("inf")
for i, centroid in enumerate(centroids):
distance = euclidean_distance(sample, centroid)
if distance < closest_distance:
closest_i = i
closest_distance = distance
return closest_i def _create_clusters(self, centroids, data):
n_samples = np.shape(data)[0]
clusters = [[] for _ in range(self.k)]
for sample_i, sample in enumerate(data):
centroid_i = self._closest_centroid(sample, centroids)
clusters[centroid_i].append(sample_i)
return clusters def _calculate_centroids(self, clusters, data):
n_features = np.shape(data)[1]
centroids = np.zeros((self.k, n_features))
for i, cluster in enumerate(clusters):
centroid = np.mean(data[cluster], axis=0)
centroids[i] = centroid
return centroids def _get_cluster_labels(self, clusters, data):
y_pred = np.zeros(np.shape(data)[0])
for cluster_i, cluster in enumerate(clusters):
for sample_i in cluster:
y_pred[sample_i] = cluster_i
return y_pred def fit(self, data):
centroids = self._init_random_centroids(data) for iteration in range(self.max_iterations): clusters = self._create_clusters(centroids, data) prev_centroids = centroids centroids = self._calculate_centroids(clusters, data) diff = centroids - prev_centroids
if not diff.any():
break self.kmeans_centroids = centroids
return centroids def predict(self, data): if not self.kmeans_centroids.any():
raise Exception("K-Means centroids have not yet been determined.\nRun the K-Means 'fit' function first.") clusters = self._create_clusters(self.kmeans_centroids, data) predicted_labels = self._get_cluster_labels(clusters, data) return predicted_labels key_name = {0:'red',1:'blue',2:'orange'} clf = KMeans(k=3, max_iterations=3000) train_data,train_labels = read_clusters('clusters3.txt')
train_data = np.array(train_data)
centroids = clf.fit(train_data)
print centroids

中心点不断更新的过程如下:

算法误差估计

检验算法的好坏,简单的办法是把一部分的数据用来训练,一部分的数据用来检验,查看算法的结果跟预计的数据相差多少?

下面是算法的效果估计:

Accuracy = 0
for index in range(len(train_labels)):
# Cluster the data using K-Means
current_label = train_labels[index]
predicted_label = predicted_labels[index] if current_label == int(predicted_label):
Accuracy += 1 Accuracy /= len(train_labels) print Accuracy

输出的结果为

1

准确率达到100%。

sklearn 下的k-近邻算法

在学习算法的时候知道了原理,通过自己的代码对算法的原理进行编写,通常来讲这很方便学习,在知道了如何编写算法以后,可以直接使用现成的开源库,直接使用该算法,sklearn 就非常方便使用。

clf = cluster.KMeans(n_clusters=3, max_iter=3000, n_init=10)
kmeans = clf.fit(train_data) Accuracy = 0
for index in range(len(train_labels)):
# Cluster the data using K-Means
current_sample = train_data[index].reshape(1,-1)
current_label = train_labels[index]
predicted_label = kmeans.predict(current_sample)
if current_label == predicted_label:
Accuracy += 1 Accuracy /= len(train_labels)

算法的应用

k-近邻算法用来找到中心点,同时算法也可以用来进行去重,把重复的附近的点都把他近似为中心点。

转载请标明来之:http://www.bugingcode.com/

更多教程:阿猫学编程

机器学习入门教程-k-近邻的更多相关文章

  1. 机器学习03:K近邻算法

    本文来自同步博客. P.S. 不知道怎么显示数学公式以及排版文章.所以如果觉得文章下面格式乱的话请自行跳转到上述链接.后续我将不再对数学公式进行截图,毕竟行内公式截图的话排版会很乱.看原博客地址会有更 ...

  2. 机器学习 Python实践-K近邻算法

    机器学习K近邻算法的实现主要是参考<机器学习实战>这本书. 一.K近邻(KNN)算法 K最近邻(k-Nearest Neighbour,KNN)分类算法,理解的思路是:如果一个样本在特征空 ...

  3. 机器学习实战python3 K近邻(KNN)算法实现

    台大机器技法跟基石都看完了,但是没有编程一直,现在打算结合周志华的<机器学习>,撸一遍机器学习实战, 原书是python2 的,但是本人感觉python3更好用一些,所以打算用python ...

  4. 02机器学习实战之K近邻算法

    第2章 k-近邻算法 KNN 概述 k-近邻(kNN, k-NearestNeighbor)算法是一种基本分类与回归方法,我们这里只讨论分类问题中的 k-近邻算法. 一句话总结:近朱者赤近墨者黑! k ...

  5. 机器学习算法之K近邻算法

    0x00 概述   K近邻算法是机器学习中非常重要的分类算法.可利用K近邻基于不同的特征提取方式来检测异常操作,比如使用K近邻检测Rootkit,使用K近邻检测webshell等. 0x01 原理   ...

  6. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

  7. 机器学习PR:k近邻法分类

    k近邻法是一种基本分类与回归方法.本章只讨论k近邻分类,回归方法将在随后专题中进行. 它可以进行多类分类,分类时根据在样本集合中其k个最近邻点的类别,通过多数表决等方式进行预测,因此不具有显式的学习过 ...

  8. 机器学习随笔01 - k近邻算法

    算法名称: k近邻算法 (kNN: k-Nearest Neighbor) 问题提出: 根据已有对象的归类数据,给新对象(事物)归类. 核心思想: 将对象分解为特征,因为对象的特征决定了事对象的分类. ...

  9. 机器学习-- 入门demo1 k临近算法

    1.k-近邻法简介 k近邻法(k-nearest neighbor, k-NN)是1967年由Cover T和Hart P提出的一种基本分类与回归方法. 它的工作原理是:存在一个样本数据集合,也称作为 ...

随机推荐

  1. PAT Basic 1020 ⽉饼 (25) [贪⼼算法]

    题目 ⽉饼是中国⼈在中秋佳节时吃的⼀种传统⻝品,不同地区有许多不同⻛味的⽉饼.现给定所有种类⽉饼的库存量.总售价.以及市场的最⼤需求量,请你计算可以获得的最⼤收益是多少. 注意:销售时允许取出⼀部分库 ...

  2. Linux中PATH、 LIBRARY_PATH、 LD_LIBRARY_PATH和ROS_PACKAGE_PATH

    PATH 保存可执行文件程序路径,我们命令行中每一句能运行的命令都是系统先通过PATH来找到命令执行文件所在的位置,再运行这个命令. 实验:执行echo $PATH 能看到当前环境PATH都是配置了哪 ...

  3. 支付宝H5支付demo

    支付宝H5支付 首先我们必须注册一个支付宝应用(本案例就直接用支付宝的沙箱环境,这个沙箱也就是支付宝提供给开发者的一个测试环境) 登录地址:https://open.alipay.com/platfo ...

  4. Python list 字符串 注册 登录

    #list #列表 python中 数组.array等都用列表 list表示#创建一个liststu = ['xiaoming','xiaoli','xiaohuang','alex','lily', ...

  5. Python语言学习前提:条件语句

    一.条件语句 1.条件语句:通过一条或多条语句的执行结果(True或False)来决定执行额代码块.python程序语言指定任何非0或非空(null)的值为true,0或null为false. 2. ...

  6. Oracle不同版本中序列的注意点

    <span style="font-size:14px;">create table manager ( userid NUMBER(10), username VAR ...

  7. js 判断元素的display是否为block或者none

    if($(this).css("display")=="none"){ //隐藏的 }else{ //显示的 }

  8. crm项目-业务实现

    ###############  crm业务    ############### """ 校区管理,部门管理,课程管理, 这三个都比较简单 1,只需要展示校区名称,这是 ...

  9. EXAM-2018-7-29

    EXAM-2018-7-29 未完成 [ ] H [ ] A D 莫名TLE 不在循环里写strlen()就行了 F 相减特判 水题 J 模拟一下就可以发现规律,o(n) K 每个数加一减一不变,用m ...

  10. 算法之匹配:KMP

    public static int getIndexOf(String str1, String str2) { if (str1 == null || str2 == null || str1.le ...