1. K-Means原理解析

2. K-Means的优化

3. sklearn的K-Means的使用

4. K-Means和K-Means++实现

1. 前言

前面3篇K-Means的博文从原理、优化、使用几个方面详细的介绍了K-Means算法,本文用python语言,详细的为读者实现一下K-Means。代码是本人修改完成,效率虽远不及sklearn,但是它的作用是在帮助同学们能从代码中去理解K-Means算法。后面我会慢慢的把所有的机器学习方面的算法,尽我所能的去实现一遍。

2. KMeans基本框架实现

先实现一个基本的kmeans,代码如下,需要查看完整代码的同学请移步至我的github

class KMeansBase(object):

    def __init__(self, n_clusters = 8, init = "random", max_iter = 300, random_state = None, n_init = 10, tol = 1e-4):
self.k = n_clusters # 聚类个数
self.init = init # 输出化方式
self.max_iter = max_iter # 最大迭代次数
self.random_state = check_random_state(random_state) #随机数
self.n_init = n_init # 进行多次聚类,选择最好的一次
self.tol = tol # 停止聚类的阈值 # fit对train建立模型
def fit(self, dataset):
self.tol = self._tolerance(dataset, self.tol) bestError = None
bestCenters = None
bestLabels = None
for i in range(self.n_init):
labels, centers, error = self._kmeans(dataset)
if bestError == None or error < bestError:
bestError = error
bestCenters = centers
bestLabels = labels
self.centers = bestCenters
return bestLabels, bestCenters, bestError # predict根据训练好的模型预测新的数据
def predict(self, X):
return self.update_labels_error(X, self.centers)[0] # 合并fit和predict
def fit_predict(self, dataset):
self.fit(dataset)
return self.predict(dataset) # kmeans的主要方法,完成一次聚类的过程
def _kmeans(self, dataset):
self.dataset = np.array(dataset)
bestError = None
bestCenters = None
bestLabels = None
centerShiftTotal = 0
centers = self._init_centroids(dataset) for i in range(self.max_iter):
oldCenters = centers.copy()
labels, error = self.update_labels_error(dataset, centers)
centers = self.update_centers(dataset, labels) if bestError == None or error < bestError:
bestLabels = labels.copy()
bestCenters = centers.copy()
bestError = error ## oldCenters和centers的偏移量
centerShiftTotal = np.linalg.norm(oldCenters - centers) ** 2
if centerShiftTotal <= self.tol:
break #由于上面的循环,最后一步更新了centers,所以如果和旧的centers不一样的话,再更新一次labels,error
if centerShiftTotal > 0:
bestLabels, bestError = self.update_labels_error(dataset, bestCenters) return bestLabels, bestCenters, bestError # k个数据点,随机生成
def _init_centroids(self, dataset):
n_samples = dataset.shape[0]
centers = []
if self.init == "random":
seeds = self.random_state.permutation(n_samples)[:self.k]
centers = dataset[seeds]
elif self.init == "k-means++":
pass
return np.array(centers) # 把tol和dataset相关联
def _tolerance(self, dataset, tol):
variances = np.var(dataset, axis=0)
return np.mean(variances) * tol # 更新每个点的标签,和计算误差
def update_labels_error(self, dataset, centers):
labels = self.assign_points(dataset, centers)
new_means = defaultdict(list)
error = 0
for assignment, point in zip(labels, dataset):
new_means[assignment].append(point) for points in new_means.values():
newCenter = np.mean(points, axis=0)
error += np.sqrt(np.sum(np.square(points - newCenter))) return labels, error # 更新中心点
def update_centers(self, dataset, labels):
new_means = defaultdict(list)
centers = []
for assignment, point in zip(labels, dataset):
new_means[assignment].append(point) for points in new_means.values():
newCenter = np.mean(points, axis=0)
centers.append(newCenter) return np.array(centers) # 分配每个点到最近的center
def assign_points(self, dataset, centers):
labels = []
for point in dataset:
shortest = float("inf") # 正无穷
shortest_index = 0
for i in range(len(centers)):
val = distance(point[np.newaxis], centers[i])
if val < shortest:
shortest = val
shortest_index = i
labels.append(shortest_index)
return labels

上面是我实现的基本的以EM算法为基础的一个KMeans的算法过程,我接口设计和参数形式尽量模范sklearn的方式,方面熟悉sklearn的同学接受起来比较快。

3. KMeans++实现

