Kmeans++算法

Kmeans++算法,主要可以解决初始中心的选择问题,不可解决k的个数问题。

Kmeans++主要思想是选择的初始聚类中心要尽量的远。

做法:

1.    在输入的数据点中随机选一个作为第一个聚类中心。

2.    对于所有数据点,计算它与已有的聚类中心的最小距离D(x)

3.    选择一个数据点作为新增的聚类中心,选择原则:D(x)较大的点被选为聚类中心的概率较大。

4.    重复2~3步骤直到选出k个聚类中心。

5.    运行Kmeans算法。

package com.lfy.main;

import java.util.ArrayList;
import java.util.List;
import java.util.Random; /**
* K均值聚类算法
*/
public class Kmeans {
private int numOfCluster;// 分成多少簇
private int timeOfIteration;// 迭代次数
private int dataSetLength;// 数据集元素个数,即数据集的长度
private ArrayList<float[]> dataSet;// 数据集
private ArrayList<float[]> center;// 质心
private ArrayList<ArrayList<float[]>> cluster; //簇
private ArrayList<Float> sumOfErrorSquare;// 误差平方和
private Random random; /**
* 设置需分组的原始数据集
*
* @param dataSet
*/ public void setDataSet(ArrayList<float[]> dataSet) {
this.dataSet = dataSet;
} /**
* 获取结果分组
*
* @return 结果集
*/ public ArrayList<ArrayList<float[]>> getCluster() {
return cluster;
} /**
* 构造函数,传入需要分成的簇数量
*
* @param numOfCluster
* 簇数量,若numOfCluster<=0时,设置为1,若numOfCluster大于数据源的长度时,置为数据源的长度
*/
public Kmeans(int numOfCluster) {
if (numOfCluster <= 0) {
numOfCluster = 1;
}
this.numOfCluster = numOfCluster;
} /**
* 初始化
*/
private void init() {
timeOfIteration = 0;
random = new Random();
//如果调用者未初始化数据集,则采用内部测试数据集
if (dataSet == null || dataSet.size() == 0) {
initDataSet();
}
dataSetLength = dataSet.size();
//若numOfCluster大于数据源的长度时,置为数据源的长度
if (numOfCluster > dataSetLength) {
numOfCluster = dataSetLength;
}
center = initCenters();
cluster = initCluster();
sumOfErrorSquare = new ArrayList<Float>();
//查看init质心的选取情况
printDataArray(center,"initCenter");
} /**
* 如果调用者未初始化数据集,则采用内部测试数据集
*/
private void initDataSet() {
dataSet = new ArrayList<float[]>();
// 其中{6,3}是一样的,所以长度为15的数据集分成14簇和15簇的误差都为0
float[][] dataSetArray = new float[][] { { 8, 2 }, { 3, 4 }, { 2, 5 },
{ 4, 2 }, { 7, 3 }, { 6, 2 }, { 4, 7 }, { 6, 3 }, { 5, 3 },
{ 6, 3 }, { 6, 9 }, { 1, 6 }, { 3, 9 }, { 4, 1 }, { 8, 6 } }; for (int i = 0; i < dataSetArray.length; i++) {
dataSet.add(dataSetArray[i]);
}
} /**
* 随机选取k个质点
* 初始化中心点,分成多少簇就有多少个中心点
*
* @return 中心点集
*/
private ArrayList<float[]> initCenters() {
ArrayList<float[]> center = new ArrayList<float[]>();
int[] randoms = new int[numOfCluster];
int temp = random.nextInt(dataSetLength);
randoms[0] = temp;
//----------------------
List<Integer> list=new ArrayList<Integer>();
list.add(temp);
//randoms数组中存放dataSet数据集的不同的下标
for (int i = 1; i < numOfCluster; i++) {
// while (true) {
// temp = random.nextInt(dataSetLength);
//
// int j=0;
// for(; j<i; j++){
// if(randoms[j] == temp){
// break;
// }
// }
// //没有与任何一个已经选定的质心重复
// //跳出外层循环,设定一个随机质心
// if (j == i) {
// break;
// }
// }
//----------------------
ArrayList<float[]> ltemp=new ArrayList<float[]>();
//从剩下的点中继续找质点
for (int k = 0; k < dataSetLength; k++) {
//如果该点还没有被选择为质点,则计算它与已有的所有质点的最小距离
if(!list.contains(k)) {
float[] distance = new float[numOfCluster];
for (int j = 0; j < list.size(); j++) {
//某点k到已有中心点的距离
distance[j] = distance(dataSet.get(k), dataSet.get(list.get(j)));
}
int j = minDistance(distance);
float[] f={0,0};
f[0]=k;
f[1]=distance[j];
ltemp.add(f);
}
}
int m=maxDistance(ltemp);
temp=(int) ltemp.get(m)[0];
list.add(temp);
//----------------------
randoms[i] = temp;
} for (int i = 0; i < numOfCluster; i++) {
center.add(dataSet.get(randoms[i]));// 生成初始化中心点集
}
return center;
} /**
* 初始化簇集合
*
* @return 一个分为k簇的空数据的簇集合
*/
private ArrayList<ArrayList<float[]>> initCluster() {
ArrayList<ArrayList<float[]>> cluster = new ArrayList<ArrayList<float[]>>();
for (int i = 0; i < numOfCluster; i++) {
cluster.add(new ArrayList<float[]>());
}
return cluster;
} /**
* 计算两个点之间的距离
*
* @param element
* 点1
* @param center
* 点2
* @return 距离
*/
private float distance(float[] element, float[] center) {
float distance = 0.0f;
float x = element[0] - center[0];
float y = element[1] - center[1];
float z = x * x + y * y;
distance = (float) Math.sqrt(z); return distance;
} /**
* 获取距离集合中最小距离的位置
*
* @param distance
* 距离数组
* @return 最小距离在距离数组中的位置
*/
private int minDistance(float[] distance) {
float minDistance = distance[0];
int minLocation = 0;
for (int i = 1; i < distance.length; i++) {
if (distance[i] <= minDistance) {
minDistance = distance[i];
minLocation = i;
}
}
return minLocation;
} /**
* 获取距离集合中最小距离的最大的位置
*
* @param distance
* 各点最小距离数组
* @return 各点最小距离在距离数组中的位置
*/
private int maxDistance(ArrayList<float[]> distance) {
float[] maxDistance = distance.get(0);
int maxLocation = 0;
for (int i = 1; i < distance.size(); i++) {
if (distance.get(i)[1] >= maxDistance[1]) {
maxDistance = distance.get(i);
maxLocation = i;
}
}
return maxLocation;
} /**
* 核心,将当前元素放到最小距离的簇中
*/
private void clusterSet() {
float[] distance = new float[numOfCluster];
for (int i = 0; i < dataSetLength; i++) {
for (int j = 0; j < numOfCluster; j++) {
//计算数据集点与所有中心点的距离
distance[j] = distance(dataSet.get(i), center.get(j));
}
int j = minDistance(distance);
// 核心,将当前元素放到最小距离中心相关的簇中
cluster.get(j).add(dataSet.get(i));
}
} /**
* 求族中各点到其中心点距离的平方,即误差平方
*
* @param element
* 点1
* @param center
* 点2
* @return 误差平方
*/
private float errorSquare(float[] element, float[] center) {
float x = element[0] - center[0];
float y = element[1] - center[1]; float errSquare = x * x + y * y; return errSquare;
} /**
* 计算一次迭代误差平方和
*/
private void countRule() {
float jcF = 0;
for (int i = 0; i < cluster.size(); i++) {
for (int j = 0; j < cluster.get(i).size(); j++) {
jcF += errorSquare(cluster.get(i).get(j), center.get(i));
}
}
sumOfErrorSquare.add(jcF);
} /**
* 设置新的簇中心方法
*/
private void setNewCenter() {
for (int i = 0; i < numOfCluster; i++) {
int n = cluster.get(i).size();
if (n != 0) {
float[] newCenter = { 0, 0 };
for (int j = 0; j < n; j++) {
newCenter[0] += cluster.get(i).get(j)[0];
newCenter[1] += cluster.get(i).get(j)[1];
}
// 设置一个平均值
newCenter[0] = newCenter[0] / n;
newCenter[1] = newCenter[1] / n;
center.set(i, newCenter);
}
}
printDataArray(center,"newCenter");
} /**
* 打印数据,测试用
*
* @param dataArray
* 数据集
* @param dataArrayName
* 数据集名称
*/
public void printDataArray(ArrayList<float[]> dataArray,
String dataArrayName) {
for (int i = 0; i < dataArray.size(); i++) {
System.out.println("print:" + dataArrayName + "[" + i + "]={"
+ dataArray.get(i)[0] + "," + dataArray.get(i)[1] + "}");
}
System.out.println("===================================");
} /**
* Kmeans算法核心过程方法
*/
private void kmeans() {
init(); // 循环分组,直到误差不变为止
while (true) {
clusterSet(); countRule(); // 误差不变了,分组完成
if (timeOfIteration != 0) {
if (sumOfErrorSquare.get(timeOfIteration) - sumOfErrorSquare.get(timeOfIteration - 1) == 0) {
break;
}
}
//设置各簇新的质心,继续迭代
setNewCenter();
timeOfIteration++;
cluster.clear();
cluster = initCluster();
}
System.out.println("note:the times of repeat:timeOfIteration="+timeOfIteration);//输出迭代次数
} /**
* 执行算法
*/
public void execute() {
long startTime = System.currentTimeMillis();
System.out.println("kmeans begins");
kmeans();
long endTime = System.currentTimeMillis();
System.out.println("kmeans running time=" + (endTime - startTime)
+ "ms");
System.out.println("kmeans ends");
System.out.println();
}
}
package com.lfy.main;

