本人以前主要focus在传统音频的软件开发,接触到的算法主要是音频信号处理相关的,如各种编解码算法和回声消除算法等。最近切到语音识别上,接触到的算法就变成了各种机器学习算法,如GMM等。K-means作为其中比较简单的一种肯定是要好好掌握的。今天就讲讲K-means的基本原理和代码实现。其中基本原理简述(主要是因为:1,K-means比较简单;2,网上有很多讲K-means基本原理的),重点放在代码实现上。

1, K-means基本原理

K均值(K-means)聚类算法是无监督聚类(聚类(clustering)是将数据集中的样本划分为若干个通常是不相交的子集,每个子集称为一个“簇(cluster)”)算法中的一种,也是最常用的聚类算法。K表示类别数,Means表示均值。K-means主要思想是在给定K值和若干样本(点)的情况下,把每个样本(点)分到离其最近的类簇中心点所代表的类簇中,所有点分配完毕之后,根据一个类簇内的所有点重新计算该类簇的中心点(取平均值),然后再迭代的进行分配点和更新类簇中心点的步骤,直至类簇中心点的变化很小,或者达到指定的迭代次数。

K-means算法流程如下:

(a)随机选取K个初始cluster center

(b)分别计算所有样本到这K个cluster center的距离

(c)如果样本离cluster center Ci最近,那么这个样本属于Ci点簇;如果到多个cluster center的距离相等,则可划分到任意簇中

(d)按距离对所有样本分完簇之后,计算每个簇的均值(最简单的方法就是求样本每个维度的平均值),作为新的cluster center

(e)重复(b)(c)(d)直到新的cluster center和上轮cluster center变化很小或者达到指定的迭代次数,算法结束

2, 算法实现

我主要偏底层开发,最熟悉语言是C,所以代码是用C语言来实现的。在二维平面上有一些点,大意如下图,

用K-means算法对其分类,其中类的个数(即K值)和点的个数人为指定。具体的代码如下:

#include<stdio.h>
#include<stdlib.h>
#include<string.h>
#include<math.h>

#define MAX_ROUNDS 100    //最大允许的聚类次数

//“点”的结构体  
typedef struct Point{
  int x_value;           //用于存放点在X轴上的值
  int y_value;           //用于存放点在Y轴上的值
  int cluster_id;        //用于存放该点所属的cluster id
}Point;
Point* data;
 
//cluster center的结构体
typedef struct ClusterCenter{
  double x_value;
  double y_value;
  int cluster_id;
}ClusterCenter;
ClusterCenter* cluster_center;

//计算cluster center的结构体
typedef struct CenterCalc{
  double x_value;
  double y_value;
}CenterCalc;
CenterCalc *center_calc;
 
int is_continue;                               //kmeans 运算是否继续
int* cluster_center_init_index;        //记录每个cluster center最初用的是哪个“点”
double* distance_from_center;      //记录一个“点”到所有cluster center的距离
int* data_size_per_cluster;            //每个cluster点的个数
int data_size_total;                        //设定点的个数
char filename[200];                       //要读取的点的数据的文件名
int cluster_count;                          //设定的cluster的个数
 
void memoryAlloc();
void memoryFree();
void readDataFromFile();
void initialCluster();
void calcDistance2OneCenter(int pointID, int centerID);
void calcDistance2AllCenters(int pointID);
void partition4OnePoint(int pointID);
void partition4AllPointOneCluster();
void calcClusterCenter();
void kmeans();
void compareNewOldClusterCenter(CenterCalc* center_calc);
 
int main(int argc, char* argv[])
{
    if( argc != 4 )
    {
        printf("This application needs 3 parameters to run:"
            "\n the 1st is the size of data set,"
            "\n the 2nd is the file name that contains data"
            "\n the 3rd indicates the cluster_count"
            "\n");
        exit(1);
    }

    data_size_total = atoi(argv[1]);
    strcat(filename, argv[2]);
    cluster_count = atoi(argv[3]);
    //1, memory alloc
    memoryAlloc();
    //2, read point data from file
    readDataFromFile();
    //3, initial cluster
    initialCluster();
    //4, run k-means
    kmeans();
    //5, memory free & end
    memoryFree();
    
    return 0;
}

