一、kNN算法基础

# kNN:k-Nearest Neighboors

# 多用于解决分类问题

 1)特点:

  1. 是机器学习中唯一一个不需要训练过程的算法,可以别认为是没有模型的算法,也可以认为训练数据集就是模型本身;
  2. 思想极度简单;
  3. 应用数学知识少(近乎为零);
  4. 效果少;
  5. 可以解释机械学习算法使用过程中的很多细节问题
  6. 更完整的刻画机械学习应用的流程;

 2)思想:

  • 根本思想:两个样本,如果它们的特征足够相似,它们就有更高的概率属于同一个类别;
  • 问题:根据现有训练数据集,判断新的样本属于哪种类型
  • 方法/思路
  1. 求新样本点在样本空间内与所有训练样本的欧拉距离;
  2. 对欧拉距离排序,找出最近的k个点;
  3. 对k个点分类统计,看哪种类型的点数量最多,此类型即为对新样本的预测类型;

 3)代码实现过程:

  • 示例代码:

    import numpy as np
    import matplotlib.pyplot as plt raw_data_x = [[3.3935, 2.3312],
    [3.1101, 1.7815],
    [1.3438, 3.3684],
    [3.5823, 4.6792],
    [2.2804, 2.8670],
    [7.4234, 4.6965],
    [5.7451, 3.5340],
    [9.1722, 2.5111],
    [7.7928, 3.4241],
    [7.9398, 0.7916]]
    raw_data_y = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] # 训练集样本的data
    x_train = np.array(raw_data_x)
    # 训练集样本的label
    y_train = np.array(raw_data_y) # 1)绘制训练集样本与新样本的散点图
    # 根据样本类型(0、1两种类型),绘制所有样本的各特征点
    plt.scatter(x_train[y_train == 0, 0], x_train[y_train == 0, 1], color = 'g')
    plt.scatter(x_train[y_train == 1, 0], x_train[y_train == 1, 1], color = 'r')
    # 新样本
    x = np.array([8.0936, 3.3657])
    # 将新样本的特征点绘制在训练集的样本空间
    plt.scatter(x[0], x[1], color = 'b')
    plt.show() # 2)在特征空间中,计算训练集样本中的所有点与新样本的点的欧拉距离
    from math import sqrt
    # math模块下的sqrt函数:对数值开平方sqrt(number)
    distances = []
    for x_train in x_train:
    d = sqrt(np.sum((x - x_train) ** 2))
    distances.append(d) # 也可以用list的生成表达式实现:
    # distances = [sqrt(np.sum((x - x_train) ** 2)) for x_train in x_train] # 3)找出距离新样本最近的k个点,并得到对新样本的预测类型
    nearest = np.argsort(distances)
    k = 6
    # 找出距离最近的k个点的类型
    topK_y = [y_train[i] for i in nearest[:k]] # 根据类别对k个点的数量进行统计
    from collections import Counter
    votes = Counter(topK_y) # 获取所需的预测类型:predict_y
    predict_y = votes.most_common(1)[0][0]
  • 封装好的Python代码

    import numpy as np
    from math import sqrt
    from collections import Counter def kNN_classify(k, X_train, y_train, x): assert 1 <= k <= X_train.shape[0],"k must be valid"
    assert X_train.shape[0] == y_train.shape[0], \
    "the size of X_train nust equal to the size of y_train"
    assert X-train.shape[1] == x.shape[0],\
    "the feature number of x must be equal to X_train" distances = [sprt(np.sum((x_train - x) ** 2)) for x_train in X_train]
    nearest = np.argsort(distances)
    topK_y = [y_train[i] for i in nearest[:k]]
    vates = Counter(topK_y)
    return votes.most_common(1)[0][0]

   # assert:表示声明;此处对4个参数进行限定;

  • 代码中的其它Python知识:
  1. math模块下的sprt()方法:对数开平方;

    from math import sqrt
    print(sprt(9))
    #
  2. collections模块下的Counter()方法:对列表中的数据进行分类统计,生产一个Counter对象;
    from collections import Counter
    
    my_list = [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]
    print(Counter(my_list))
    # 一个Counter对象:Counter({0: 2, 1: 3, 2: 4, 3: 5})
  3. Counter对象的most_common()方法:Counter.most_common(n),返回Counter对象中数量最多的n种数据,返回一个list,list的每个元素为一个tuple;
    from collections import Counter
    
    my_list = [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3]
    votes = Counter(my_list)
    print(votes.most_common(2))
    # [(3, 5), (2, 4)]