kmeans++的原理在之前有介绍。这里为了配合代码,再介绍一遍。

  1. 从输入的数据点集合中随机选择一个点作为第一个聚类中心\(\mu_1\).
  2. 对于数据集中的每一个点\(x_i\),计算它与已选择的聚类中心中最近聚类中心的距离.

\[D(x_i) = arg\;min|x_i-\mu_r|^2\;\;r=1,2,...k_{selected}
\]

  1. 选择一个新的数据点作为新的聚类中心,选择的原则是:\(D(x)\)较大的点,被选取作为聚类中心的概率较大
  2. 重复2和3直到选择出k个聚类质心。
  3. 利用这k个质心来作为初始化质心去运行标准的K-Means算法。

其中比较关键的是第2、3步,请看具体实现过程:

# kmeans++的初始化方式,加速聚类速度
def _k_means_plus_plus(self, dataset):
n_samples, n_features = dataset.shape
centers = np.empty((self.k, n_features))
# n_local_trials是每次选择候选点个数
n_local_trials = None
if n_local_trials is None:
n_local_trials = 2 + int(np.log(self.k)) # 第一个随机点
center_id = self.random_state.randint(n_samples)
centers[0] = dataset[center_id] # closest_dist_sq是每个样本,到所有中心点最近距离
# 假设现在有3个中心点,closest_dist_sq = [min(样本1到3个中心距离),min(样本2到3个中心距离),...min(样本n到3个中心距离)]
closest_dist_sq = distance(centers[0, np.newaxis], dataset) # current_pot所有最短距离的和
current_pot = closest_dist_sq.sum() for c in range(1, self.k):
# 选出n_local_trials随机址,并映射到current_pot的长度
rand_vals = self.random_state.random_sample(n_local_trials) * current_pot
# np.cumsum([1,2,3,4]) = [1, 3, 6, 10],就是累加当前索引前面的值
# np.searchsorted搜索随机出的rand_vals落在np.cumsum(closest_dist_sq)中的位置。
# candidate_ids候选节点的索引
candidate_ids = np.searchsorted(np.cumsum(closest_dist_sq), rand_vals) # best_candidate最好的候选节点
# best_pot最好的候选节点计算出的距离和
# best_dist_sq最好的候选节点计算出的距离列表
best_candidate = None
best_pot = None
best_dist_sq = None
for trial in range(n_local_trials):
# 计算每个样本到候选节点的欧式距离
distance_to_candidate = distance(dataset[candidate_ids[trial], np.newaxis], dataset) # 计算每个候选节点的距离序列new_dist_sq, 距离总和new_pot
new_dist_sq = np.minimum(closest_dist_sq, distance_to_candidate)
new_pot = new_dist_sq.sum() # 选择最小的new_pot
if (best_candidate is None) or (new_pot < best_pot):
best_candidate = candidate_ids[trial]
best_pot = new_pot
best_dist_sq = new_dist_sq centers[c] = dataset[best_candidate]
current_pot = best_pot
closest_dist_sq = best_dist_sq return centers

4. 效果比较

用kmeans_base和kmeans++和sklearn的kmeans对sklearn中自带的数据集iris、boston房价、digits进行聚类,比较速度和聚类效果比较。









5. 总结

Kmeans的算法讲解靠一段落,有兴趣的同学们可以去实践下我在优化中提到的另外两个优化方法,elkan减少计算距离的次数,Mini Batch处理大样本的情况下,计算的速度。