void memoryAlloc()
{
  data = (Point*)malloc(sizeof(struct Point) * (data_size_total));
  if( !data )
  {
    printf("malloc error:data!");
    exit(1);
  }
  cluster_center_init_index = (int*)malloc(sizeof(int) * (cluster_count));
  if( !cluster_center_init_index )
  {
    printf("malloc error:cluster_center!\n");
    exit(1);
  }
  distance_from_center = (double*)malloc(sizeof(double) * (cluster_count));
  if( !distance_from_center )
  {
    printf("malloc error: distance_from_center!\n");
    exit(1);
  }
  cluster_center = (ClusterCenter*)malloc(sizeof(struct ClusterCenter) * (cluster_count));
  if( !cluster_center )
  {
    printf("malloc cluster center new error!\n");
    exit(1);
  }

  center_calc = (CenterCalc*)malloc(sizeof(CenterCalc) * cluster_count);
  if( !center_calc )
  {
    printf("malloc error: center_calc!\n");
    exit(1);
  }

  data_size_per_cluster = (int*)malloc(sizeof(int) * (cluster_count));
  if( !data_size_per_cluster )
  {
    printf("malloc error: data_size_per_cluster\n");
    exit(1);
  }
 
}

void memoryFree()
{
  free(data);
  data = NULL;
  free(cluster_center_init_index);
  cluster_center_init_index = NULL;
  free(distance_from_center);
  distance_from_center = NULL;
  free(cluster_center);
  cluster_center = NULL;
  free(center_calc);
  center_calc = NULL;
  free(data_size_per_cluster);
  data_size_per_cluster = NULL;
}

//从文件中读入每个点的x和y值
void readDataFromFile()
{
  int i;
  FILE* fread;
 
  if( NULL == (fread = fopen(filename, "r")))
  {
    printf("open file(%s) error!\n", filename);
    exit(1);
  }

  for( i = 0; i < data_size_total; i++ )
  {
    if( 2 != fscanf(fread, "%d %d ", &data[i].x_value, &data[i].y_value))
    {
      printf("fscanf error: %d\n", i);
    }
    data[i].cluster_id = -1;    //初始时每个点所属的cluster id均置为-1

    printf("After reading, point index:%d, X:%d, Y:%d, cluster_id:%d\n", i, data[i].x_value, data[i].y_value, data[i].cluster_id);
  }
}
 

//根据传入的cluster_count来随机的选择一个点作为 一个cluster的center  
void initialCluster()
{
  int i,j;
  int random;
    
  //产生初始化的cluster_count个聚类  
  for( i = 0; i < cluster_count; i++ )
  {
    cluster_center_init_index[i] = -1;
  }
  //随机选择一个点作为每个cluster的center(不重复)
  for( i = 0; i < cluster_count; i++ )
  {
    Reselect:
        random = rand() % (data_size_total - 1);
        for(j = 0; j < i; j++) {
            if(random == cluster_center_init_index[j])
                goto Reselect;
        }

    cluster_center_init_index[i] = random;
    printf("cluster_id: %d, located in point index:%d\n", i, random);  
  }
  //将随机选择的点作为center,同时这个点的cluster id也就确定了
  for( i = 0; i < cluster_count; i++ )
  {
    cluster_center[i].x_value = data[cluster_center_init_index[i]].x_value;
    cluster_center[i].y_value = data[cluster_center_init_index[i]].y_value;
    cluster_center[i].cluster_id = i;
    data[cluster_center_init_index[i]].cluster_id = i;

    printf("cluster_id:%d, index:%d, x_value:%f, y_value:%f\n", cluster_center[i].cluster_id, cluster_center_init_index[i], cluster_center[i].x_value, cluster_center[i].y_value);
  }
}
 

