08.手写KNN算法测试
导入库
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
导入数据
iris = datasets.load_iris()
数据准备
X = iris.data
y = iris.target
X.shape, y.shape
((150, 4), (150,))
数据分割(28开)
# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割
X_join_y = np.hstack([X, y.reshape(-1,1)])
# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置
np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))
准备data和target
# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)
X_train = train[:,0:4]
y_train = train[:,-1]
X_test = test[:,0:4]
y_test = test[:,-1]
KNN手写算法
import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier: def __init__(self, k):
# 初始化KNN分类器
self.k = k
self._X_train = None
self._y_train = None def fit(self, X_train, y_train):
# 根据训练集X_train, Y_train训练分类器
self._X_train = X_train
self._y_train = y_train
return self def predict(self, X_predict):
# 给定待遇测的数据集X_predict,返回表示X_predict的结果向量
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
# 给定单个待遇测数据x,返回x的预测结果值
distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def __repr__(self):
return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split result = train_test_split(X, y)
result
[array([[7.2, 3. , 5.8, 1.6],
[5.4, 3.9, 1.3, 0.4],
[6.5, 3.2, 5.1, 2. ],
[6.1, 3. , 4.6, 1.4],
[4.6, 3.2, 1.4, 0.2],
[6.9, 3.2, 5.7, 2.3],
[6.1, 2.8, 4. , 1.3],
[5.7, 3. , 4.2, 1.2],
[5.8, 2.7, 4.1, 1. ],
[5.5, 2.5, 4. , 1.3],
[5.7, 2.5, 5. , 2. ],
[4.6, 3.4, 1.4, 0.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.9, 5.6, 1.8],
[6.8, 3. , 5.5, 2.1],
[6.4, 2.7, 5.3, 1.9],
[6. , 2.9, 4.5, 1.5],
[6. , 2.2, 4. , 1. ],
[4.8, 3. , 1.4, 0.1],
[5.6, 2.5, 3.9, 1.1],
[7.1, 3. , 5.9, 2.1],
[6.7, 3.3, 5.7, 2.1],
[5.5, 2.6, 4.4, 1.2],
[6.3, 3.3, 4.7, 1.6],
[6.7, 3.1, 4.7, 1.5],
[4.3, 3. , 1.1, 0.1],
[4.8, 3.4, 1.9, 0.2],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.7, 5.1, 1.6],
[6.5, 3. , 5.5, 1.8],
[4.9, 2.5, 4.5, 1.7],
[5. , 3.5, 1.3, 0.3],
[5.9, 3. , 4.2, 1.5],
[5.5, 2.4, 3.8, 1.1],
[6.2, 2.2, 4.5, 1.5],
[6.3, 2.7, 4.9, 1.8],
[4.4, 3. , 1.3, 0.2],
[7.7, 3. , 6.1, 2.3],
[7. , 3.2, 4.7, 1.4],
[6.4, 2.8, 5.6, 2.2],
[5.7, 2.8, 4.5, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 3. , 4.1, 1.3],
[6.3, 2.8, 5.1, 1.5],
[4.9, 3.6, 1.4, 0.1],
[6. , 3.4, 4.5, 1.6],
[5.7, 4.4, 1.5, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.7, 1.5, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.9, 3.1, 4.9, 1.5],
[5.1, 3.8, 1.9, 0.4],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.5, 0.3],
[5. , 3.4, 1.5, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.2, 2.7, 3.9, 1.4],
[6.1, 2.6, 5.6, 1.4],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 5.1, 1.9],
[6.8, 2.8, 4.8, 1.4],
[4.4, 3.2, 1.3, 0.2],
[5.3, 3.7, 1.5, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.4, 3.1, 5.5, 1.8],
[6.2, 3.4, 5.4, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.3, 2.5, 4.9, 1.5],
[5.8, 2.6, 4. , 1.2],
[4.6, 3.1, 1.5, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5.6, 2.9, 3.6, 1.3],
[5.1, 3.7, 1.5, 0.4],
[5. , 3.2, 1.2, 0.2],
[6.5, 3. , 5.8, 2.2],
[7.3, 2.9, 6.3, 1.8],
[5.2, 3.4, 1.4, 0.2],
[4.5, 2.3, 1.3, 0.3],
[5.5, 2.3, 4. , 1.3],
[6.5, 3. , 5.2, 2. ],
[5.5, 2.4, 3.7, 1. ],
[7.6, 3. , 6.6, 2.1],
[5. , 3.6, 1.4, 0.2],
[5.9, 3. , 5.1, 1.8],
[6.3, 2.5, 5. , 1.9],
[6.1, 3. , 4.9, 1.8],
[4.9, 3. , 1.4, 0.2],
[6.7, 3. , 5.2, 2.3],
[5.1, 3.5, 1.4, 0.3],
[6.3, 2.3, 4.4, 1.3],
[4.4, 2.9, 1.4, 0.2],
[6.8, 3.2, 5.9, 2.3],
[5.1, 3.8, 1.6, 0.2],
[7.2, 3.6, 6.1, 2.5],
[5.7, 3.8, 1.7, 0.3],
[5. , 2. , 3.5, 1. ],
[5. , 3. , 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.8, 5.1, 2.4],
[5.8, 4. , 1.2, 0.2],
[6.1, 2.8, 4.7, 1.2],
[5.4, 3.9, 1.7, 0.4],
[6.5, 2.8, 4.6, 1.5],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.4, 1.7, 0.2],
[4.9, 2.4, 3.3, 1. ],
[5.1, 3.4, 1.5, 0.2]]),
array([[6.2, 2.9, 4.3, 1.3],
[6.7, 3. , 5. , 1.7],
[5.2, 4.1, 1.5, 0.1],
[5.7, 2.6, 3.5, 1. ],
[7.4, 2.8, 6.1, 1.9],
[5.6, 3. , 4.5, 1.5],
[6.9, 3.1, 5.1, 2.3],
[6. , 2.2, 5. , 1.5],
[5.5, 3.5, 1.3, 0.2],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.2, 6. , 1.8],
[6. , 3. , 4.8, 1.8],
[5.2, 3.5, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.2],
[5. , 3.3, 1.4, 0.2],
[5.6, 2.8, 4.9, 2. ],
[5.6, 2.7, 4.2, 1.3],
[5. , 3.5, 1.6, 0.6],
[7.9, 3.8, 6.4, 2. ],
[6.3, 3.4, 5.6, 2.4],
[5. , 3.4, 1.6, 0.4],
[6.2, 2.8, 4.8, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.5, 4.2, 1.4, 0.2],
[4.6, 3.6, 1. , 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.7, 2.9, 4.2, 1.3],
[7.7, 2.6, 6.9, 2.3],
[7.7, 3.8, 6.7, 2.2],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 3.9, 1.2],
[6.6, 2.9, 4.6, 1.3],
[4.7, 3.2, 1.6, 0.2],
[6.7, 3.1, 4.4, 1.4],
[6.4, 3.2, 4.5, 1.5],
[4.7, 3.2, 1.3, 0.2],
[6.6, 3. , 4.4, 1.4]]),
array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
1, 0]),
array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)
y_predict = my_knn_clf.predict(result[1])
sum(y_predict == result[3])
sum(y_predict == result[3])/len(result[3])
08.手写KNN算法测试的更多相关文章
- [纯C#实现]基于BP神经网络的中文手写识别算法
效果展示 这不是OCR,有些人可能会觉得这东西会和OCR一样,直接进行整个字的识别就行,然而并不是. OCR是2维像素矩阵的像素数据.而手写识别不一样,手写可以把用户写字的笔画时间顺序,抽象成一个维度 ...
- 用C实现单隐层神经网络的训练和预测(手写BP算法)
实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...
- 手写KMeans算法
KMeans算法是一种无监督学习,它会将相似的对象归到同一类中. 其基本思想是: 1.随机计算k个类中心作为起始点. 将数据点分配到理其最近的类中心. 3.移动类中心. 4.重复2,3直至类中心不再改 ...
- 手写k-means算法
作为聚类的代表算法,k-means本属于NP难问题,通过迭代优化的方式,可以求解出近似解. 伪代码如下: 1,算法部分 距离采用欧氏距离.参数默认值随意选的. import numpy as np d ...
- Javascript 手写 LRU 算法
LRU 是 Least Recently Used 的缩写,即最近最少使用.作为一种经典的缓存策略,它的基本思想是长期不被使用的数据,在未来被用到的几率也不大,所以当新的数据进来时我们可以优先把这些数 ...
- 手写LRU算法
import java.util.LinkedHashMap; import java.util.Map; public class LRUCache<K, V> extends Link ...
- 手写hashmap算法
/** * 01.自定义一个hashmap * 02.实现put增加键值对,实现key重复时替换key的值 * 03.重写toString方法,方便查看map中的键值对信息 * 04.实现get方法, ...
- 基于kNN的手写字体识别——《机器学习实战》笔记
看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...
- 手写BP(反向传播)算法
BP算法为深度学习中参数更新的重要角色,一般基于loss对参数的偏导进行更新. 一些根据均方误差,每层默认激活函数sigmoid(不同激活函数,则更新公式不一样) 假设网络如图所示: 则更新公式为: ...
随机推荐
- ScalikeJDBC,操作mysql数据,API
ScalikeJDBC,操作mysql数据,API 一.构建maven项目,添加pom.xml依赖 二.resource文件下创建application.conf文件,并配置以下内容 三.操作mysq ...
- Mysql 5.5升级5.8
前言,因为升级跳板机,需要将mariadb 升级到10.2,也就是对应MySQL的5.8,废话不多说下面开始进行mariadb 5.5 的升级 Welcome to the MariaDB monit ...
- maven高级笔记
Maven高级 1.maven基础知识回顾 1.1 maven介绍 maven 是一个项目管理工具,主要作用是在项目开发阶段对Java项目进行依赖管理和项目构建. 依赖管理:就是对jar包的管理.通过 ...
- 既有Nginx重新动态编译增加http2.0模块
1.HTTP2.0 HTTP2.0相较于http1.x,大幅度的提升了web性能,在与http1.1完全语义兼容的基础上,进一步减少了网络延时.我们现在很多对外的网站都采用https,但是F12一下看 ...
- Hive创建HBase,ES外部表
1.创建HBase外部表 CREATE EXTERNAL TABLE `ods_women`( `rowkey` string COMMENT 'from deserializer', `articl ...
- POJ-3208 Apocalypse Someday (数位DP)
只要某数字的十进制表示中有三个6相邻,则该数字为魔鬼数,求第X小的魔鬼数\(X\le 5e7\) 这一类题目可以先用DP进行预处理,再基于拼凑思想,用"试填法"求出最终的答案 \( ...
- 2020Nowcode多校 Round9 B.Groundhog and Apple Tree
题意 给一棵树 初始\(hp=0\) 经过一条边会掉血\(w_{i}\) 第一次到达一个点可以回血\(a_{i}\) 在一个点休息\(1s\)可以回复\(1hp\) 血不能小于\(0\) 每条边最多经 ...
- 2019-2020 ACM-ICPC Brazil Subregional Programming Contest (11/13)
\(2019-2020\ ACM-ICPC\ Brazil\ Subregional\ Programming\ Contest\) \(A.Artwork\) 并查集,把检测区域能在一起的检测器放在 ...
- 2019牛客多校 Round7
Solved:5 Rank:296 E Find the median (线段树) 题意:最开始一个空的数组 4e5次操作 每次把Li,Ri中的每个数插入进来 问当前的中位数 题解:把这n个区间离散化 ...
- P3384 [模板] 树链剖分
#include <bits/stdc++.h> using namespace std; typedef long long ll; int n, m, rt, mod, cnt, to ...