4. K-Means和K-Means++实现的更多相关文章

  1. lintcode 中等题:k Sum ii k数和 II

    题目: k数和 II 给定n个不同的正整数,整数k(1<= k <= n)以及一个目标数字. 在这n个数里面找出K个数,使得这K个数的和等于目标数字,你需要找出所有满足要求的方案. 样例 ...

  2. 今天遇到的面试题for(j=0,i=0;j<6,i<10;j++,i++) { k=i+j; } k 值最后是多少?

    for(j=0,i=0;j<6,i<10;j++,i++) { k=i+j; } k 值最后是多少? <script type="text/javascript" ...

  3. 设子数组A[0:k]和A[k+1:N-1]已排好序(0≤K≤N-1)。试设计一个合并这2个子数组为排好序的数组A[0:N-1]的算法。

    设子数组A[0:k]和A[k+1:N-1]已排好序(0≤K≤N-1).试设计一个合并这2个子数组为排好序的数组A[0:N-1]的算法.要求算法在最坏情况下所用的计算时间为O(N),只用到O(1)的辅助 ...

  4. 有两个序列A和B,A=(a1,a2,...,ak),B=(b1,b2,...,bk),A和B都按升序排列。对于1<=i,j<=k,求k个最小的(ai+bj)。要求算法尽量高效。

    有两个序列A和B,A=(a1,a2,...,ak),B=(b1,b2,...,bk),A和B都按升序排列.对于1<=i,j<=k,求k个最小的(ai+bj).要求算法尽量高效. int * ...

  5. Python交互K线工具 K线核心功能+指标切换

    Python交互K线工具 K线核心功能+指标切换 aiqtt团队量化研究,用vn.py回测和研究策略.基于vnpy开源代码,刚开始接触pyqt,开发界面还是很痛苦,找了很多案例参考,但并不能完全满足我 ...

  6. 给定一个非负索引 k,其中 k ≤ 33,返回杨辉三角的第 k 行。

    从第0行开始,输出第k行,传的参数为第几行,所以在方法中先将所传参数加1,然后将最后一行加入集合中返回. 代码如下: public static List<Integer> generat ...

  7. [leetcode]692. Top K Frequent Words K个最常见单词

    Given a non-empty list of words, return the k most frequent elements. Your answer should be sorted b ...

  8. [leetcode]347. Top K Frequent Elements K个最常见元素

    Given a non-empty array of integers, return the k most frequent elements. Example 1: Input: nums = [ ...

  9. imshow(K)和imshow(K,[]) 的区别

    参考文献 imshow(K)直接显示K:imshow(K,[])显示K,并将K的最大值和最小值分别作为纯白(255)和纯黑(0),中间的K值映射为0到255之间的标准灰度值.

  10. spine 所有动画的第一帧必须把所有能K的都K上

    spine 所有动画的第一帧必须把所有能K的都K上.否则在快速切换动画时会出问题.

随机推荐

  1. 【Servlet】web.xml中url-pattern的用法

    目录结构: contents structure [+] url-pattern的三种写法 servlet匹配原则 filter匹配原则 语法错误的后果 参考文章 一.url-pattern的三种写法 ...

  2. 【struts2】值栈(后篇)

    在值栈(前篇)我们学习了值栈的基本知识,接下来,来看看在程序中具体如何使用值栈. 1 ActionContext的基本使用 1.1 如何获取? 要获取ActionContext有两个基本的方法,如果在 ...

  3. logstash_output_mongodb插件用途及安装详解

    安装详情参见:http://mojijs.com/2017/03/222639/index.html http://www.jianshu.com/p/8516e51e105d

  4. Nginx https证书部署

    1 获取证书 Nginx文件夹内获得SSL证书文件 1_www.domain.com_bundle.crt 和私钥文件 2_www.domain.com.key,1_www.domain.com_bu ...

  5. sqlserver 在尝试加载程序集 ID 65537 时 Microsoft .NET Framework 出错.服务器可能资源不足

    报错信息: 处理报表时出错. 对数据集“query”执行查询失败. 在尝试加载程序集 ID 65536 时 Microsoft .NET Framework 出错.服务器可能资源不足,或者不信任该程序 ...

  6. 反射已经"Out",动态编译才能"Hold"住

    Net支持反射功能以后,确实使我们Net程序眼前一亮啊,真是太神奇了,只需要传入字符串就可以完成功能.可以说,反射功能的引入,使我们在处理某些问题上更加得心应手. 传统的Db管理软件中,数据库字段的频 ...

  7. 跟我学SharePoint 2013视频培训课程——网站导航及页面元素(2)

    课程简介 第2天,介绍SharePoint 2013 网站导航及页面元素 视频 SharePoint 2013 交流群 41032413

  8. SharePoint CAML In Action——Part II

    在SharePoint中,相对于Linq to SharePoint而言,CAML是轻量化的.当然缺点也是显而易见的,"Hard Code"有时会让你抓狂.在实际场景中,经常会根据 ...

  9. 关于埃博拉病毒的基本知识(ABC)

    科学研究表明.埃博拉病毒的存在历史很久远,可能有两千多万年的历史,在类人猿出现的时期就已存在. 埃博拉病毒呈现一种"蚕丝状",又细又长,无色透明.直径有80纳米,长短不等,在14微 ...

  10. CentOS 7 设置iptables防火墙开放proftpd端口

    由于ftp的被动模式是这样的,客户端跟服务器端的21号端口交互信令,服务器端开启21号端口能够使客户端登录以及查看目录.但是ftp被动模式用于传输数据的端口却不是21,而是大于1024的随机或配置文件 ...