//计算一个点到一个cluster center的distance
void calcDistance2OneCenter(int point_id,int center_id)
{
  distance_from_center[center_id] = sqrt( (data[point_id].x_value-cluster_center[center_id].x_value)*(double)(data[point_id].x_value-cluster_center[center_id].x_value) + (double)(data[point_id].y_value-cluster_center[center_id].y_value) *              (data[point_id].y_value-cluster_center[center_id].y_value) );
}
 
//计算一个点到每个cluster center的distance
void calcDistance2AllCenters(int point_id)
{
  int i;
  for( i = 0; i < cluster_count; i++ )
  {
    calcDistance2OneCenter(point_id, i);
  }
}
 
//确定一个点属于哪一个cluster center(取距离最小的)
void partition4OnePoint(int point_id)
{
  int i;
  int min_index = 0;
  double min_value = distance_from_center[0];
  for( i = 0; i < cluster_count; i++ )
  {
    if( distance_from_center[i] < min_value )
    {
      min_value = distance_from_center[i];
      min_index = i;
    }
  }
 
  data[point_id].cluster_id = cluster_center[min_index].cluster_id;
}

//在一轮的聚类中得到所有的point所属于的cluster center
void partition4AllPointOneCluster()
{
  int i;
  for( i = 0; i < data_size_total; i++ )
  {
    if( data[i].cluster_id != -1 )  //这个点就是center,不需要计算
      continue;
    else
    {
      calcDistance2AllCenters(i);  //计算第i个点到所有center的distance
      partition4OnePoint(i);          //根据distance对第i个点进行partition
    }
  }
}

//重新计算新的cluster center
void calcClusterCenter()
{
  int i;

  memset(center_calc, 0, sizeof(CenterCalc) * cluster_count);
  memset(data_size_per_cluster, 0, sizeof(int) * cluster_count);
  //分别对每个cluster内的每个点的X和Y求和,并计每个cluster内点的个数
  for( i = 0; i < data_size_total; i++ )
  {
    center_calc[data[i].cluster_id].x_value += data[i].x_value;
    center_calc[data[i].cluster_id].y_value += data[i].y_value;
    data_size_per_cluster[data[i].cluster_id]++;
  }
  //计算每个cluster内点的X和Y的均值作为center
  for( i = 0; i < cluster_count; i++ )
  {
     if(data_size_per_cluster[i] != 0) {
        center_calc[i].x_value = center_calc[i].x_value/ (double)(data_size_per_cluster[i]);
        center_calc[i].y_value = center_calc[i].y_value/ (double)(data_size_per_cluster[i]);

printf(" cluster %d point cnt:%d\n", i, data_size_per_cluster[i]);
        printf(" cluster %d center: X:%f, Y:%f\n", i, center_calc[i].x_value, center_calc[i].y_value);
    }
    else
          printf(" cluster %d count is zero\n", i);
  }
 
  //比较新的和旧的cluster center值的差别。如果是相等的,则停止K-means算法。
  compareNewOldClusterCenter(center_calc);
 
  //将新的cluster center的值放入cluster_center结构体中
  for( i = 0; i < cluster_count; i++ )
  {
    cluster_center[i].x_value = center_calc[i].x_value;
    cluster_center[i].y_value = center_calc[i].y_value;
    cluster_center[i].cluster_id = i;
  }

  //在重新计算了新的cluster center之后,要重新来为每一个Point进行聚类,所以data中用于表示聚类ID的cluster_id要都重新置为-1。
  for( i = 0; i < data_size_total; i++ )
  {
    data[i].cluster_id = -1;
  }
}
 
