'''
Created on Nov 06, 2017
kNN: k Nearest Neighbors Input: inX: vector to compare to existing dataset (1xN)
dataSet: size m data set of known vectors (NxM)
labels: data set labels (1xM vector)
k: number of neighbors to use for comparison (should be an odd number) Output: the most popular class label @author: Liu Chuanfeng
'''
import operator
import numpy as np
import matplotlib.pyplot as plt
from os import listdir def classify0(inX, dataSet, labels, k):
dataSetSize = dataSet.shape[0]
diffMat = np.tile(inX, (dataSetSize,1)) - dataSet
sqDiffMat = diffMat ** 2
sqDistances = sqDiffMat.sum(axis=1)
distances = sqDistances ** 0.5
sortedDistIndicies = distances.argsort()
classCount = {}
for i in range(k):
voteIlabel = labels[sortedDistIndicies[i]]
classCount[voteIlabel] = classCount.get(voteIlabel,0) + 1
sortedClassCount = sorted(classCount.items(), key = operator.itemgetter(1), reverse = True)
return sortedClassCount[0][0] #数据预处理,将文件中数据转换为矩阵类型
def file2matrix(filename):
fr = open(filename)
arrayLines = fr.readlines()
numberOfLines = len(arrayLines)
returnMat = np.zeros((numberOfLines, 3))
classLabelVector = []
index = 0
for line in arrayLines:
line = line.strip()
listFromLine = line.split('\t')
returnMat[index,:] = listFromLine[0:3]
classLabelVector.append(int(listFromLine[-1]))
index += 1
return returnMat, classLabelVector #数据归一化处理:由于矩阵各列数据取值范围的巨大差异导致各列对计算结果的影响大小不一,需要归一化以保证相同的影响权重
def autoNorm(dataSet):
maxVals = dataSet.max(0)
minVals = dataSet.min(0)
ranges = maxVals - minVals
m = dataSet.shape[0]
normDataSet = (dataSet - np.tile(minVals, (m, 1))) / np.tile(ranges, (m, 1))
return normDataSet, ranges, minVals #约会网站测试代码
def datingClassTest():
hoRatio = 0.10
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
m = normMat.shape[0]
numTestVecs = int(m * hoRatio)
errorCount = 0.0
for i in range(numTestVecs):
classifyResult = classify0(normMat[i,:], normMat[numTestVecs:m, :], datingLabels[numTestVecs:m], 3)
print('theclassifier came back with: %d, the real answer is: %d' % (classifyResult, datingLabels[i]))
if ( classifyResult != datingLabels[i]):
errorCount += 1.0
print ('the total error rate is: %.1f%%' % (errorCount/float(numTestVecs) * 100)) #约会网站预测函数
def classifyPerson():
resultList = ['not at all', 'in small doses', 'in large doses']
percentTats = float(input("percentage of time spent playing video games?"))
ffMiles = float(input("frequent flier miles earned per year?"))
iceCream = float(input("liters of ice cream consumed per year?"))
datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
normMat, ranges, minVals = autoNorm(datingDataMat)
inArr = np.array([ffMiles, percentTats, iceCream])
classifyResult = classify0((inArr-minVals)/ranges, normMat, datingLabels, 3)
print ("You will probably like this persoon:", resultList[classifyResult - 1]) #手写识别系统#============================================================================================================
#数据预处理:输入图片为32*32的文本类型,将其形状转换为1*1024
def img2vector(filename):
returnVect = np.zeros((1, 1024))
fr = open(filename)
for i in range(32):
lineStr = fr.readline()
for j in range(32):
returnVect[0, 32*i+j] = int(lineStr[j])
return returnVect #手写数字识别系统测试代码
def handwritingClassTest():
hwLabels = []
trainingFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits')
m = len(trainingFileList)
trainingMat = np.zeros((m, 1024))
for i in range(m): #|
fileNameStr = trainingFileList[i] #|
fileName = fileNameStr.split('.')[0] #| 获取训练集路径下每一个文件,分割文件名,将第一个数字作为标签存储在hwLabels中
classNumber = int(fileName.split('_')[0]) #|
hwLabels.append(classNumber) #|
trainingMat[i,:] = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\traingDigits\\%s' % fileNameStr) #变换矩阵形状: from 32*32 to 1*1024
testFileList = listdir('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits')
errorCount = 0.0
mTest = len(testFileList)
for i in range(mTest): #同训练集
fileNameStr = testFileList[i]
fileName = fileNameStr.split('.')[0]
classNumber = int(fileName.split('_')[0])
vectorUnderTest = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\%s' % fileNameStr)
classifyResult = classify0(vectorUnderTest, trainingMat, hwLabels, 3) #计算欧氏距离并分类,返回计算结果
print ('The classifier came back with: %d, the real answer is: %d' % (classifyResult, classNumber))
if (classifyResult != classNumber):
errorCount += 1.0
print ('The total number of errors is: %d' % (errorCount))
print ('The total error rate is: %.1f%%' % (errorCount/float(mTest) * 100)) # Simple unit test of func: file2matrix()
#datingDataMat, datingLabels = file2matrix('datingTestSet2.txt')
#print (datingDataMat)
#print (datingLabels) # Usage of figure construction of matplotlib
#fig=plt.figure()
#ax = fig.add_subplot(111)
#ax.scatter(datingDataMat[:,1], datingDataMat[:,2], 15.0*np.array(datingLabels), 15.0*np.array(datingLabels))
#plt.show() # Simple unit test of func: autoNorm()
#normMat, ranges, minVals = autoNorm(datingDataMat)
#print (normMat)
#print (ranges)
#print (minVals) # Simple unit test of func: img2vector
#testVect = img2vector('C:\\Private\\PycharmProjects\\Algorithm\\kNN\digits\\testDigits\\0_13.txt')
#print (testVect[0, 32:63] ) #约会网站测试
datingClassTest() #约会网站预测
classifyPerson() #手写数字识别系统预测
handwritingClassTest()

