LVQ聚类与k-means不同之处在于,它是有标记的聚类,设定带标签的k个原型向量(即团簇中心),根据样本标签是否与原型向量的标签一致,对原型向量进行更新。

最后,根据样本到原型向量的距离,对样本进行团簇划分。

伪代码如下:

python实现如下:

1,算法部分

# 学习向量量化LVQ:有标记的聚类
import numpy as np
import random def dis(x,y):
return np.sqrt(np.sum(np.power(x[:-1]-y[:-1],2)))
# lvq算法
def lvq(data,labels,k=4,lr=0.01,epochs=1000,delta=1e-3):
'''
data:np.array,last feature is the label.
labels:1-dimension list or array,label of the data.
k:num_group
lr:learning rate
epochs:max epoch to stop running earlier
delta: max distance for two vectors to be 'equal'.
'''
# 学习向量
q=np.empty(shape=(k,data.shape[-1]),dtype=np.float32)
# 确认是否所有向量更新完了
all_vectors_updated=np.empty(shape=(k,),dtype=np.bool)
num_labels=len(labels)
# 初始化原型向量,从每一类中随机选取样本,如果类别数小于聚类数,循环随机取各类别中的样本
for i in range(k):
q[i]=random.choice(data[data[:,-1]==labels[i%num_labels]])
step=0
while not all_vectors_updated.all() and step<epochs:
# 从样本中随机选取样本,书上是这么写的,为啥不循环,要随机呢?np.random的choice只支持一维
x=random.choice(data)
min_dis=np.inf
index=0
for i in range(k):
distance=dis(x,q[i])
if distance<min_dis:
min_dis=distance
index=i
# 保存更新前向量
temp_q=q[index].copy()
# 如果标签相同,则q更新后接近样本x,否则远离
if x[-1]==q[index][-1]:
q[index][:-1]=q[index][:-1]+lr*(x[:-1]-q[index][:-1])
else:
q[index][:-1]=q[index][:-1]-lr*(x[:-1]-q[index][:-1])
# 更新记录数组
if dis(temp_q,q[index])<delta:
all_vectors_updated[index]=True
step+=1
# 训练完后,样本划分到最近的原型向量簇中
categoried_data=[]
for i in range(k):
categoried_data.append([])
for item in data:
min_dis=np.inf
index=0
for i in range(k):
distance=dis(item,q[i])
if distance<min_dis:
min_dis=distance
index=i
categoried_data[index].append(item)
return q,categoried_data

2,验证、测试

2.1 随机x-y平面上的点,根据y=x将数据划分为2个类别,然后聚类

先看看原始数据分布:

x=np.random.randint(-50,50,size=100)
y=np.random.randint(-50,50,size=100)
x=np.array(list(zip(x,y))) import matplotlib.pyplot as plt
%matplotlib inline plt.plot([item[0] for item in x],[item[1] for item in x],'ro')

处理输入数据:

# y>x:1  y<=x:0
y=np.array([ 1&(item[1]>item[0]) for item in x])
y=np.expand_dims(y,axis=-1)
data=np.concatenate((x,y),axis=1).astype(np.float32)

训练,显示结果

q,categoried_data=lvq(data,np.array([0.,1.]),k=4)

color=['bo','ko','go','co','yo','ro']
for i in range(len(categoried_data)):
data_i=categoried_data[i]
plt.plot([item[0] for item in data_i],[item[1] for item in data_i],color[i])
plt.plot([item[0] for item in q],[item[1] for item in q],color[-1])
plt.show()

这里执行了2次,可以看出与k-means一样,对初值敏感

总结:

根据上图可以看出,聚类的效果是在标记的前提下进行的,即团簇是很少跨过分类边界y=x的。相当于对每一个类别,进行了细分。因为每次训练根据一个样本更新,epochs应该设置大一点。

另外,感觉我这个算法有点问题(不知道是不是没理解好lvq),当团簇数大于分类数时,团簇标记会重叠,这就导致同一个类下的2个团簇,当进行原型向量更新时,可能导致向量靠近另一个团簇的样本。从直觉上看,k-means那种基于多个样本的中心更新看起来更靠谱一些。