//比较新旧的cluster center的值,完全一样表示聚类完成
void compareNewOldClusterCenter(CenterCalc* center_calc)
{
  int i;
  is_continue = 0;       //等于0表示不要继续,1表示要继续
  for( i = 0; i < cluster_count; i++ )
  {
    if( center_calc[i].x_value != cluster_center[i].x_value || center_calc[i].y_value != cluster_center[i].y_value)
    {
      is_continue = 1;
      break;
    }
  }
}
 
//K-means算法
void kmeans()
{
  int rounds;
  for( rounds = 0; rounds < MAX_ROUNDS; rounds++ )
  {
    printf("\nRounds : %d             \n", rounds+1);
    partition4AllPointOneCluster();
    calcClusterCenter();
    if( 0 == is_continue )
    {
       printf("\n after %d rounds, the classification is ok and can stop.\n", rounds+1);
       break;  
    }
  }
}

编译后生成可执行文件kmeans,输入的文件里共有6个点,分别为(0, 0), (4, 4), (4, 5), (0, 1), (3, 6) ,(4, 9),要求分成两类。运行可执行程序后得到结果如下:

$ ./kmeans 6 data 2
After reading, point index:0, X:0, Y:0, cluster_id:-1
After reading, point index:1, X:4, Y:4, cluster_id:-1
After reading, point index:2, X:4, Y:5, cluster_id:-1
After reading, point index:3, X:0, Y:1, cluster_id:-1
After reading, point index:4, X:3, Y:6, cluster_id:-1
After reading, point index:5, X:4, Y:9, cluster_id:-1

cluster_id: 0, located in point index:3
cluster_id: 1, located in point index:1
cluster_id:0, index:3, x_value:0.000000, y_value:1.000000
cluster_id:1, index:1, x_value:4.000000, y_value:4.000000

Rounds : 1             
 cluster 0 point cnt:2
 cluster 0 center: X:0.000000, Y:0.500000
 cluster 1 point cnt:4
 cluster 1 center: X:3.750000, Y:6.000000

Rounds : 2             
 cluster 0 point cnt:2
 cluster 0 center: X:0.000000, Y:0.500000
 cluster 1 point cnt:4
 cluster 1 center: X:3.750000, Y:6.000000

 after 2 rounds, the classification is ok and can stop.

即两轮后聚类就好了,(0, 0),(0, 1)一类,(4, 4), (4, 5), (3, 6) ,(4, 9)一类。

