tensorflow kmeans 聚类
iris:
# -*- coding: utf-8 -*-
# K-means with TensorFlow
#----------------------------------
#
# This script shows how to do k-means with TensorFlow import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from sklearn import datasets
from scipy.spatial import cKDTree
from sklearn.decomposition import PCA
from sklearn.preprocessing import scale
from tensorflow.python.framework import ops
ops.reset_default_graph() sess = tf.Session() iris = datasets.load_iris() num_pts = len(iris.data)
num_feats = len(iris.data[0]) # Set k-means parameters
# There are 3 types of iris flowers, see if we can predict them
k=3
generations = 25 data_points = tf.Variable(iris.data)
cluster_labels = tf.Variable(tf.zeros([num_pts], dtype=tf.int64)) # Randomly choose starting points
rand_starts = np.array([iris.data[np.random.choice(len(iris.data))] for _ in range(k)]) centroids = tf.Variable(rand_starts) # In order to calculate the distance between every data point and every centroid, we
# repeat the centroids into a (num_points) by k matrix.
centroid_matrix = tf.reshape(tf.tile(centroids, [num_pts, 1]), [num_pts, k, num_feats])
# Then we reshape the data points into k (3) repeats
point_matrix = tf.reshape(tf.tile(data_points, [1, k]), [num_pts, k, num_feats])
distances = tf.reduce_sum(tf.square(point_matrix - centroid_matrix), axis=2) #Find the group it belongs to with tf.argmin()
centroid_group = tf.argmin(distances, 1) # Find the group average
def data_group_avg(group_ids, data):
# Sum each group
sum_total = tf.unsorted_segment_sum(data, group_ids, 3)
# Count each group
num_total = tf.unsorted_segment_sum(tf.ones_like(data), group_ids, 3)
# Calculate average
avg_by_group = sum_total/num_total
return(avg_by_group) means = data_group_avg(centroid_group, data_points) update = tf.group(centroids.assign(means), cluster_labels.assign(centroid_group)) init = tf.global_variables_initializer() sess.run(init) for i in range(generations):
print('Calculating gen {}, out of {}.'.format(i, generations))
_, centroid_group_count = sess.run([update, centroid_group])
group_count = []
for ix in range(k):
group_count.append(np.sum(centroid_group_count==ix))
print('Group counts: {}'.format(group_count)) [centers, assignments] = sess.run([centroids, cluster_labels]) # Find which group assignments correspond to which group labels
# First, need a most common element function
def most_common(my_list):
return(max(set(my_list), key=my_list.count)) label0 = most_common(list(assignments[0:50]))
label1 = most_common(list(assignments[50:100]))
label2 = most_common(list(assignments[100:150])) group0_count = np.sum(assignments[0:50]==label0)
group1_count = np.sum(assignments[50:100]==label1)
group2_count = np.sum(assignments[100:150]==label2) accuracy = (group0_count + group1_count + group2_count)/150. print('Accuracy: {:.2}'.format(accuracy)) # Also plot the output
# First use PCA to transform the 4-dimensional data into 2-dimensions
pca_model = PCA(n_components=2)
reduced_data = pca_model.fit_transform(iris.data)
# Transform centers
reduced_centers = pca_model.transform(centers) # Step size of mesh for plotting
h = .02 # Plot the decision boundary. For that, we will assign a color to each
x_min, x_max = reduced_data[:, 0].min() - 1, reduced_data[:, 0].max() + 1
y_min, y_max = reduced_data[:, 1].min() - 1, reduced_data[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) # Get k-means classifications for the grid points
xx_pt = list(xx.ravel())
yy_pt = list(yy.ravel())
xy_pts = np.array([[x,y] for x,y in zip(xx_pt, yy_pt)])
mytree = cKDTree(reduced_centers)
dist, indexes = mytree.query(xy_pts) # Put the result into a color plot
indexes = indexes.reshape(xx.shape)
plt.figure(1)
plt.clf()
plt.imshow(indexes, interpolation='nearest',
extent=(xx.min(), xx.max(), yy.min(), yy.max()),
cmap=plt.cm.Paired,
aspect='auto', origin='lower') # Plot each of the true iris data groups
symbols = ['o', '^', 'D']
label_name = ['Setosa', 'Versicolour', 'Virginica']
for i in range(3):
temp_group = reduced_data[(i*50):(50)*(i+1)]
plt.plot(temp_group[:, 0], temp_group[:, 1], symbols[i], markersize=10, label=label_name[i])
# Plot the centroids as a white X
plt.scatter(reduced_centers[:, 0], reduced_centers[:, 1],
marker='x', s=169, linewidths=3,
color='w', zorder=10)
plt.title('K-means clustering on Iris Dataset\n'
'Centroids are marked with white cross')
plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.legend(loc='lower right')
plt.show()