Output:

theclassifier came back with: 3, the real answer is: 3
the total error rate is: 0.0%
theclassifier came back with: 2, the real answer is: 2
the total error rate is: 0.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 0.0%

...

theclassifier came back with: 2, the real answer is: 2
the total error rate is: 4.0%
theclassifier came back with: 1, the real answer is: 1
the total error rate is: 4.0%
theclassifier came back with: 3, the real answer is: 1
the total error rate is: 5.0%

percentage of time spent playing video games?10
frequent flier miles earned per year?10000
liters of ice cream consumed per year?0.5
You will probably like this persoon: in small doses

...

The classifier came back with: 9, the real answer is: 9
The total number of errors is: 27
The total error rate is: 6.8%

 Reference:

《机器学习实战》

k近邻算法python实现 -- 《机器学习实战》的更多相关文章

  1. AI小记-K近邻算法

    K近邻算法和其他机器学习模型比,有个特点:即非参数化的局部模型. 其他机器学习模型一般都是基于训练数据,得出一般性知识,这些知识的表现是一个全局性模型的结构和参数.模型你和好了后,不再依赖训练数据,直 ...

  2. K近邻 Python实现 机器学习实战(Machine Learning in Action)

    算法原理 K近邻是机器学习中常见的分类方法之间,也是相对最简单的一种分类方法,属于监督学习范畴.其实K近邻并没有显式的学习过程,它的学习过程就是测试过程.K近邻思想很简单:先给你一个训练数据集D,包括 ...

  3. 机器学习实战 - python3 学习笔记(一) - k近邻算法

    一. 使用k近邻算法改进约会网站的配对效果 k-近邻算法的一般流程: 收集数据:可以使用爬虫进行数据的收集,也可以使用第三方提供的免费或收费的数据.一般来讲,数据放在txt文本文件中,按照一定的格式进 ...

  4. 02机器学习实战之K近邻算法

    第2章 k-近邻算法 KNN 概述 k-近邻(kNN, k-NearestNeighbor)算法是一种基本分类与回归方法,我们这里只讨论分类问题中的 k-近邻算法. 一句话总结:近朱者赤近墨者黑! k ...

  5. 机器学习实战笔记--k近邻算法

    #encoding:utf-8 from numpy import * import operator import matplotlib import matplotlib.pyplot as pl ...

  6. python 机器学习(二)分类算法-k近邻算法

      一.什么是K近邻算法? 定义: 如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别. 来源: KNN算法最早是由Cover和Hart提 ...

  7. 用Python从零开始实现K近邻算法

    KNN算法的定义: KNN通过测量不同样本的特征值之间的距离进行分类.它的思路是:如果一个样本在特征空间中的k个最相似(即特征空间中最邻近)的样本中的大多数属于某一个类别,则该样本也属于这个类别.K通 ...

  8. 《机实战》第2章 K近邻算法实战(KNN)

    1.准备:使用Python导入数据 1.创建kNN.py文件,并在其中增加下面的代码: from numpy import * #导入科学计算包 import operator #运算符模块,k近邻算 ...

  9. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

随机推荐

  1. 【LeetCode-面试算法经典-Java实现】【101-Symmetric Tree(对称树)】

    [101-Symmetric Tree(对称树)] [LeetCode-面试算法经典-Java实现][全部题目文件夹索引] 原题 Given a binary tree, check whether ...

  2. jQuery改变label/input的值,改变class,改变img的src

    jQuery改变label/input的值.改变class,改变img的src jQuery改变label的值: $('#aID').text("New Value"); jQue ...

  3. Navicat Premium快速导出数据库ER图和数据字典

    2.快速导出数据库数据字典: SQL Server 数据库,生成数据字典 use YourDatabase --指定要生成数据字典的数据库 go SELECT 表名= then d.name else ...

  4. unity, Global和Local编辑模式

    下图表示是在Local模式下: 下图表示是在Global模式下: 不要搞反.

  5. C语言 文件操作

    /** *@author cody *@date 2014-08-09 *@description copy text file * FILE *fopen(filename,openmode) * ...

  6. memcahced&redis命令行cmd下的操作

    一.memcahced   1.安装 执行memcached.exe -d install 把memcached加入到服务中 执行memcached.exe -d uninstall 卸载memcac ...

  7. mysql 一些常用指令

    登陆: mysql -u root -p //登陆,输入root密码 退出登陆 mysql>exit; mysql 为所有ip授权 mysql> GRANT ALL PRIVILEGES ...

  8. Java 堆内存

    堆内存 Java 中的堆是 JVM 所管理的最大的一块内存空间,主要用于存放各种类的实例对象. 在 Java 中,堆被划分成两个不同的区域:新生代 ( Young ).老年代 ( Old ).新生代 ...

  9. js测试

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  10. oracle instant client,tnsping,tnsnames.ora和ORACLE_HOME

    前段时间要远程连接oracle数据库,可是又不想在自己电脑上完整安装oracleclient,于是到oracle官网下载了轻量级clientinstant client. 这玩意没有图形界面,全靠sq ...