机器学习中K-means聚类算法原理及C语言实现的更多相关文章

  1. 机器学习实战---K均值聚类算法

    一:一般K均值聚类算法实现 (一)导入数据 import numpy as np import matplotlib.pyplot as plt def loadDataSet(filename): ...

  2. 【机器学习】:Kmeans均值聚类算法原理(附带Python代码实现)

    这个算法中文名为k均值聚类算法,首先我们在二维的特殊条件下讨论其实现的过程,方便大家理解. 第一步.随机生成质心 由于这是一个无监督学习的算法,因此我们首先在一个二维的坐标轴下随机给定一堆点,并随即给 ...

  3. k均值聚类算法原理和(TensorFlow)实现

    顾名思义,k均值聚类是一种对数据进行聚类的技术,即将数据分割成指定数量的几个类,揭示数据的内在性质及规律. 我们知道,在机器学习中,有三种不同的学习模式:监督学习.无监督学习和强化学习: 监督学习,也 ...

  4. Kmeans聚类算法原理与实现

    Kmeans聚类算法 1 Kmeans聚类算法的基本原理 K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一.K-means算法的基本思想是:以空间中k个点为中心进行聚类,对 ...

  5. 【转】K-Means聚类算法原理及实现

    k-means 聚类算法原理: 1.从包含多个数据点的数据集 D 中随机取 k 个点,作为 k 个簇的各自的中心. 2.分别计算剩下的点到 k 个簇中心的相异度,将这些元素分别划归到相异度最低的簇.两 ...

  6. K均值聚类算法

    k均值聚类算法(k-means clustering algorithm)是一种迭代求解的聚类分析算法,其步骤是随机选取K个对象作为初始的聚类中心,然后计算每个对象与各个种子聚类中心之间的距离,把每个 ...

  7. 机器学习之K均值聚类

      聚类的核心概念是相似度或距离,有很多相似度或距离的方法,比如欧式距离.马氏距离.相关系数.余弦定理.层次聚类和K均值聚类等 1. K均值聚类思想   K均值聚类的基本思想是,通过迭代的方法寻找K个 ...

  8. 100天搞定机器学习|day44 k均值聚类数学推导与python实现

    [如何正确使用「K均值聚类」? 1.k均值聚类模型 给定样本,每个样本都是m为特征向量,模型目标是将n个样本分到k个不停的类或簇中,每个样本到其所属类的中心的距离最小,每个样本只能属于一个类.用C表示 ...

  9. OPTICS聚类算法原理

    OPTICS聚类算法原理 基础 OPTICS聚类算法是基于密度的聚类算法,全称是Ordering points to identify the clustering structure,目标是将空间中 ...

随机推荐

  1. WPF圆形环绕的Loading动画

    原文:WPF圆形环绕的Loading动画 版权声明:本文为博主原创文章,未经博主允许不得转载. https://blog.csdn.net/yangyisen0713/article/details/ ...

  2. Carthage 包管理工具,另一种敏捷轻快的 iOS & MAC 开发体验 | SwiftCafe 咖啡时光

    说起 iOS 开发的包管理,大家就不由得会想起 CocoaPods, 它确实是一个强大的工具.但这次咱们来关注另外一个包管理工具 Carthage,如果说 CocoaPods 像一个航母,一应俱全,坚 ...

  3. dotnet pack 打包文件版本号引起 "Could not load file or assembly" 问题

    如果不是遇到,真的不会想到,代码世界的问题真是千奇百怪,这次遇到的是 dotnet pack 打包文件版本号引起的问题. 之前进行 nuget 打包都是在 Visual Studio build 时进 ...

  4. Httpclient Fluent API简单封装

    import java.io.IOException;import java.util.ArrayList;import java.util.HashMap;import java.util.List ...

  5. Hutool 3.0.8 发布,Java 工具集

    Hutool 是一个Java工具包,提供了丰富的文件.日期.日志.正则.字符串.配置文件等工具方法,并封装了一套简单易用的ORM框架. 主页:http://hutool.cn/ 文档:http://h ...

  6. C# 事件详解

    1.事件的本质是什么 答:事件是委托的包装器,就像属性是字段的包装器一样 2.为什么有了委托还有有事件 委托可以被访问就可以被执行,事件则只能在类的内部执行 3.事件要怎么声明 a.明一个委托 //委 ...

  7. Redis实现关注关系

    最近使用关系型数据库实现了用户之间的关注,于是思考换一种思路,使用Redis实现用户之间的关注关系. 综合考虑了一下Redis的几种数据结构后,觉得可以用集合实现一下. 假设"我" ...

  8. shell中select、case的使用

    case和select结构在技术上说并不是循环, 因为它们并不对可执行代码块进行迭代. 但是和循环相似的是, 它们也依靠在代码块顶部或底部的条件判断来决定程序的分支. select   select结 ...

  9. 使用VC2005编译真正的静态Qt程序 good

    首先,你应该该知道什么叫静态引用编译.什么叫动态引用编译.我这里只是简单的提提,具体的可以google一下. 动态引用编译,是指相关的库,以dll的形式引用库.动态编译的Exe程序尺寸比较小,因为相关 ...

  10. LFTP 4.6.2 发布,命令行 FTP 工具。这个东东可以用来做插件

    直击现场  这个东东可以用来做插件 LFTP 4.6.2 发布,新增特征如下: * new command "edit" instead of the edit alias.* n ...