kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。

一、数据点的实现

  1. package com.meachine.learning.kmeans;
  2.  
  3. import java.util.ArrayList;
  4.  
  5. /**
  6. * 数据点,有n维数据
  7. *
  8. */
  9. public class Point {
  10. private static int num;
  11. private int id;
  12. private int dimensioNum; // 维度
  13. private ArrayList<Double> values;
  14. private int clusterId = -1;
  15. private double minDist = Integer.MAX_VALUE;
  16.  
  17. public Point() {
  18. id = ++num;
  19. values = new ArrayList<>();
  20. }
  21.  
  22. public void add(double e) {
  23. values.add(e);
  24. dimensioNum++;
  25. }
  26. //------set与get省略----------
  27. }

二、数据簇的实现

  1. package com.meachine.learning.kmeans;
  2.  
  3. import lombok.EqualsAndHashCode;
  4. import lombok.Getter;
  5. import lombok.Setter;
  6. import lombok.ToString;
  7.  
  8. /**
  9. * 簇<br>
  10. * 数据集合的基本信息
  11. *
  12. */
  13. public class Cluster {
  14. // 簇id
  15. private int clusterId;
  16. // 属于该簇的点的个数
  17. private int numOfPoints;
  18. // 簇中心点的信息
  19. private Point center;
  20.  
  21. public Cluster(int id) {
  22. this.clusterId = id;
  23. numOfPoints = 0;
  24. }
  25.  
  26. public Cluster(int id, Point center) {
  27. this.clusterId = id;
  28. this.center = center;
  29. }
  30. //----------set与get省略----------------
  31. }

三、计算数据点距离

  1. package com.meachine.learning.kmeans;
  2.  
  3. import java.util.List;
  4.  
  5. /**
  6. * 计算距离接口
  7. *
  8. */
  9. public interface IDistance<T> {
  10. public double getDis(List<T> p1, List<T> p2);
  11. }

  

  1. package com.meachine.learning.kmeans;
  2.  
  3. import java.util.List;
  4.  
  5. /**
  6. * 欧式距离
  7. *
  8. */
  9. public class OujilidDistance<T extends Number> implements IDistance<T> {
  10.  
  11. public double getDis(List<T> a, List<T> b) {
  12. if (a.size() != b.size()) {
  13. throw new IllegalArgumentException("Size not compatible!");
  14. }
  15. double result = 0;
  16. for (int i = 0; i < a.size(); i++) {
  17. result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
  18. }
  19. return Math.sqrt(result);
  20. }
  21.  
  22. }