tensorflow kmeans 聚类的更多相关文章
- 用 TensorFlow 实现 k-means 聚类代码解析
k-means 是聚类中比较简单的一种.用这个例子说一下感受一下 TensorFlow 的强大功能和语法. 一. TensorFlow 的安装 按照官网上的步骤一步一步来即可,我使用的是 virtua ...
- K-Means 聚类算法
K-Means 概念定义: K-Means 是一种基于距离的排他的聚类划分方法. 上面的 K-Means 描述中包含了几个概念: 聚类(Clustering):K-Means 是一种聚类分析(Clus ...
- 用scikit-learn学习K-Means聚类
在K-Means聚类算法原理中,我们对K-Means的原理做了总结,本文我们就来讨论用scikit-learn来学习K-Means聚类.重点讲述如何选择合适的k值. 1. K-Means类概述 在sc ...
- K-Means聚类算法原理
K-Means算法是无监督的聚类算法,它实现起来比较简单,聚类效果也不错,因此应用很广泛.K-Means算法有大量的变体,本文就从最传统的K-Means算法讲起,在其基础上讲述K-Means的优化变体 ...
- K-means聚类算法
聚类分析(英语:Cluster analysis,亦称为群集分析) K-means也是聚类算法中最简单的一种了,但是里面包含的思想却是不一般.最早我使用并实现这个算法是在学习韩爷爷那本数据挖掘的书中, ...
- k-means聚类算法python实现
K-means聚类算法 算法优缺点: 优点:容易实现缺点:可能收敛到局部最小值,在大规模数据集上收敛较慢使用数据类型:数值型数据 算法思想 k-means算法实际上就是通过计算不同样本间的距离来判断他 ...
- K-Means 聚类算法原理分析与代码实现
前言 在前面的文章中,涉及到的机器学习算法均为监督学习算法. 所谓监督学习,就是有训练过程的学习.再确切点,就是有 "分类标签集" 的学习. 现在开始,将进入到非监督学习领域.从经 ...
- Kmeans聚类算法原理与实现
Kmeans聚类算法 1 Kmeans聚类算法的基本原理 K-means算法是最为经典的基于划分的聚类方法,是十大经典数据挖掘算法之一.K-means算法的基本思想是:以空间中k个点为中心进行聚类,对 ...
- 机器学习六--K-means聚类算法
机器学习六--K-means聚类算法 想想常见的分类算法有决策树.Logistic回归.SVM.贝叶斯等.分类作为一种监督学习方法,要求必须事先明确知道各个类别的信息,并且断言所有待分类项都有一个类别 ...
随机推荐
- WPF自定义依赖集合属性无法触发更新的问题
通常WPF中通过继承UserControl的来快速创建自定义控件,最近项目上需要设计一个卫星星图显示控件,最终效果如下图所示.完成过程中遇到了自定义集合依赖属性无法触发更新通知的问题,在此记录一下,方 ...
- python实现区块链代码
如果你明白了原理其实挺简单的. 加密算法是python自带的 需要导入hashlib import hashlib as hash sha = hasher.sha256() sha.update(' ...
- vim 模式查找
1. / 正向查找, ?反向查找 2. \v 激活very magic搜索模式,撰写正则表达式更接近于perl的正则表达式,大多数字符不需要进行转义 3. \V 激活noVeryMagic模式,按字符 ...
- AndroidX86模拟器Genymotion的一些使用和另一款Andy模拟器
命令行启动虚拟机 当我们下载安装好,可以通过命令行运行指定名字模拟器 D:\ProgramFiles\Genymobile\Genymotion\player --vm-name "Sam ...
- SQL Server中排名函数row_number,rank,dense_rank,ntile详解
SQL Server中排名函数row_number,rank,dense_rank,ntile详解 从SQL SERVER2005开始,SQL SERVER新增了四个排名函数,分别如下:1.row_n ...
- HDFS源码分析心跳汇报之BPServiceActor工作线程运行流程
在<HDFS源码分析心跳汇报之数据结构初始化>一文中,我们了解到HDFS心跳相关的BlockPoolManager.BPOfferService.BPServiceActor三者之间的关系 ...
- 大数据:Hive常用参数调优
1.limit限制调整 一般情况下,Limit语句还是需要执行整个查询语句,然后再返回部分结果. 有一个配置属性可以开启,避免这种情况---对数据源进行抽样 hive.limit.optimize.e ...
- matlab biplot 符号的困惑
在matlab中做Principal component Analysis 时,常要用biplot 函数来画图,表示原分量与主分量(principal component)之间的关系,以及原始观察数据 ...
- 【C语言】一句printf代码——{ a[0] ? 0[a] }
这是前段时间做的http://fun.coolshell.cn/中的一道题,很有意思,涉及的其实是C的基础,不过当时第一次看见这行代码确实把我弄懵了: printf(&unix["\ ...
- Android UI开发第四十三篇——使用Property Animation实现墨迹天气3.0引导界面及动画实现
前面写过<墨迹天气3.0引导界面及动画实现>,里面完美实现了动画效果,那一篇文章使用的View Animation,这一篇文章使用的Property Animation实现.Propert ...