仿scikit-learn模式写的kNN算法
一、什么是kNN算法
k邻近是指每个样本都可以用它最接近的k个邻居来代表。
核心思想:如果一个样本在特征空间中的k个最相邻的样本中大多数属于一个某类别,则该样本也属于这个类别。
二、将kNN封装成kNNClassifier
1、训练样本的特征在二维空间中的表示
、
2、kNN的训练过程如下图

3、完整代码(kNN.py)
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score class kNNClassifier():
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._x_train = None
self._y_train = None def fit(self, x_train, y_train):
"""根据训练集x_train和y_train训练kNN分类器"""
assert x_train.shape[0] == y_train.shape[0], \
"the size of x_train must be equal to the size of y_train"
assert x_train.shape[0] >= self.k, "the size of x_train must be at least k"
self._x_train = x_train
self._y_train = y_train
return self def predict(self, X_predict):
"""给定待预测数据集X_train,返回表示x_train的结果向量"""
assert self._x_train is not None and self._y_train is not None, \
"must fit before predict"
assert X_predict.shape[1] == self._x_train.shape[1] , \
"the feature number of X_predict must be equal to x_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
"""给定待预测数据x,返回x预测的结果值"""
assert x.shape[0] == self._x_train.shape[1], \
"the feature number of x must be equal tu x_train"
distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._x_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def score(self, X_test, y_test):
"""根据数据集X_test 和y_test 得到当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict) def __repr__(self):
return "kNN(k=%d)" % self.k if __name__ == "__main__":
x_train = np.array([[0.31864691, 0.99608349],
[0.8609734 , 0.40706129],
[0.86746155, 0.20136923],
[0.4346735 , 0.17677379],
[0.42842348, 0.68055183],
[0.70661963, 0.76155652],
[0.73379517, 0.6123456 ],
[0.68330672, 0.52193524],
[0.11192091, 0.07885633],
[0.99273292, 0.62484263]])
y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
k = 6
x = np.array([0.756789,0.6123456])
knn = kNNClassifier(k)
knn.fit(x_train,y_train)
x_predict = x.reshape(1,-1)
print(knn.predict(x_predict))
三、测试结果
[1]
四、问题
1、如果直接将上面训练得到的模型直接放在真实环境中使用,但是模型没有得到验证,会造成模型很差,会有真实损失。
2、真实环境下很难拿到符合条件的数据去测试
解决办法:
1、将训练数据拿出一部分作为测试数据,通过测试数据直接判断模型好坏。
2、在模型进入真实环境前改进模型
1、train_test_split.py
import numpy as np def train_test_split(X, Y, train_ratio=0.8, seed=None):
"""将数据X和Y按照train_ratio分割成x_train,y_train,x_test,y_test"""
assert X.shape[0] == Y.shape[0], "the size of X must equal to the size of Y"
assert 0.0 <= train_ratio <= 1.0, "train_ratio must be valid" if seed:
np.random.seed(seed) shuffled_indexes = np.random.permutation(len(X))
train_size = int(len(X) * train_ratio)
train_indexes = shuffled_indexes[:train_size]
test_indexes = shuffled_indexes[train_size:] x_train = X[train_indexes]
y_train = Y[train_indexes] x_test = X[test_indexes]
y_test = Y[test_indexes] return x_train,y_train,x_test,y_test
2、实际操作

2、从最终的结果来看,该模型与原始数据的标签的吻合达到100%。
五、scikit-learn中的train_test_split

仿scikit-learn模式写的kNN算法的更多相关文章
- 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验
实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...
- Python 手写数字识别-knn算法应用
在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点 ...
- 机器学习--kNN算法识别手写字母
本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...
- KNN算法识别手写数字
需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
- Python实现KNN算法及手写程序识别
1.Python实现KNN算法 输入:inX:与现有数据集(1xN)进行比较的向量 dataSet:已知向量的大小m数据集(NxM) 个标签:数据集标签(1xM矢量) k:用于比较的邻居数 ...
- Scikit Learn: 在python中机器学习
转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...
- 【机器学习】机器学习入门01 - kNN算法
0. 写在前面 近日加入了一个机器学习的学习小组,每周按照学习计划学习一个机器学习的小专题.笔者恰好近来计划深入学习Python,刚刚熟悉了其基本的语法知识(主要是与C系语言的差别),决定以此作为对P ...
随机推荐
- Selenium学习之==>18种定位方式的使用
Selenium的定位方式一共有18种,单数8种,复数8种,最后两种是前面这16种的底层封装 单数形式8种 # 1.id定位 al = driver.find_element_by_id('i1') ...
- Vue 渲染函数
Vue 推荐在绝大多数情况下使用模板来创建你的 HTML.然而在一些场景中,你真的需要 JavaScript 的完全编程的能力.这时你可以用渲染函数,它比模板更接近编译器. 一 项目结构 二 App组 ...
- Django 自带 user 字段扩展及头像上传
django 及 rest_framework 笔记链接如下: django 入门笔记:环境及项目搭建 django 入门笔记:数据模型 django 入门笔记:视图及模版 django 入门笔记:A ...
- Git 创建分支并合并主分支
首先,我们创建dev分支,然后切换到dev分支: $ git checkout -b dev(等价于 $ git branch dev $ git checkout dev ) Switched to ...
- excel导入导出(一)
excel导入导出 依赖 <dependency> <groupId>org.apache.poi</groupId> <artifactId>poi& ...
- Simplify Path(路径简化)
问题: 来源:https://leetcode.com/problems/simplify-path Given an absolute path for a file (Unix-style), s ...
- 在laravel框架中使用mq
本文写于2018-11-28 1.部署laravel项目 https://github.com/laravel/laravel 通过git克隆项目,或者下载zip包然后解压等方式都可以把larave ...
- [转帖]功耗降50%,性能升35%!三星3nm GAA 2021年量产
功耗降50%,性能升35%!三星3nm GAA 2021年量产 http://www.chinaflashmarket.com/Instructor 在三星晶圆代工技术论坛(Samsung Found ...
- JS中值类型和引用类型
一.值类型 例子: var a=10; var b=a; a=20; console.log(b); 例子中,将a的值赋给了b,b=10,然后改变a的值不会影响b的值,a和b是独立的两份,互不影响. ...
- IDEA导入Junit jar包,在JavaSE的Module中使用Junit测试
写代码时偶尔想试一下自己的小想法,于是在IDEA中建了一个JavaEE项目.JavaEE项目中只能在main方法中运行代码块,不如单元测试的@Test灵活. 于是在网上找到了Junit的jar包:Do ...