K-Means 算法(Java)
kMeans算法原理见我的上一篇文章。这里介绍K-Means的Java实现方法,参考了Python的实现方法。
一、数据点的实现
- package com.meachine.learning.kmeans;
- import java.util.ArrayList;
- /**
- * 数据点,有n维数据
- *
- */
- public class Point {
- private static int num;
- private int id;
- private int dimensioNum; // 维度
- private ArrayList<Double> values;
- private int clusterId = -1;
- private double minDist = Integer.MAX_VALUE;
- public Point() {
- id = ++num;
- values = new ArrayList<>();
- }
- public void add(double e) {
- values.add(e);
- dimensioNum++;
- }
- //------set与get省略----------
- }
二、数据簇的实现
- package com.meachine.learning.kmeans;
- import lombok.EqualsAndHashCode;
- import lombok.Getter;
- import lombok.Setter;
- import lombok.ToString;
- /**
- * 簇<br>
- * 数据集合的基本信息
- *
- */
- public class Cluster {
- // 簇id
- private int clusterId;
- // 属于该簇的点的个数
- private int numOfPoints;
- // 簇中心点的信息
- private Point center;
- public Cluster(int id) {
- this.clusterId = id;
- numOfPoints = 0;
- }
- public Cluster(int id, Point center) {
- this.clusterId = id;
- this.center = center;
- }
- //----------set与get省略----------------
- }
三、计算数据点距离
- package com.meachine.learning.kmeans;
- import java.util.List;
- /**
- * 计算距离接口
- *
- */
- public interface IDistance<T> {
- public double getDis(List<T> p1, List<T> p2);
- }
- package com.meachine.learning.kmeans;
- import java.util.List;
- /**
- * 欧式距离
- *
- */
- public class OujilidDistance<T extends Number> implements IDistance<T> {
- public double getDis(List<T> a, List<T> b) {
- if (a.size() != b.size()) {
- throw new IllegalArgumentException("Size not compatible!");
- }
- double result = 0;
- for (int i = 0; i < a.size(); i++) {
- result += Math.pow((a.get(i).doubleValue() - b.get(i).doubleValue()), 2);
- }
- return Math.sqrt(result);
- }
- }
四、K-Means算法
- package com.meachine.learning.kmeans;
- import java.io.BufferedReader;
- import java.io.File;
- import java.io.FileReader;
- import java.io.IOException;
- import java.util.ArrayList;
- import java.util.List;
- import java.util.Random;
- /**
- * K-Means算法
- *
- * @author Cang
- *
- */
- public class KMeans {
- // 簇的个数
- private int k;
- // 维度,即多少个变量
- private int dimensioNum;
- // 最大迭代次数
- private int maxItrNum = 100;
- private IDistance<Double> distance;
- private List<Point> points;
- private List<Cluster> clusters = new ArrayList<Cluster>();
- private String dataFileName = "D:/testSet.txt";
- public KMeans(int k) {
- this.k = k;
- }
- /**
- * 初始化数据
- */
- public void init() {
- points = loadDataSet(dataFileName);
- distance = new OujilidDistance<Double>();
- initCluster();
- }
- /**
- * 加载数据集
- *
- * @param fileName
- * @return
- */
- private List<Point> loadDataSet(String fileName) {
- List<Point> points = new ArrayList<>();
- File file = new File(fileName);
- BufferedReader reader = null;
- try {
- reader = new BufferedReader(new FileReader(file));
- String tempString = null;
- int i = 0;
- while ((tempString = reader.readLine()) != null) {
- Point point = new Point();
- dimensioNum = tempString.split("\t").length;
- for (String data : tempString.split("\t")) {
- point.add(Double.parseDouble(data));
- }
- points.add(point);
- }
- reader.close();
- } catch (IOException e) {
- e.printStackTrace();
- }
- return points;
- }
- /**
- * 初始化簇中心
- *
- * @return
- */
- private void initCluster() {
- Random ran = new Random();
- int id = 0;
- while (id < k) {
- Cluster c = new Cluster(++id);
- int temp = ran.nextInt(points.size());
- c.setCenter(points.get(temp));
- clusters.add(c);
- }
- }
- /**
- * kMeans 具体算法
- */
- public void clustering() {
- boolean finished = false;
- int count = 0;
- while (!finished) {
- // 寻找最近的中心
- finished = true;
- for (Point point : points) {
- for (Cluster cluster : clusters) {
- double minLen = distance.getDis(cluster.getCenter().getValues(),
- point.getValues());
- // 更新最小距离
- if (minLen < point.getMinDist()) {
- if (cluster.getClusterId() != point.getClusterId()) {
- finished = false;
- point.setClusterId(cluster.getClusterId());
- }
- point.setMinDist(minLen);
- }
- }
- }
- System.out.println("Cluster center info:");
- for (Cluster string : clusters) {
- System.out.println(string.getCenter().getValues());
- }
- // 更改中心的位置
- changeCentroids();
- // 超过循环次数,则跳出循环
- if (++count > maxItrNum) {
- finished = true;
- }
- }
- }
- /**
- * 改变簇中心
- */
- private void changeCentroids() {
- for (Cluster cluster : clusters) {
- ArrayList<Double> newCenterValue = new ArrayList<Double>();
- Point newCenterPoint = new Point();
- double result = 0;
- for (int i = 0; i < dimensioNum; i++) {
- for (Point point : points) {
- if (point.getClusterId() == cluster.getClusterId()) {
- result += point.getValues().get(i);
- }
- }
- newCenterValue.add(result / points.size());
- }
- newCenterPoint.setClusterId(cluster.getClusterId());
- newCenterPoint.setValues(newCenterValue);
- cluster.setCenter(newCenterPoint);
- }
- }
- public static void main(String[] args) {
- KMeans kmeans = new KMeans(4);
- kmeans.init();
- kmeans.clustering();
- }
- }
K-Means 算法(Java)的更多相关文章
- k近邻算法-java实现
最近在看<机器学习实战>这本书,因为自己本身很想深入的了解机器学习算法,加之想学python,就在朋友的推荐之下选择了这本书进行学习. 一 . K-近邻算法(KNN)概述 最简单最初级的分 ...
- KNN 与 K - Means 算法比较
KNN K-Means 1.分类算法 聚类算法 2.监督学习 非监督学习 3.数据类型:喂给它的数据集是带label的数据,已经是完全正确的数据 喂给它的数据集是无label的数据,是杂乱无章的,经过 ...
- K-means算法
K-means算法很简单,它属于无监督学习算法中的聚类算法中的一种方法吧,利用欧式距离进行聚合啦. 解决的问题如图所示哈:有一堆没有标签的训练样本,并且它们可以潜在地分为K类,我们怎么把它们划分呢? ...
- k近邻算法的Java实现
k近邻算法是机器学习算法中最简单的算法之一,工作原理是:存在一个样本数据集合,即训练样本集,并且样本集中的每个数据都存在标签,即我们知道样本集中每一数据和所属分类的对应关系.输入没有标签的新数据之后, ...
- KNN算法java实现代码注释
K近邻算法思想非常简单,总结起来就是根据某种距离度量检测未知数据与已知数据的距离,统计其中距离最近的k个已知数据的类别,以多数投票的形式确定未知数据的类别. 一直想自己实现knn的java实现,但限于 ...
- Floyd算法java实现demo
Floyd算法java实现,如下: https://www.cnblogs.com/Halburt/p/10756572.html package a; /** * ┏┓ ┏┓+ + * ┏┛┻━━━ ...
- k-means算法Java一维实现
这里的程序稍微有点变形.k_means方法返回K-means聚类的若干中心点.代码: import java.util.ArrayList; import java.util.Collections; ...
- 感知机学习算法Java实现
感知机学习算法Java实现. Perceptron类用于实现感知机, 其中的perceptronOriginal()方法用于实现感知机学习算法的原始形式: perceptronAnother()方法用 ...
- 一致哈希算法Java实现
一致哈希算法(Consistent Hashing Algorithms)是一个分布式系统中经常使用的算法. 传统的Hash算法当槽位(Slot)增减时,面临全部数据又一次部署的问题.而一致哈希算法确 ...
- 机器学习实战笔记--k近邻算法
#encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...
随机推荐
- mysql导出csv文件excel打开后数字用科学计数法显示且低位变0的解决方法
Excel显示数字时,如果数字大于12位,它会自动转化为科学计数法:如果数字大于15位,它不仅用于科学技术费表示,还会只保留高15位,其他位都变0. Excel打开csv文件时,只要字段值都是数字,它 ...
- oneThink添加成功,返回到当前请求地址!
其实没什么,就一行代码: $this->success('已采纳',$_SERVER['HTTP_REFERER']);
- 最小树形图(poj3164)
Command Network Time Limit: 1000MS Memory Limit: 131072K Total Submissions: 12834 Accepted: 3718 ...
- rest_framework之规范详解 00
接口开发 方式1:缺点:如果有10张表,则需要40个url. urls.py views.py 缺点:如果有10张表,则需要40个url. 接下来就出现了resrful 规范,比较简洁 方式2: ...
- instanceof 用于确定一个 PHP 变量是否属于某一类 class 的实例 , 返回true或者false
<?phpclass MyClass{} class NotMyClass{}$a = new MyClass; var_dump($a instanceof MyClass);var_dump ...
- Service学习笔记
一 什么是Service Service作为安卓四大组件之一,拥有重要的地位.Service和Activity级别相同,只是没有界面,是运行于后台的服务.这个运行“后台”是指不可见,不是指在后台线程中 ...
- #pragma预处理命令详解
#pragma预处理命令 #pragma可以说是C++中最复杂的预处理指令了,下面是最常用的几个#pragma指令: #pragma comment(lib,"XXX.lib") ...
- ubuntu16.04下笔记本自带摄像头编译运行PTAM
ubuntu16.04下笔记本自带摄像头编译运行PTAM 转载请注明链接:https://i.cnblogs.com/EditPosts.aspx?postid=9014147 个人邮箱:feifan ...
- ORA-39006、ORA-39065、ORA-01403、ORA-39097错误解决办法
今天有同事找说是expdp到出数据时报错: 处理方法:sys用户下执行如下语句重新生成DATAPUMP API用到的视图问题就解决了. SQL> @?/rdbms/admin/catmeta.s ...
- python 类 __module__ __class__
__module__ 和 __class__ __module__ 表示当前操作的对象在那个模块 __class__ 表示当前操作的对象的类是什么 创建一个目录lib 在day7 目录下创 ...