二、总结

 1)k近邻算法的作用

  1、解决分类问题,而且天然可以解决多分类问题;

  2、也可以解决回归问题,其中scikit-learn库中封装的KNeighborsRegressor,就是解决回归问题;

 2)缺点

  • 缺点1:效率低下
  1. 原因:如果训练集有m个样本,n个特征,预测每一个新样本,需要计算与m个样本的距离,每计算一个距离,要使用n个时间复杂度,则计算m个样本的距离,使用m * n个时间复杂度;
  2. 算法的时间复杂度:反映了程序执行时间随输入规模增长而增长的量级,在很大程度上能很好反映出算法的优劣与否。
  3. 算法的时间复杂度与空间复杂度,参考:算法的时间复杂度和空间复杂度
  4. 可以通过树结构对k近邻算法优化:KD-Tree、Ball-Tree,但即便进行优化,效率依然不高;
  • 缺点2:高度数据相关
  1. 机器学习算法,就是通过喂给数据进行预测,理论上所有机器学习算法都是高度数据相关;
  2. k近邻算法对outlier更加敏感:比如三近邻算法,在特征空间中,如果在需要预测的样本周边,一旦有两个样本出现错误值,就足以使预测结果错误,哪怕在更高的范围里,在特征空间中有大量正确的样本;
  • 缺点3:预测的结果不具有可解释性
  1. 按k近邻算法的逻辑:找到和预测样本比较近的样本,就得出预测样本和其最近的这个样本类型相同;
  2. 问题:为什么预测的样本类型就是离它最近的样本的类型?
  3. 很多情况下,只是拿到预测结果是不够的,还需要对此结果有解释性,进而通过解释推广使用,或者制作更多工具,或者以此为基础发现新的理论/规则,来改进生产活动中的其它方面——这些是kNN算法做不到的;
  • 缺点4:维数灾难
  1. 维数灾难:随着维度的增加,“看似相近”的两个点之间的距离越来越大;
  2. 例:[0, 0, 0, ...0]和[1, 1, 1,...1],按欧拉定理计算,元素个数越多,两点距离越大;
  3. 方案:降维(PCA);

三、使用机器学习算法的流程

  • 获取原始数据——数据分割——数据归一化——训练模型——预测
  1. 获取原始数据:一般可从scikit-learn库中调用——# 调用数据集的操作流程  机器学习:scikit-learn中算法的调用、封装并使用自己所写的算法
  2. 数据分割:一般按2 :8进行分割——# 分割数据的代码实现过程、通过scikit-learn库分割数据的操作流程  机器学习:训练数据集、测试数据集
  3. 数据归一化:参见  机器学习:数据归一化(Scaler)
  4. 训练模型、模型预测:  机器学习:scikit-learn中算法的调用、封装并使用自己所写的算法

