http://blog.csdn.net/zouxy09/article/details/17590137

机器学习算法与Python实践之(六)二分k均值聚类

zouxy09@qq.com

http://blog.csdn.net/zouxy09

机器学习算法与Python实践这个系列主要是参考《机器学习实战》这本书。因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学习算法。恰好遇见这本同样定位的书籍,所以就参考这本书的过程来学习了。

在上一个博文中,我们聊到了k-means算法。但k-means算法有个比较大的缺点就是对初始k个质心点的选取比较敏感。有人提出了一个二分k均值(bisecting k-means)算法,它的出现就是为了一定情况下解决这个问题的。也就是说它对初始的k个质心的选择不太敏感。那下面我们就来了解和实现下这个算法。

一、二分k均值(bisecting k-means)算法

二分k均值(bisecting k-means)算法的主要思想是:首先将所有点作为一个簇,然后将该簇一分为二。之后选择能最大程度降低聚类代价函数(也就是误差平方和)的簇划分为两个簇。以此进行下去,直到簇的数目等于用户给定的数目k为止。

以上隐含着一个原则是:因为聚类的误差平方和能够衡量聚类性能,该值越小表示数据点月接近于它们的质心,聚类效果就越好。所以我们就需要对误差平方和最大的簇进行再一次的划分,因为误差平方和越大,表示该簇聚类越不好,越有可能是多个簇被当成一个簇了,所以我们首先需要对这个簇进行划分。

二分k均值算法的伪代码如下:

***************************************************************

将所有数据点看成一个簇

当簇数目小于k时

对每一个簇

计算总误差

在给定的簇上面进行k-均值聚类(k=2)

计算将该簇一分为二后的总误差

选择使得误差最小的那个簇进行划分操作

***************************************************************

二、Python实现

我使用的Python是2.7.5版本的。附加的库有Numpy和Matplotlib。具体的安装和配置见前面的博文。在代码中已经有了比较详细的注释了。不知道有没有错误的地方,如果有,还望大家指正(每次的运行结果都有可能不同)。里面我写了个可视化结果的函数,但只能在二维的数据上面使用。直接贴代码:

biKmeans.py

  1. #################################################
  2. # kmeans: k-means cluster
  3. # Author : zouxy
  4. # Date   : 2013-12-25
  5. # HomePage : http://blog.csdn.net/zouxy09
  6. # Email  : zouxy09@qq.com
  7. #################################################
  8. from numpy import *
  9. import time
  10. import matplotlib.pyplot as plt
  11. # calculate Euclidean distance
  12. def euclDistance(vector1, vector2):
  13. return sqrt(sum(power(vector2 - vector1, 2)))
  14. # init centroids with random samples
  15. def initCentroids(dataSet, k):
  16. numSamples, dim = dataSet.shape
  17. centroids = zeros((k, dim))
  18. for i in range(k):
  19. index = int(random.uniform(0, numSamples))
  20. centroids[i, :] = dataSet[index, :]
  21. return centroids
  22. # k-means cluster
  23. def kmeans(dataSet, k):
  24. numSamples = dataSet.shape[0]
  25. # first column stores which cluster this sample belongs to,
  26. # second column stores the error between this sample and its centroid
  27. clusterAssment = mat(zeros((numSamples, 2)))
  28. clusterChanged = True
  29. ## step 1: init centroids
  30. centroids = initCentroids(dataSet, k)
  31. while clusterChanged:
  32. clusterChanged = False
  33. ## for each sample
  34. for i in xrange(numSamples):
  35. minDist  = 100000.0
  36. minIndex = 0
  37. ## for each centroid
  38. ## step 2: find the centroid who is closest
  39. for j in range(k):
  40. distance = euclDistance(centroids[j, :], dataSet[i, :])
  41. if distance < minDist:
  42. minDist  = distance
  43. minIndex = j
  44. ## step 3: update its cluster
  45. if clusterAssment[i, 0] != minIndex:
  46. clusterChanged = True
  47. clusterAssment[i, :] = minIndex, minDist**2
  48. ## step 4: update centroids
  49. for j in range(k):
  50. pointsInCluster = dataSet[nonzero(clusterAssment[:, 0].A == j)[0]]
  51. centroids[j, :] = mean(pointsInCluster, axis = 0)
  52. print 'Congratulations, cluster using k-means complete!'
  53. return centroids, clusterAssment
  54. # bisecting k-means cluster
  55. def biKmeans(dataSet, k):
  56. numSamples = dataSet.shape[0]
  57. # first column stores which cluster this sample belongs to,
  58. # second column stores the error between this sample and its centroid
  59. clusterAssment = mat(zeros((numSamples, 2)))
  60. # step 1: the init cluster is the whole data set
  61. centroid = mean(dataSet, axis = 0).tolist()[0]
  62. centList = [centroid]
  63. for i in xrange(numSamples):
  64. clusterAssment[i, 1] = euclDistance(mat(centroid), dataSet[i, :])**2
  65. while len(centList) < k:
  66. # min sum of square error
  67. minSSE = 100000.0
  68. numCurrCluster = len(centList)
  69. # for each cluster
  70. for i in range(numCurrCluster):
  71. # step 2: get samples in cluster i
  72. pointsInCurrCluster = dataSet[nonzero(clusterAssment[:, 0].A == i)[0], :]
  73. # step 3: cluster it to 2 sub-clusters using k-means
  74. centroids, splitClusterAssment = kmeans(pointsInCurrCluster, 2)
  75. # step 4: calculate the sum of square error after split this cluster
  76. splitSSE = sum(splitClusterAssment[:, 1])
  77. notSplitSSE = sum(clusterAssment[nonzero(clusterAssment[:, 0].A != i)[0], 1])
  78. currSplitSSE = splitSSE + notSplitSSE
  79. # step 5: find the best split cluster which has the min sum of square error
  80. if currSplitSSE < minSSE:
  81. minSSE = currSplitSSE
  82. bestCentroidToSplit = i
  83. bestNewCentroids = centroids.copy()
  84. bestClusterAssment = splitClusterAssment.copy()
  85. # step 6: modify the cluster index for adding new cluster
  86. bestClusterAssment[nonzero(bestClusterAssment[:, 0].A == 1)[0], 0] = numCurrCluster
  87. bestClusterAssment[nonzero(bestClusterAssment[:, 0].A == 0)[0], 0] = bestCentroidToSplit
  88. # step 7: update and append the centroids of the new 2 sub-cluster
  89. centList[bestCentroidToSplit] = bestNewCentroids[0, :]
  90. centList.append(bestNewCentroids[1, :])
  91. # step 8: update the index and error of the samples whose cluster have been changed
  92. clusterAssment[nonzero(clusterAssment[:, 0].A == bestCentroidToSplit), :] = bestClusterAssment
  93. print 'Congratulations, cluster using bi-kmeans complete!'
  94. return mat(centList), clusterAssment
  95. # show your cluster only available with 2-D data
  96. def showCluster(dataSet, k, centroids, clusterAssment):
  97. numSamples, dim = dataSet.shape
  98. if dim != 2:
  99. print "Sorry! I can not draw because the dimension of your data is not 2!"
  100. return 1
  101. mark = ['or', 'ob', 'og', 'ok', '^r', '+r', 'sr', 'dr', '<r', 'pr']
  102. if k > len(mark):
  103. print "Sorry! Your k is too large! please contact Zouxy"
  104. return 1
  105. # draw all samples
  106. for i in xrange(numSamples):
  107. markIndex = int(clusterAssment[i, 0])
  108. plt.plot(dataSet[i, 0], dataSet[i, 1], mark[markIndex])
  109. mark = ['Dr', 'Db', 'Dg', 'Dk', '^b', '+b', 'sb', 'db', '<b', 'pb']
  110. # draw the centroids
  111. for i in range(k):
  112. plt.plot(centroids[i, 0], centroids[i, 1], mark[i], markersize = 12)
  113. plt.show()

三、测试结果

测试数据是二维的,共80个样本。有4个类。具体见上一个博文

测试代码:

test_biKmeans.py

  1. #################################################
  2. # kmeans: k-means cluster
  3. # Author : zouxy
  4. # Date   : 2013-12-25
  5. # HomePage : http://blog.csdn.net/zouxy09
  6. # Email  : zouxy09@qq.com
  7. #################################################
  8. from numpy import *
  9. import time
  10. import matplotlib.pyplot as plt
  11. ## step 1: load data
  12. print "step 1: load data..."
  13. dataSet = []
  14. fileIn = open('E:/Python/Machine Learning in Action/testSet.txt')
  15. for line in fileIn.readlines():
  16. lineArr = line.strip().split('\t')
  17. dataSet.append([float(lineArr[0]), float(lineArr[1])])
  18. ## step 2: clustering...
  19. print "step 2: clustering..."
  20. dataSet = mat(dataSet)
  21. k = 4
  22. centroids, clusterAssment = biKmeans(dataSet, k)
  23. ## step 3: show the result
  24. print "step 3: show the result..."
  25. showCluster(dataSet, k, centroids, clusterAssment)

这里贴出两次的运行结果:

不同的类用不同的颜色来表示,其中的大菱形是对应类的均值质心点。

事实上,这个算法在初始质心选择不同时运行效果也会不同。我没有看初始的论文,不确定它究竟是不是一定会收敛到全局最小值。《机器学习实战》这本书说是可以的,但因为每次运行的结果不同,所以我有点怀疑,自己去找资料也没找到相关的说明。对这个算法有了解的还望您不吝指点下,谢谢。

