【笔记】初探KNN算法(2)
KNN算法(2)
机器学习算法封装
scikit-learn中的机器学习算法封装
在python chame中将算法写好
import numpy as np
from math import sqrt
from collections import Counter
def kNN_classify(k, X_train, y_train , x):
assert 1 <= k <= X_train.shape[0],"k must be valid"
assert X_train.shape[0] == y_train.shape[0], \
"the size of X_train must equal to the size of y_train"
assert X_train.shape[1] == x.shape[0], \
"the feature number of x must be equal to X_train"
distances = [sqrt(np.sum((x_train - x)**2)) for x_train in X_train]
nearest = np.argsort(distances)
topK_y = [y_train[i] for i in nearest[:k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0]
将所需要的数据提前准备好

使用魔法命令%run调用函数
%run KNN.py
执行即可得到预测结果

k近邻算法是非常特殊的,可以被认为是没有模型的算法,为了和其他的算法统一,可以认为训练数据集就是魔性本身
使用scikit-learn中的kNN
需要调用KNeighborsClassifier类

创建实例,其中n_neighbors=6相当于k=6

然后进行fit操作
kNN_classifier.fit(X_train,y_train)
其返回值就是自身,可以不用接参数

调用predict方法即可实现
不过需要注意的是,这个必须是一个矩阵,不能是一维数组
因此我们先reshape改变结构

最后就可以得到预测的类别

重新整理我们的kNN代码
在同一个文件夹下创建一个kNN1.py的文件
写入KNN算法
import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier:
def __init__(self, k):
"""初始化KNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._X_train = None
self._y_train = None
def fit(self, X_train, y_train):
"""根据训练数据集X_train和y_train训练kNN分类器"""
assert X_train.shape[0] == y_train.shape[0], \
"this size of X_train must be equal to the size of y_train"
assert self.k <= X_train.shape[0], \
"the size of X_train must be at least k."
self._X_train = X_train
self._y_train = y_train
return self
def predict(self, X_predict):
"""给定预测数据集X_predict,返回表示X_predict的结果向量"""
assert self._X_train is not None and self._y_train is not None, \
"must fit before predict!"
assert X_predict.shape[1] == self._X_train.shape[1], \
"the feature number of X_predict must be equal to X_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict)
def _predict(self, x):
"""给定单个待预测数据x,返回x的预测结果值"""
assert x.shape[0] == self._X_train.shape[1], \
"the feature number of x must be equal to X_train"
distances = [sqrt(np.sum((x_train - x) ** 2))
for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0]
def __repr__(self):
return "KNN(k=%d)" % self.k
同上操作,即可得到


【笔记】初探KNN算法(2)的更多相关文章
- 【笔记】初探KNN算法(3)
KNN算法(3) 测试算法的目的就是为了帮助我们选择一个更好的模型 训练数据集,测试数据集方面 一般来说,我们训练得到的模型直接在真实的环境中使用 这就导致了一些问题 如果模型很差,未经改进就应用在现 ...
- 【笔记】初探KNN算法(1)
KNN算法(1) 全称是K Nearest Neighbors k近邻算法: 思想简单 需要的数学知识很少 效果不错 可以解释机器学习算法使用过程中的很多细节问题 更加完整的刻画机器学习应用的流程 其 ...
- 机器学习实战(笔记)------------KNN算法
1.KNN算法 KNN算法即K-临近算法,采用测量不同特征值之间的距离的方法进行分类. 以二维情况举例: 假设一条样本含有两个特征.将这两种特征进行数值化,我们就可以假设这两种特种分别 ...
- 机器学习笔记(5) KNN算法
这篇其实应该作为机器学习的第一篇笔记的,但是在刚开始学习的时候,我还没有用博客记录笔记的打算.所以也就想到哪写到哪了. 你在网上搜索机器学习系列文章的话,大部分都是以KNN(k nearest nei ...
- kNN算法笔记
kNN算法笔记 标签(空格分隔): 机器学习 kNN是什么 kNN算法是k-NearestNeighbor算法,也就是k邻近算法.是监督学习的一种.所谓监督学习就是有训练数据,训练数据有label标好 ...
- 机器学习笔记--KNN算法2-实战部分
本文申明:本系列的所有实验数据都是来自[美]Peter Harrington 写的<Machine Learning in Action>这本书,侵删. 一案例导入:玛利亚小姐最近寂寞了, ...
- 机器学习笔记--KNN算法1
前言 Hello ,everyone. 我是小花.大四毕业,留在学校有点事情,就在这里和大家吹吹我们的狐朋狗友算法---KNN算法,为什么叫狐朋狗友算法呢,在这里我先卖个关子,且听我慢慢道来. 一 K ...
- 算法学习笔记:knn理论介绍
阅读对象:了解指示函数,了解训练集.测试集的概念. 1.简介 knn算法是监督学习中分类方法的一种.所谓监督学习与非监督学习,是指训练数据是否有标注类别,若有则为监督学习,若否则为非监督学习.所谓K近 ...
- 机器学习简要笔记(三)-KNN算法
#coding:utf-8 import numpy as np import operator def classify(intX,dataSet,labels,k): ''' KNN算法 ''' ...
随机推荐
- shiro框架基础
一.shiro框架简介 Apache Shiro是Java的一个安全框架.其内部架构如下: 下面来介绍下里面的几个重要类: Subject:主体,应用代码直接交互的对象就是Subject.代表了当前用 ...
- JDBC:Connection.close()
https://www.2cto.com/database/201501/369246.html Connection对象在执行close() 方法之后,并不是直接把Connection对象设置为nu ...
- PHP7与php5
php在2015年12月03日发布了7.0正式版,带来了许多新的特性,以下是不完全列表: 性能提升:PHP7比PHP5.6性能提升了两倍. Improved performance: PHP 7 is ...
- 嵌入式Linux会议LinuxCon欧洲的时间表公布
From: http://linuxgizmos.com/embedded-linux-conference-and-linuxcon-europe-schedules-posted/ Linux基金 ...
- 5.Java流程控制
所有的流程控制语句都可以相互嵌套.互不影响 一.用户交互Scanner Scanner对象 之前我们学的基本语法中我们并没有实现程序和人的交互,但是Java给我们提供了这样一个工具类,我们可以获取用户 ...
- python从图片中找图
import aircv as ac def matcha(bb,aa):#从bb查找aa,如果有则返回其坐标位置 yuan=ac.imread(bb) mubi=ac.imread(aa) resu ...
- Spring框架中一个有用的小组件:Spring Retry
1.概述 Spring Retry 是Spring框架中的一个组件, 它提供了自动重新调用失败操作的能力.这在错误可能是暂时发生的(如瞬时网络故障)的情况下很有帮助. 在本文中,我们将看到使用Spri ...
- 小 W 离职了
今天这篇是架构师大刘的系列故事 小W要离职了,大刘并没有挽留,甚至有点庆幸. 小W离职的原因比较简单,这次升职加薪,大刘提拔了和他同期进来,并且工作年限和他差不多的小L,而小W则是原地没动,薪水也没有 ...
- 如何让Spring Boot 的配置动起来?
前言 对于微服务而言配置本地化是个很大的鸡肋,不可能每次需要改个配置都要重新把服务重新启动一遍,因此最终的解决方案都是将配置外部化,托管在一个平台上达到不用重启服务即可一次修改多处生效的目的. 但是对 ...
- 【16位RAW图像处理三】直方图均衡化及局部直方图均衡用于16位图像的细节增强。
通常我们生活中遇到的图像,无论是jpg.还是png或者bmp格式,一般都是8位的(每个通道的像素值范围是0-255),但是随着一些硬件的发展,在很多行业比如医疗.红外.航拍等一些场景下,拥有更宽的量化 ...