import java.util.ArrayList;

public class KmeansTest {
public static void main(String[] args)
{
//初始化一个Kmean对象,设置k值
Kmeans k=new Kmeans(3);
ArrayList<float[]> dataSet=new ArrayList<float[]>(); dataSet.add(new float[]{3,4});
dataSet.add(new float[]{4,4});
dataSet.add(new float[]{3,3});
dataSet.add(new float[]{4,3});
//
dataSet.add(new float[]{0,2});
dataSet.add(new float[]{1,2});
dataSet.add(new float[]{0,1});
dataSet.add(new float[]{1,1});
//
dataSet.add(new float[]{3,1});
dataSet.add(new float[]{3,0});
dataSet.add(new float[]{5,0});
dataSet.add(new float[]{4,0});
dataSet.add(new float[]{4,1}); //设置原始数据集
k.setDataSet(dataSet);
//执行算法
k.execute();
//得到聚类结果
ArrayList<ArrayList<float[]>> cluster=k.getCluster();
//查看结果
for(int i=0;i<cluster.size();i++)
{
k.printDataArray(cluster.get(i), "cluster["+i+"]");
} }
}

算法 - k-means++的更多相关文章

  1. 第4章 最基础的分类算法-k近邻算法

    思想极度简单 应用数学知识少 效果好(缺点?) 可以解释机器学习算法使用过程中的很多细节问题 更完整的刻画机器学习应用的流程 distances = [] for x_train in X_train ...

  2. KNN 与 K - Means 算法比较

    KNN K-Means 1.分类算法 聚类算法 2.监督学习 非监督学习 3.数据类型:喂给它的数据集是带label的数据,已经是完全正确的数据 喂给它的数据集是无label的数据,是杂乱无章的,经过 ...

  3. ML: 聚类算法-K均值聚类

    基于划分方法聚类算法R包: K-均值聚类(K-means)                   stats::kmeans().fpc::kmeansruns() K-中心点聚类(K-Medoids) ...

  4. 聚类算法:K-means 算法(k均值算法)

    k-means算法:      第一步:选$K$个初始聚类中心,$z_1(1),z_2(1),\cdots,z_k(1)$,其中括号内的序号为寻找聚类中心的迭代运算的次序号. 聚类中心的向量值可任意设 ...

  5. 聚类算法:K均值、凝聚层次聚类和DBSCAN

    聚类分析就仅根据在数据中发现的描述对象及其关系的信息,将数据对象分组(簇).其目标是,组内的对象相互之间是相似的,而不同组中的对象是不同的.组内相似性越大,组间差别越大,聚类就越好. 先介绍下聚类的不 ...

  6. 分类算法——k最近邻算法(Python实现)(文末附工程源代码)

    kNN算法原理 k最近邻(k-Nearest Neighbor)算法是比较简单的机器学习算法.它采用测量不同特征值之间的距离方法进行分类,思想很简单:如果一个样本在特征空间中的k个最近邻(最相似)的样 ...

  7. 【学习笔记】分类算法-k近邻算法

    k-近邻算法采用测量不同特征值之间的距离来进行分类. 优点:精度高.对异常值不敏感.无数据输入假定 缺点:计算复杂度高.空间复杂度高 使用数据范围:数值型和标称型 用例子来理解k-近邻算法 电影可以按 ...

  8. 常见聚类算法——K均值、凝聚层次聚类和DBSCAN比较

    聚类分析就仅根据在数据中发现的描述对象及其关系的信息,将数据对象分组(簇).其目标是,组内的对象相互之间是相似的,而不同组中的对象是不同的.组内相似性越大,组间差别越大,聚类就越好. 先介绍下聚类的不 ...

  9. 分类算法----k近邻算法

    K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一.该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的 ...

  10. 【机器学习】聚类算法——K均值算法(k-means)

    一.聚类 1.基于划分的聚类:k-means.k-medoids(每个类别找一个样本来代表).Clarans 2.基于层次的聚类:(1)自底向上的凝聚方法,比如Agnes (2)自上而下的分裂方法,比 ...

