KNN分类算法--python实现
一、kNN算法分析
K最近邻(k-Nearest Neighbor,KNN)分类算法可以说是最简单的机器学习算法了。它采用测量不同特征值之间的距离方法进行分类。它的思想很简单:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别。
KNN算法中,所选择的邻居都是已经正确分类的对象。该方法在定类决策上只依据最邻近的一个或者几个样本的类别来决定待分样本所属的类别。由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重叠较多的待分样本集来说,KNN方法较其他方法更为适合。
该算法在分类时有个主要的不足是,当样本不平衡时,如一个类的样本容量很大,而其他类样本容量很小时,有可能导致当输入一个新样本时,该样本的K个邻居中大容量类的样本占多数。因此可以采用权值的方法(和该样本距离小的邻居权值大)来改进。该方法的另一个不足之处是计算量较大,因为对每一个待分类的文本都要计算它到全体已知样本的距离,才能求得它的K个最近邻点。目前常用的解决方法是事先对已知样本点进行剪辑,事先去除对分类作用不大的样本。该算法比较适用于样本容量比较大的类域的自动分类,而那些样本容量较小的类域采用这种算法比较容易产生误分[参考机器学习十大算法]。
总的来说就是我们已经存在了一个带标签的数据库,然后输入没有标签的新数据后,将新数据的每个特征与样本集中数据对应的特征进行比较,然后算法提取样本集中特征最相似(最近邻)的分类标签。一般来说,只选择样本数据库中前k个最相似的数据。最后,选择k个最相似数据中出现次数最多的分类。其算法描述如下:
1)计算已知类别数据集中的点与当前点之间的距离;
2)按照距离递增次序排序;
3)选取与当前点距离最小的k个点;
4)确定前k个点所在类别的出现频率;
5)返回前k个点出现频率最高的类别作为当前点的预测分类。
代码:
#########################################
# kNN: k Nearest Neighbors # Input: inX: vector to compare to existing dataset (1xN)
# dataSet: size m data set of known vectors (NxM)
# labels: data set labels (1xM vector)
# k: number of neighbors to use for comparison # Output: the most popular class label
######################################### from numpy import *
import operator
import os
from Canvas import Line # classify using kNN
def kNNClassify(newInput, dataSet, labels, k):
numSamples = dataSet.shape[0] # shape[0] stands for the num of row ## step 1: calculate Euclidean distance
# tile(A, reps): Construct an array by repeating A reps times
# the following copy numSamples rows for dataSet
diff = tile(newInput, (numSamples, 1)) - dataSet # Subtract element-wise
squaredDiff = diff ** 2 # squared for the subtract
squaredDist = sum(squaredDiff, axis = 1) # sum is performed by row
distance = squaredDist ** 0.5 ## step 2: sort the distance
# argsort() returns the indices that would sort an array in a ascending order
sortedDistIndices = argsort(distance) classCount = {} # define a dictionary (can be append element)
for i in xrange(k):
## step 3: choose the min k distance
voteLabel = labels[sortedDistIndices[i]] ## step 4: count the times labels occur
# when the key voteLabel is not in dictionary classCount, get()
# will return 0
classCount[voteLabel] = classCount.get(voteLabel, 0) + 1 ## step 5: the max voted class will return
maxCount = 0
for key, value in classCount.items():
if value > maxCount:
maxCount = value
maxIndex = key return maxIndex # convert image to vector
def img2vector(filename):
rows = 32
cols = 32
imgVector = zeros((1, rows * cols)) fileIn = open(filename)
for row in xrange(rows):
lineStr = fileIn.readline()
for col in xrange(cols):
imgVector[0, row * 32 + col] = int(lineStr[col]) return imgVector # load dataSet
def loadDataSet():
## step 1: Getting training set
print "---Getting training set..."
dataSetDir = 'F:/eclipse/workspace/KnnTest/'
trainingFileList = os.listdir(dataSetDir + 'trainingDigits') # load the training set
numSamples = len(trainingFileList) train_x = zeros((numSamples, 1024))
train_y = []
for i in xrange(numSamples):
filename = trainingFileList[i] # get train_x
train_x[i, :] = img2vector(dataSetDir + 'trainingDigits/%s' % filename) # get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
train_y.append(label) ## step 2: Getting testing set
print "---Getting testing set..."
testingFileList = os.listdir(dataSetDir + 'testDigits') # load the testing set
numSamples = len(testingFileList)
test_x = zeros((numSamples, 1024))
test_y = []
for i in xrange(numSamples):
filename = testingFileList[i] # get train_x
test_x[i, :] = img2vector(dataSetDir + 'testDigits/%s' % filename) # get label from file name such as "1_18.txt"
label = int(filename.split('_')[0]) # return 1
test_y.append(label) return train_x, train_y, test_x, test_y # test hand writing class
def testHandWritingClass():
## step 1: load data
print "step 1: load data..."
train_x, train_y, test_x, test_y = loadDataSet() ## step 2: training...
print "step 2: training..."
pass ## step 3: testing
print "step 3: testing..."
numTestSamples = test_x.shape[0]
matchCount = 0
for i in xrange(numTestSamples):
predict = kNNClassify(test_x[i], train_x, train_y, 3)
if predict == test_y[i]:
matchCount += 1
accuracy = float(matchCount) / numTestSamples ## step 4: show the result
print "step 4: show the result..."
print 'The classify accuracy is: %.2f%%' % (accuracy * 100)
另外创建一个脚本knnTest.py
import KNN
KNN.testHandWritingClass()
其中数据集下载链接为:http://download.csdn.net/detail/zouxy09/6610571
KNN分类算法--python实现的更多相关文章
- KNN分类算法及python代码实现
KNN分类算法(先验数据中就有类别之分,未知的数据会被归类为之前类别中的某一类!) 1.KNN介绍 K最近邻(k-Nearest Neighbor,KNN)分类算法是最简单的机器学习算法. 机器学习, ...
- KNN分类算法实现手写数字识别
需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...
- knn分类算法学习
K最近邻(k-Nearest Neighbor,KNN)分类算法,是一个理论上比较成熟的方法,也是最简单的机器学习算法之一.该方法的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的 ...
- 机器学习---K最近邻(k-Nearest Neighbour,KNN)分类算法
K最近邻(k-Nearest Neighbour,KNN)分类算法 1.K最近邻(k-Nearest Neighbour,KNN) K最近邻(k-Nearest Neighbour,KNN)分类算法, ...
- 后端程序员之路 12、K最近邻(k-Nearest Neighbour,KNN)分类算法
K最近邻(k-Nearest Neighbour,KNN)分类算法,是最简单的机器学习算法之一.由于KNN方法主要靠周围有限的邻近的样本,而不是靠判别类域的方法来确定所属类别的,因此对于类域的交叉或重 ...
- KNN分类算法
K邻近算法.K最近邻算法.KNN算法(k-Nearest Neighbour algorithm):是数据挖掘分类技术中最简单的方法之一 KNN的工作原理 所谓K最近邻,就是k个最近的邻居的意思,说的 ...
- 在Ignite中使用k-最近邻(k-NN)分类算法
在本系列前面的文章中,简单介绍了一下Ignite的线性回归算法,下面会尝试另一个机器学习算法,即k-最近邻(k-NN)分类.该算法基于对象k个最近邻中最常见的类来对对象进行分类,可用于确定类成员的关系 ...
- K-NN(最近邻分类算法 python
# algorithm:K-NN(最近邻分类算法)# author:Kermit.L# time: 2016-8-7 #======================================== ...
- kNN分类算法实例1:用kNN改进约会网站的配对效果
目录 实战内容 用sklearn自带库实现kNN算法分类 将内含非数值型的txt文件转化为csv文件 用sns.lmplot绘图反映几个特征之间的关系 参考资料 @ 实战内容 海伦女士一直使用在线约会 ...
随机推荐
- String,StringBuilder性能对照
import java.util.Date; import java.util.UUID; /** * 測试String,StringBuilder性能,推断什么时候改用String,什么时候该用S ...
- HMM MEMM CRF 差别 联系
声明:本文主要是基于网上的材料做了文字编辑,原创部分甚少.參考资料见最后. 隐马尔可夫模型(Hidden Markov Model.HMM),最大熵马尔可夫模型(Maximum Entropy Mar ...
- 使用kubernetes 官网工具kubeadm部署kubernetes(使用阿里云镜像)
系列目录 kubernetes简介 Kubernetes节点架构图: kubernetes组件架构图: 准备基础环境 我们将使用kubeadm部署3个节点的 Kubernetes Cluster,整体 ...
- vim tips 集锦
删除文件中的空行 :g/^$/d g 表示 global,全文件 ^ 是行开始,$ 是行结束 d 表示删除该 这里只能匹配到没有白空符的空行,假如要删除有空白符的空行,则使用: :g/^\s*$/d ...
- 一起学android之怎样卸载指定的 应用程序(25)
效果图例如以下: 代码例如以下: public class MainActivity extends Activity { private Button btn_delete; @Override p ...
- 记使用WaitGroup时的一个错误
记使用WaitGroup时的一个错误 近期重构我之前写的server代码时,不当使用了WaitGroup,碰到了个错误,记录下. package main import ( "fmt&quo ...
- 更改UISearchBar button属性
//点击搜索框时触发 - (BOOL)searchBarShouldBeginEditing:(UISearchBar *)searchBar { self.theSearchUserSearchBa ...
- asp.net mvc4 之Webapi之客户端或服务器端安全控制
一.WebAPI的工作方式 WebAPI的工作方式:HTTP的请求最先是被传递到HOST中的,如果WebAPI是被寄宿在IIS上的,这个HOST就是IIS上,HOST是没有能力也没有必 要进行请求的处 ...
- 新版的Spring4X怎样下载
点击下载 <a href="http://download.csdn.net/detail/zhaoqingkaitt/7733719">点击免费下载</a> ...
- 头文件---#include<***.h>和#include"***.h"的区别
采用"< >"方式进行包含的头文件表示让编译器在编译器的预设标准路径下去搜索相应的头文件,如果找不到则报错. 例如:VS的安装目录\Microsoft Visual S ...