四、K-Means算法

  

  1. package com.meachine.learning.kmeans;
  2.  
  3. import java.io.BufferedReader;
  4. import java.io.File;
  5. import java.io.FileReader;
  6. import java.io.IOException;
  7. import java.util.ArrayList;
  8. import java.util.List;
  9. import java.util.Random;
  10.  
  11. /**
  12. * K-Means算法
  13. *
  14. * @author Cang
  15. *
  16. */
  17. public class KMeans {
  18. // 簇的个数
  19. private int k;
  20. // 维度,即多少个变量
  21. private int dimensioNum;
  22. // 最大迭代次数
  23. private int maxItrNum = 100;
  24. private IDistance<Double> distance;
  25. private List<Point> points;
  26. private List<Cluster> clusters = new ArrayList<Cluster>();
  27. private String dataFileName = "D:/testSet.txt";
  28.  
  29. public KMeans(int k) {
  30. this.k = k;
  31. }
  32.  
  33. /**
  34. * 初始化数据
  35. */
  36. public void init() {
  37. points = loadDataSet(dataFileName);
  38. distance = new OujilidDistance<Double>();
  39. initCluster();
  40. }
  41.  
  42. /**
  43. * 加载数据集
  44. *
  45. * @param fileName
  46. * @return
  47. */
  48. private List<Point> loadDataSet(String fileName) {
  49. List<Point> points = new ArrayList<>();
  50. File file = new File(fileName);
  51. BufferedReader reader = null;
  52. try {
  53. reader = new BufferedReader(new FileReader(file));
  54. String tempString = null;
  55. int i = 0;
  56. while ((tempString = reader.readLine()) != null) {
  57. Point point = new Point();
  58. dimensioNum = tempString.split("\t").length;
  59. for (String data : tempString.split("\t")) {
  60. point.add(Double.parseDouble(data));
  61. }
  62. points.add(point);
  63. }
  64. reader.close();
  65. } catch (IOException e) {
  66. e.printStackTrace();
  67. }
  68. return points;
  69. }
  70.  
  71. /**
  72. * 初始化簇中心
  73. *
  74. * @return
  75. */
  76. private void initCluster() {
  77. Random ran = new Random();
  78. int id = 0;
  79. while (id < k) {
  80. Cluster c = new Cluster(++id);
  81. int temp = ran.nextInt(points.size());
  82. c.setCenter(points.get(temp));
  83. clusters.add(c);
  84. }
  85. }
  86.  
  87. /**
  88. * kMeans 具体算法
  89. */
  90. public void clustering() {
  91. boolean finished = false;
  92. int count = 0;
  93. while (!finished) {
  94. // 寻找最近的中心
  95. finished = true;
  96. for (Point point : points) {
  97. for (Cluster cluster : clusters) {
  98.  
  99. double minLen = distance.getDis(cluster.getCenter().getValues(),
  100. point.getValues());
  101. // 更新最小距离
  102. if (minLen < point.getMinDist()) {
  103. if (cluster.getClusterId() != point.getClusterId()) {
  104. finished = false;
  105. point.setClusterId(cluster.getClusterId());
  106. }
  107. point.setMinDist(minLen);
  108. }
  109. }
  110. }
  111. System.out.println("Cluster center info:");
  112. for (Cluster string : clusters) {
  113. System.out.println(string.getCenter().getValues());
  114. }
  115. // 更改中心的位置
  116. changeCentroids();
  117. // 超过循环次数,则跳出循环
  118. if (++count > maxItrNum) {
  119. finished = true;
  120. }
  121. }
  122. }
  123.  
  124. /**
  125. * 改变簇中心
  126. */
  127. private void changeCentroids() {
  128. for (Cluster cluster : clusters) {
  129. ArrayList<Double> newCenterValue = new ArrayList<Double>();
  130. Point newCenterPoint = new Point();
  131. double result = 0;
  132. for (int i = 0; i < dimensioNum; i++) {
  133. for (Point point : points) {
  134. if (point.getClusterId() == cluster.getClusterId()) {
  135. result += point.getValues().get(i);
  136. }
  137. }
  138. newCenterValue.add(result / points.size());
  139. }
  140. newCenterPoint.setClusterId(cluster.getClusterId());
  141. newCenterPoint.setValues(newCenterValue);
  142. cluster.setCenter(newCenterPoint);
  143. }
  144. }
  145.  
  146. public static void main(String[] args) {
  147. KMeans kmeans = new KMeans(4);
  148. kmeans.init();
  149. kmeans.clustering();
  150. }
  151. }

  