随机推荐

  1. TTTTTTTTTTTTTT hdu 5763 Another Meaning 哈希+dp

    Another Meaning Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K (Java/Others)T ...

  2. AcWing:111. 畜栏预定(贪心 + 小根堆)

    有N头牛在畜栏中吃草. 每个畜栏在同一时间段只能提供给一头牛吃草,所以可能会需要多个畜栏. 给定N头牛和每头牛开始吃草的时间A以及结束吃草的时间B,每头牛在[A,B]这一时间段内都会一直吃草. 当两头 ...

  3. fanout(Publish/Subscribe)发布/订阅

    引言 它是一种通过广播方式发送消息的路由器,所有和exchange建立的绑定关系的队列都会接收到消息 不处理路由键,只需要简单的将队列绑定到交换机上 fanout交换机转发消息是最快的,它不需要做路由 ...

  4. Vue 新手学习笔记:vue-element-admin 之安装,配置及入门开发

    所属专栏: Vue 开发学习进步 说实话都是逼出来的,对于前端没干过ES6都不会的人,vue视频也就看了基础的一些但没办法,接下来做微服务架构,前端就用 vue,这块你负责....说多了都是泪,脚手架 ...

  5. HDU 5831 Rikka with Parenthesis II ——(括号匹配问题)

    用一个temp变量,每次出现左括号,+1,右括号,-1:用ans来记录出现的最小的值,很显然最终temp不等于0或者ans比-2小都是不可以的.-2是可以的,因为:“))((”可以把最左边的和最右边的 ...

  6. 2.微服务开发框架——Spring Cloud

                     微服务开发框架—Spring Cloud 2.1. Spring Cloud简介及其特点 简介: Spring Cloud为开发人员提供了快速构建分布式系统中一些常见 ...

  7. Leetcode题目33.搜索旋转排序数组(中等)

    题目描述: 假设按照升序排序的数组在预先未知的某个点上进行了旋转. ( 例如,数组 [0,1,2,4,5,6,7] 可能变为 [4,5,6,7,0,1,2] ). 搜索一个给定的目标值,如果数组中存在 ...

  8. java 直接内存

    android 内存结构 : dalvik(jvm)内存---navtive men 两部分. 这个概念相信有经验的开发人员都会知道. java虚拟机分配到的内存是有限的,根据手机不同,大小不一,但也 ...

  9. Linux安全工具之fail2ban防爆力破解

    一:简单介绍 fail2ban是一款实用软件,可以监视你的系统日志,然后匹配日志的错误信息(正则式匹配)执行相应的屏蔽动作 在企业中,有些很多人会开放root登录,这样就有机会给黑客造成暴力破解的机会 ...

  10. 如何实现一个串行promise

    异步执行任务A.B.C,...... 1.使用数组的reduce方法,reduce里有四个参数,pre,next,index,arr, 2.如果then方法里返回的是一个promise对象,那么执行下 ...