仿scikit-learn模式写的kNN算法
一、什么是kNN算法
k邻近是指每个样本都可以用它最接近的k个邻居来代表。
核心思想:如果一个样本在特征空间中的k个最相邻的样本中大多数属于一个某类别,则该样本也属于这个类别。
二、将kNN封装成kNNClassifier
1、训练样本的特征在二维空间中的表示
、
2、kNN的训练过程如下图
3、完整代码(kNN.py)
import numpy as np
from math import sqrt
from collections import Counter
from metrics import accuracy_score class kNNClassifier():
def __init__(self, k):
"""初始化kNN分类器"""
assert k >= 1, "k must be valid"
self.k = k
self._x_train = None
self._y_train = None def fit(self, x_train, y_train):
"""根据训练集x_train和y_train训练kNN分类器"""
assert x_train.shape[0] == y_train.shape[0], \
"the size of x_train must be equal to the size of y_train"
assert x_train.shape[0] >= self.k, "the size of x_train must be at least k"
self._x_train = x_train
self._y_train = y_train
return self def predict(self, X_predict):
"""给定待预测数据集X_train,返回表示x_train的结果向量"""
assert self._x_train is not None and self._y_train is not None, \
"must fit before predict"
assert X_predict.shape[1] == self._x_train.shape[1] , \
"the feature number of X_predict must be equal to x_train"
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
"""给定待预测数据x,返回x预测的结果值"""
assert x.shape[0] == self._x_train.shape[1], \
"the feature number of x must be equal tu x_train"
distances = [sqrt(np.sum((x_train-x)**2)) for x_train in self._x_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def score(self, X_test, y_test):
"""根据数据集X_test 和y_test 得到当前模型的准确度"""
y_predict = self.predict(X_test)
return accuracy_score(y_test, y_predict) def __repr__(self):
return "kNN(k=%d)" % self.k if __name__ == "__main__":
x_train = np.array([[0.31864691, 0.99608349],
[0.8609734 , 0.40706129],
[0.86746155, 0.20136923],
[0.4346735 , 0.17677379],
[0.42842348, 0.68055183],
[0.70661963, 0.76155652],
[0.73379517, 0.6123456 ],
[0.68330672, 0.52193524],
[0.11192091, 0.07885633],
[0.99273292, 0.62484263]])
y_train = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1])
k = 6
x = np.array([0.756789,0.6123456])
knn = kNNClassifier(k)
knn.fit(x_train,y_train)
x_predict = x.reshape(1,-1)
print(knn.predict(x_predict))
三、测试结果
[1]
四、问题
1、如果直接将上面训练得到的模型直接放在真实环境中使用,但是模型没有得到验证,会造成模型很差,会有真实损失。
2、真实环境下很难拿到符合条件的数据去测试
解决办法:
1、将训练数据拿出一部分作为测试数据,通过测试数据直接判断模型好坏。
2、在模型进入真实环境前改进模型
1、train_test_split.py
import numpy as np def train_test_split(X, Y, train_ratio=0.8, seed=None):
"""将数据X和Y按照train_ratio分割成x_train,y_train,x_test,y_test"""
assert X.shape[0] == Y.shape[0], "the size of X must equal to the size of Y"
assert 0.0 <= train_ratio <= 1.0, "train_ratio must be valid" if seed:
np.random.seed(seed) shuffled_indexes = np.random.permutation(len(X))
train_size = int(len(X) * train_ratio)
train_indexes = shuffled_indexes[:train_size]
test_indexes = shuffled_indexes[train_size:] x_train = X[train_indexes]
y_train = Y[train_indexes] x_test = X[test_indexes]
y_test = Y[test_indexes] return x_train,y_train,x_test,y_test
2、实际操作
2、从最终的结果来看,该模型与原始数据的标签的吻合达到100%。
五、scikit-learn中的train_test_split
仿scikit-learn模式写的kNN算法的更多相关文章
- 吴裕雄--天生自然python机器学习实战:K-NN算法约会网站好友喜好预测以及手写数字预测分类实验
实验设备与软件环境 硬件环境:内存ddr3 4G及以上的x86架构主机一部 系统环境:windows 软件环境:Anaconda2(64位),python3.5,jupyter 内核版本:window ...
- Python 手写数字识别-knn算法应用
在上一篇博文中,我们对KNN算法思想及流程有了初步的了解,KNN是采用测量不同特征值之间的距离方法进行分类,也就是说对于每个样本数据,需要和训练集中的所有数据进行欧氏距离计算.这里简述KNN算法的特点 ...
- 机器学习--kNN算法识别手写字母
本文主要是用kNN算法对字母图片进行特征提取,分类识别.内容如下: kNN算法及相关Python模块介绍 对字母图片进行特征提取 kNN算法实现 kNN算法分析 一.kNN算法介绍 K近邻(kNN,k ...
- KNN算法识别手写数字
需求: 利用一个手写数字“先验数据”集,使用knn算法来实现对手写数字的自动识别: 先验数据(训练数据)集: ♦数据维度比较大,样本数比较多. ♦ 数据集包括数字0-9的手写体. ♦每个数字大约有20 ...
- 基于OpenCV的KNN算法实现手写数字识别
基于OpenCV的KNN算法实现手写数字识别 一.数据预处理 # 导入所需模块 import cv2 import numpy as np import matplotlib.pyplot as pl ...
- KNN 算法-实战篇-如何识别手写数字
公号:码农充电站pro 主页:https://codeshellme.github.io 上篇文章介绍了KNN 算法的原理,今天来介绍如何使用KNN 算法识别手写数字? 1,手写数字数据集 手写数字数 ...
- Python实现KNN算法及手写程序识别
1.Python实现KNN算法 输入:inX:与现有数据集(1xN)进行比较的向量 dataSet:已知向量的大小m数据集(NxM) 个标签:数据集标签(1xM矢量) k:用于比较的邻居数 ...
- Scikit Learn: 在python中机器学习
转自:http://my.oschina.net/u/175377/blog/84420#OSC_h2_23 Scikit Learn: 在python中机器学习 Warning 警告:有些没能理解的 ...
- 【机器学习】机器学习入门01 - kNN算法
0. 写在前面 近日加入了一个机器学习的学习小组,每周按照学习计划学习一个机器学习的小专题.笔者恰好近来计划深入学习Python,刚刚熟悉了其基本的语法知识(主要是与C系语言的差别),决定以此作为对P ...
随机推荐
- 当遇到npm ERR! Unexpected end of JSON input while parsing near……时的解决办法
运行npm install时有时会遇到以下错误: npm ERR! Unexpected end of JSON input while parsing near ... 这时可以先执行下面的命令: ...
- md5sum c实现
#include <stdio.h>#include <ctype.h> #define STR_VALUE(val) #val#define STR(name) STR_VA ...
- React Native商城项目实战02 - 主要框架部分(tabBar)
1.安装插件,cd到项目根目录下执行: $ npm i react-native-tab-navigator --save 2.主框架文件Main.js /** * 主页面 */ import Rea ...
- eclipse导入工程
一般项目配置信息完全可直接导入,即import 如果缺失.project等文件,eclipse无法识别,则将工程拷贝到工作空间目录下,在eclipse中新建一个同名工程即可
- 基于DRF的图书增删改查
功能演示 信息展示 添加功能 编辑功能 删除功能 DRF构建后台数据 本例的Model如下 from django.db import models class Publish(models.Mode ...
- jmeter之报告输出(html)
在使用jmeter进行测试时,我们需要生成相应的测试报告,jmeter3.0之后有自带的测试报告. 在测试报告的格式和输出内容不满足需求时,我们可以根据需要去修改其配置文件(jmeter.proper ...
- oracle-SYSTEM表空间的备份与恢复
oracle-SYSTEM表空间的备份与恢复 这一篇在介绍备份及恢复数据文件的方法时,以备份和重做日志(包括归档日志和在线日志)没有丢失为前提 所谓关键数据文件:system表空间的数据文件与参数un ...
- eclipse和myeclipse怎么在项目中查找指定代码?https://www.jb51.net/softjc/554889.html
有的童鞋,想eclipse和myeclipse整个项目中查找指定代码,由于补经常使用,可能会补熟悉.如果要去掉项目中所有的某个代码的话,找不到是灰常麻烦的,下面就简单说下怎么查找,希望对需要的人有用. ...
- c语言字串指针 char*
c语言中 char* 不仅能存字符串,还能存二进制数据,所以它的用途因使用者而定. char* 在很多使用场景下,是需要存储ascii码为0的元素的,这样就必须注意一个问题,那就是char*的长度. ...
- [Usaco2017 Jan]Promotion Counting
n只奶牛构成了一个树形的公司,每个奶牛有一个能力值pi,1号奶牛为树根.问对于每个奶牛来说,它的子树中有几个能力值比它大的.Inputn,表示有几只奶牛 n<=100000接下来n行为1-n号奶 ...