手写LVQ(学习向量量化)聚类算法的更多相关文章

  1. 零基础学习Kmeans聚类算法的原理与实现过程

    内容导入: 聚类是无监督学习的典型例子,聚类也能为企业运营中也发挥者巨大的作用,比如我们可以利用聚类对目标用户进行群体分类,把目标群体划分成几个具有明显特征区别的细分群体,从而可以在运营活动中为这些细 ...

  2. 快排算法Java版-每次以最左边的值为基准值手写QuickSort

    如题 手写一份快排算法. 注意, 两边双向找值的时候, 先从最右边起找严格小于基准值的值,再从最左边查找严格大于基准base的值; 并且先右后左的顺序不能反!!这个bug改了好久,233~ https ...

  3. 搞定redis面试--Redis的过期策略?手写一个LRU?

    1 面试题 Redis的过期策略都有哪些?内存淘汰机制都有哪些?手写一下LRU代码实现? 2 考点分析 1)我往redis里写的数据怎么没了? 我们生产环境的redis怎么经常会丢掉一些数据?写进去了 ...

  4. 机器学习:weka中添加自己的分类和聚类算法

    不管是实验室研究机器学习算法或是公司研发,都有需要自己改进算法的时候,下面就说说怎么在weka里增加改进的机器学习算法. 一 添加分类算法的流程 1 编写的分类器必须继承 Classifier或是Cl ...

  5. 4.redis 的过期策略都有哪些?内存淘汰机制都有哪些?手写一下 LRU 代码实现?

    作者:中华石杉 面试题 redis 的过期策略都有哪些?内存淘汰机制都有哪些?手写一下 LRU 代码实现? 面试官心理分析 如果你连这个问题都不知道,上来就懵了,回答不出来,那线上你写代码的时候,想当 ...

  6. C#中调用Matlab人工神经网络算法实现手写数字识别

    手写数字识别实现 设计技术参数:通过由数字构成的图像,自动实现几个不同数字的识别,设计识别方法,有较高的识别率 关键字:二值化  投影  矩阵  目标定位  Matlab 手写数字图像识别简介: 手写 ...

  7. Python 手写数字识别-knn算法应用

    在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点 ...

  8. 在opencv3中实现机器学习算法之:利用最近邻算法(knn)实现手写数字分类

    手写数字digits分类,这可是深度学习算法的入门练习.而且还有专门的手写数字MINIST库.opencv提供了一张手写数字图片给我们,先来看看 这是一张密密麻麻的手写数字图:图片大小为1000*20 ...

  9. TensorFlow 入门之手写识别(MNIST) softmax算法

    TensorFlow 入门之手写识别(MNIST) softmax算法 MNIST flyu6 softmax回归 softmax回归算法 TensorFlow实现softmax softmax回归算 ...

随机推荐

  1. 题解 【HEOI2016】tree树

    题面 解析 其实这题可以考虑离线做法,用并查集解决. 因为仔细想,添加标记并不方便, 但如果用并查集记录下祖先, 再一一删除,就会方便很多. 先把每次操作记录下来, 同时记录下每个点被标记的次数(因为 ...

  2. react创建项目后运行npm run eject命令将配置文件暴露出来时报错解决方法

    最近在用create-react-app创建项目,因要配置各种组件,比如babel,antd等, 需要运行npm run eject命令把项目的配置文件暴露出来,但是还是一如既然碰到报错,因为是在本地 ...

  3. Navicat for MySQL 使用

    库创建标准 表查看sql样式

  4. H5 设计尺寸

    750*1218 微信下 兼容 7plus 内容高度 居中 1000px 内 750*1448 微信下 兼容 iphoneX 微信导航栏高度 64px 64px =  导航栏44+状态栏20 但是现在 ...

  5. FZU 2231 平行四边形数

    FZU - 2231  平行四边形数 题目大意:给你n个点,求能够组成多少个平行四边形? 首先想到的是判断两对边平行且相等,但这样的话得枚举四个顶点,或者把点转换成边然后再枚举所有边相等的麻烦,还不好 ...

  6. noi.ac#458 sequence

    题目链接:戳我 蒟蒻的第一道子序列自动机! 给定两个01串A,B,求一个最短的01串,要求C不是A,B的子序列.要求如果同样短,输出字典序最小的. 那么我们先构建A,B两个串的子序列自动机.然后我们设 ...

  7. 线程系列2--Java线程的互斥技术

    java的多线程互斥主要通过synchronized关键字实现.一个线程就是一个执行线索,多个线程可理解为多个执行线索.进程有独立的内存空间,而进程中的线程则是共享数据对象资源.这样当多个执行线索在C ...

  8. postgres的数据库备份和恢复

    备份和恢复 一条命令就可以解决很简单: 这是备份的命令: pg_dump -h 127/0.0.1 -U postgres databasename > databasename.bak 指令解 ...

  9. kali安装与配置

    闲来没事,把kali虚拟机重新装到了电脑上,记录下步骤 1.在kali官网(https://www.kali.org/downloads/)直接下载的.ova虚拟机,因为之前从官网下载的iso文件不知 ...

  10. Nginx 配置443 HTTPS

    server { listen 443 ssl; server_name localhost; ssl on; ssl_certificate D://newlingshou//nginx-1.12. ...