机器学习:k-NN算法(也叫k近邻算法)的更多相关文章

  1. 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)

    --------------------------------------------------------------------------------------- 本系列文章为<机器 ...

  2. machine_learning-knn算法具体解释(近邻算法)

    近邻算法是机器学习算法中的入门算法,该算法用于针对已有数据集对未知数据进行分类. 该算法核心思想是通过计算预測数据与已有数据的相似度猜測结果. 举例: 如果有例如以下一组数据(在下面我们统一把该数据作 ...

  3. 机器学习(四) 分类算法--K近邻算法 KNN (上)

    一.K近邻算法基础 KNN------- K近邻算法--------K-Nearest Neighbors 思想极度简单 应用数学知识少 (近乎为零) 效果好(缺点?) 可以解释机器学习算法使用过程中 ...

  4. KNN-k近邻算法

    目录 KNN-k近邻算法 一.KNN基础 二.自己写一个knn函数 三.使用sklearn中的KNN 四.自己写一个面向对象的KNN 五.分割数据集 六.使用sklearn中的鸢尾花数据测试KNN 七 ...

  5. K近邻算法:机器学习萌新必学算法

    摘要:K近邻(k-NearestNeighbor,K-NN)算法是一个有监督的机器学习算法,也被称为K-NN算法,由Cover和Hart于1968年提出,可以用于解决分类问题和回归问题. 1. 为什么 ...

  6. 机器学习之K近邻算法(KNN)

    机器学习之K近邻算法(KNN) 标签: python 算法 KNN 机械学习 苛求真理的欲望让我想要了解算法的本质,于是我开始了机械学习的算法之旅 from numpy import * import ...

  7. 机器学习03:K近邻算法

    本文来自同步博客. P.S. 不知道怎么显示数学公式以及排版文章.所以如果觉得文章下面格式乱的话请自行跳转到上述链接.后续我将不再对数学公式进行截图,毕竟行内公式截图的话排版会很乱.看原博客地址会有更 ...

  8. [机器学习] k近邻算法

    算是机器学习中最简单的算法了,顾名思义是看k个近邻的类别,测试点的类别判断为k近邻里某一类点最多的,少数服从多数,要点摘录: 1. 关键参数:k值 && 距离计算方式 &&am ...

  9. 机器学习 Python实践-K近邻算法

    机器学习K近邻算法的实现主要是参考<机器学习实战>这本书. 一.K近邻(KNN)算法 K最近邻(k-Nearest Neighbour,KNN)分类算法,理解的思路是:如果一个样本在特征空 ...

随机推荐

  1. python(pytest)+allure+jenkins 实现接口自动化的思路

    效果图镇楼: 上述各模块作用: python(pytest): 1:用于读测试用例(本次用例写在csv文件中) 2:环境配置相关 3:提取1中的测试数据,组成请求体 4:发送请求 5:获取结果 6:断 ...

  2. C#线程使用学习

    线程的入口函数可以不带输入参数,也可以带输入参数: form1.cs using System; using System.Collections.Generic; using System.Comp ...

  3. 为什么下了android 4.1 的SDK后在本地用浏览器看api说明文档时,浏览器打开api的html文件很慢?试了好几款浏览器都一样。为什么?

    http://www.oschina.net/question/436724_61401 http://www.google.com/jsapi  他惹的祸 注释掉就可以了- <!-- < ...

  4. Python:笔记(1)——基础语法

    Python:笔记(1)——基础语法 我很抱歉有半年没有在博客园写过笔记了,客观因素有一些,但主观原因居多,再多的谴责和批判也都于事无补,我们能做的就是重振旗鼓,继续出发! ——写在Python之前 ...

  5. 每天一个Linux命令(8)cat命令

    cat命令连接文件并打印到标准输出设备上,cat经常用来显示文件的内容,类似于下的type命令. 注意:当文件较大时,文本在屏幕上迅速闪过(滚屏),用户往往看不清所显示的内容.因此,一般用more等命 ...

  6. 10款CSS3进度条Loading动画

    在线演示 本地下载

  7. Linux 上通过rpm安装mysql

    安装mysql之前要remove掉系统自带的mysql: rpm -qa | grep "MySQL*"    和rpm -qa | grep mysql  要确保卸载干净 rpm ...

  8. 剑指offer之 栈的压入、弹出序列

    题目描述:输入两个整数序列,第一个序列表示栈的压入顺序,请判断第二个序列是否为该栈的弹出序列.假设压入栈的所有数字均不相等.例如序列1/2/3/4/5是某栈的压栈序列,序列4/5/3/2/1是该压栈序 ...

  9. java打包命令

    (1)首先,必须保证java的所有路径都设置好,在dos提示符下输入jar -help 出现C:\Documents and Settings\dly>jar -help 非法选项:h 用法:j ...

  10. R Customizing graphics

    Customizing graphics GraphicsLaTeXLattice (Treillis) plots In this chapter (it tends to be overly co ...