K-Means 算法(Java)的更多相关文章

  1. k近邻算法-java实现

    最近在看<机器学习实战>这本书,因为自己本身很想深入的了解机器学习算法,加之想学python,就在朋友的推荐之下选择了这本书进行学习. 一 . K-近邻算法(KNN)概述 最简单最初级的分 ...

  2. KNN 与 K - Means 算法比较

    KNN K-Means 1.分类算法 聚类算法 2.监督学习 非监督学习 3.数据类型:喂给它的数据集是带label的数据,已经是完全正确的数据 喂给它的数据集是无label的数据,是杂乱无章的,经过 ...

  3. K-means算法

    K-means算法很简单,它属于无监督学习算法中的聚类算法中的一种方法吧,利用欧式距离进行聚合啦. 解决的问题如图所示哈:有一堆没有标签的训练样本,并且它们可以潜在地分为K类,我们怎么把它们划分呢?  ...

  4. k近邻算法的Java实现

    k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...

  5. KNN算法java实现代码注释

    K近邻算法思想非常简单,总结起来就是根据某种距离度量检测未知数据与已知数据的距离,统计其中距离最近的k个已知数据的类别,以多数投票的形式确定未知数据的类别. 一直想自己实现knn的java实现,但限于 ...

  6. Floyd算法java实现demo

    Floyd算法java实现,如下: https://www.cnblogs.com/Halburt/p/10756572.html package a; /** * ┏┓ ┏┓+ + * ┏┛┻━━━ ...

  7. k-means算法Java一维实现

    这里的程序稍微有点变形.k_means方法返回K-means聚类的若干中心点.代码: import java.util.ArrayList; import java.util.Collections; ...

  8. 感知机学习算法Java实现

    感知机学习算法Java实现. Perceptron类用于实现感知机, 其中的perceptronOriginal()方法用于实现感知机学习算法的原始形式: perceptronAnother()方法用 ...

  9. 一致哈希算法Java实现

    一致哈希算法(Consistent Hashing Algorithms)是一个分布式系统中经常使用的算法. 传统的Hash算法当槽位(Slot)增减时,面临全部数据又一次部署的问题.而一致哈希算法确 ...

  10. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

随机推荐

  1. mysql导出csv文件excel打开后数字用科学计数法显示且低位变0的解决方法

    Excel显示数字时,如果数字大于12位,它会自动转化为科学计数法:如果数字大于15位,它不仅用于科学技术费表示,还会只保留高15位,其他位都变0. Excel打开csv文件时,只要字段值都是数字,它 ...

  2. oneThink添加成功,返回到当前请求地址!

    其实没什么,就一行代码: $this->success('已采纳',$_SERVER['HTTP_REFERER']);

  3. 最小树形图(poj3164)

    Command Network Time Limit: 1000MS   Memory Limit: 131072K Total Submissions: 12834   Accepted: 3718 ...

  4. rest_framework之规范详解 00

    接口开发 方式1:缺点:如果有10张表,则需要40个url. urls.py views.py 缺点:如果有10张表,则需要40个url.    接下来就出现了resrful 规范,比较简洁 方式2: ...

  5. instanceof 用于确定一个 PHP 变量是否属于某一类 class 的实例 , 返回true或者false

    <?phpclass MyClass{} class NotMyClass{}$a = new MyClass; var_dump($a instanceof MyClass);var_dump ...

  6. Service学习笔记

    一 什么是Service Service作为安卓四大组件之一,拥有重要的地位.Service和Activity级别相同,只是没有界面,是运行于后台的服务.这个运行“后台”是指不可见,不是指在后台线程中 ...

  7. #pragma预处理命令详解

    #pragma预处理命令 #pragma可以说是C++中最复杂的预处理指令了,下面是最常用的几个#pragma指令: #pragma comment(lib,"XXX.lib") ...

  8. ubuntu16.04下笔记本自带摄像头编译运行PTAM

    ubuntu16.04下笔记本自带摄像头编译运行PTAM 转载请注明链接:https://i.cnblogs.com/EditPosts.aspx?postid=9014147 个人邮箱:feifan ...

  9. ORA-39006、ORA-39065、ORA-01403、ORA-39097错误解决办法

    今天有同事找说是expdp到出数据时报错: 处理方法:sys用户下执行如下语句重新生成DATAPUMP API用到的视图问题就解决了. SQL> @?/rdbms/admin/catmeta.s ...

  10. python 类 __module__ __class__

    __module__ 和  __class__  __module__ 表示当前操作的对象在那个模块 __class__     表示当前操作的对象的类是什么 创建一个目录lib 在day7 目录下创 ...