机器学习算法与Python实践之(六)二分k均值聚类的更多相关文章

  1. 机器学习算法与Python实践之(四)支持向量机(SVM)实现

    机器学习算法与Python实践之(四)支持向量机(SVM)实现 机器学习算法与Python实践之(四)支持向量机(SVM)实现 zouxy09@qq.com http://blog.csdn.net/ ...

  2. 机器学习算法与Python实践之(三)支持向量机(SVM)进阶

    机器学习算法与Python实践之(三)支持向量机(SVM)进阶 机器学习算法与Python实践之(三)支持向量机(SVM)进阶 zouxy09@qq.com http://blog.csdn.net/ ...

  3. 机器学习算法与Python实践之(二)支持向量机(SVM)初级

    机器学习算法与Python实践之(二)支持向量机(SVM)初级 机器学习算法与Python实践之(二)支持向量机(SVM)初级 zouxy09@qq.com http://blog.csdn.net/ ...

  4. 机器学习算法与Python实践之(五)k均值聚类(k-means)

    机器学习算法与Python实践这个系列主要是参考<机器学习实战>这本书.因为自己想学习Python,然后也想对一些机器学习算法加深下了解,所以就想通过Python来实现几个比较常用的机器学 ...

  5. 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression)

    http://blog.csdn.net/zouxy09/article/details/20319673 机器学习算法与Python实践之(七)逻辑回归(Logistic Regression) z ...

  6. 机器学习理论与实战(十)K均值聚类和二分K均值聚类

    接下来就要说下无监督机器学习方法,所谓无监督机器学习前面也说过,就是没有标签的情况,对样本数据进行聚类分析.关联性分析等.主要包括K均值聚类(K-means clustering)和关联分析,这两大类 ...

  7. 机器学习实战5:k-means聚类:二分k均值聚类+地理位置聚簇实例

    k-均值聚类是非监督学习的一种,输入必须指定聚簇中心个数k.k均值是基于相似度的聚类,为没有标签的一簇实例分为一类. 一 经典的k-均值聚类 思路: 1 随机创建k个质心(k必须指定,二维的很容易确定 ...

  8. Bisecting KMeans (二分K均值)算法讲解及实现

    算法原理 由于传统的KMeans算法的聚类结果易受到初始聚类中心点选择的影响,因此在传统的KMeans算法的基础上进行算法改进,对初始中心点选取比较严格,各中心点的距离较远,这就避免了初始聚类中心会选 ...

  9. spark Bisecting k-means(二分K均值算法)

    Bisecting k-means(二分K均值算法) 二分k均值(bisecting k-means)是一种层次聚类方法,算法的主要思想是:首先将所有点作为一个簇,然后将该簇一分为二.之后选择能最大程 ...

随机推荐

  1. Codeforces Round #397 by Kaspersky Lab and Barcelona Bootcamp (Div. 1 + Div. 2 combined) B. Code obfuscation 水题

    B. Code obfuscation 题目连接: http://codeforces.com/contest/765/problem/B Description Kostya likes Codef ...

  2. 初始化collectionViewCell

    #import <UIKit/UIKit.h> @interface TonyCollectionViewCell : UICollectionViewCell @property UII ...

  3. Linux学习笔记03—初识Linux

    命令介绍 忘记root密码的处理方法 系统安装盘的救援模式的使用 一.命令介绍 1.LS命令 ls 查看当前目录下的文件 Ls –l 等同于ll 查看目录的详细信息 Ls –a 查看当前目录下的所有文 ...

  4. 2018 dnc .NET Core、.NET开发的大型网站列表、各大公司.NET职位精选,C#王者归来

    简洁.优雅.高效的C#语言,神一样的C#创始人Anders Hejlsberg,async/await编译器级异步语法,N年前就有的lambda表达式,.NET Native媲美C++的原生编译性能, ...

  5. 使用 IntraWeb (12) - 基本控件之 TIWGradButton、TIWImageButton

    TIWGradButton.TIWImageButton 分别是有颜色梯度变化按钮和图像按钮. TIWGradButton 所在单元及继承链: IWCompGradButton.TIWGradButt ...

  6. @Transactional导致AbstractRoutingDataSource动态数据源无法切换的解决办法

    上午花了大半天排查一个多数据源主从切换的问题,记录一下: 背景: 项目的数据库采用了读写分离多数据源,采用AOP进行拦截,利用ThreadLocal及AbstractRoutingDataSource ...

  7. android应用程序签名(转)

    概述 Android系统要求,所有的程序经过数字签名后才能安装.Android系统使用这个证书来识别应用程序的作者,并且建立程序间的信任关系.证书不是用于用户控制哪些程序可以安装.证书不需要授权中心来 ...

  8. STM32的CRC32 测试代码

    // STM32 CRC32 Test App - sourcer32@gmail.com #include <windows.h> #include <stdio.h> DW ...

  9. IIS、Asp.net 编译时的临时文件路径

    IIS上部署的ASP.NET站点都会在一个.Net Framework的特定目录下生成临时编译文件增加ASP.NET站点的访问性能,有时候需要手动去删除这些临时编译文件,特别是发布新版本代码到IIS后 ...

  10. 【Go命令教程】6. go doc 与 godoc

    go doc 命令可以打印附于Go语言程序 实体 上的文档.我们可以通过把程序实体的标识符作为该命令的参数来达到查看其文档的目的. 插播:所谓 Go语言的 程序实体,是指变量.常量.函数.结构体以及接 ...