08.手写KNN算法测试
导入库
import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt
导入数据
iris = datasets.load_iris()
数据准备
X = iris.data
y = iris.target
X.shape, y.shape
((150, 4), (150,))
数据分割(28开)
# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割
X_join_y = np.hstack([X, y.reshape(-1,1)])
# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置
np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))
准备data和target
# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)
X_train = train[:,0:4]
y_train = train[:,-1]
X_test = test[:,0:4]
y_test = test[:,-1]
KNN手写算法
import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier: def __init__(self, k):
# 初始化KNN分类器
self.k = k
self._X_train = None
self._y_train = None def fit(self, X_train, y_train):
# 根据训练集X_train, Y_train训练分类器
self._X_train = X_train
self._y_train = y_train
return self def predict(self, X_predict):
# 给定待遇测的数据集X_predict,返回表示X_predict的结果向量
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
# 给定单个待遇测数据x,返回x的预测结果值
distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def __repr__(self):
return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split result = train_test_split(X, y)
result
[array([[7.2, 3. , 5.8, 1.6],
[5.4, 3.9, 1.3, 0.4],
[6.5, 3.2, 5.1, 2. ],
[6.1, 3. , 4.6, 1.4],
[4.6, 3.2, 1.4, 0.2],
[6.9, 3.2, 5.7, 2.3],
[6.1, 2.8, 4. , 1.3],
[5.7, 3. , 4.2, 1.2],
[5.8, 2.7, 4.1, 1. ],
[5.5, 2.5, 4. , 1.3],
[5.7, 2.5, 5. , 2. ],
[4.6, 3.4, 1.4, 0.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.9, 5.6, 1.8],
[6.8, 3. , 5.5, 2.1],
[6.4, 2.7, 5.3, 1.9],
[6. , 2.9, 4.5, 1.5],
[6. , 2.2, 4. , 1. ],
[4.8, 3. , 1.4, 0.1],
[5.6, 2.5, 3.9, 1.1],
[7.1, 3. , 5.9, 2.1],
[6.7, 3.3, 5.7, 2.1],
[5.5, 2.6, 4.4, 1.2],
[6.3, 3.3, 4.7, 1.6],
[6.7, 3.1, 4.7, 1.5],
[4.3, 3. , 1.1, 0.1],
[4.8, 3.4, 1.9, 0.2],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.7, 5.1, 1.6],
[6.5, 3. , 5.5, 1.8],
[4.9, 2.5, 4.5, 1.7],
[5. , 3.5, 1.3, 0.3],
[5.9, 3. , 4.2, 1.5],
[5.5, 2.4, 3.8, 1.1],
[6.2, 2.2, 4.5, 1.5],
[6.3, 2.7, 4.9, 1.8],
[4.4, 3. , 1.3, 0.2],
[7.7, 3. , 6.1, 2.3],
[7. , 3.2, 4.7, 1.4],
[6.4, 2.8, 5.6, 2.2],
[5.7, 2.8, 4.5, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 3. , 4.1, 1.3],
[6.3, 2.8, 5.1, 1.5],
[4.9, 3.6, 1.4, 0.1],
[6. , 3.4, 4.5, 1.6],
[5.7, 4.4, 1.5, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.7, 1.5, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.9, 3.1, 4.9, 1.5],
[5.1, 3.8, 1.9, 0.4],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.5, 0.3],
[5. , 3.4, 1.5, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.2, 2.7, 3.9, 1.4],
[6.1, 2.6, 5.6, 1.4],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 5.1, 1.9],
[6.8, 2.8, 4.8, 1.4],
[4.4, 3.2, 1.3, 0.2],
[5.3, 3.7, 1.5, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.4, 3.1, 5.5, 1.8],
[6.2, 3.4, 5.4, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.3, 2.5, 4.9, 1.5],
[5.8, 2.6, 4. , 1.2],
[4.6, 3.1, 1.5, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5.6, 2.9, 3.6, 1.3],
[5.1, 3.7, 1.5, 0.4],
[5. , 3.2, 1.2, 0.2],
[6.5, 3. , 5.8, 2.2],
[7.3, 2.9, 6.3, 1.8],
[5.2, 3.4, 1.4, 0.2],
[4.5, 2.3, 1.3, 0.3],
[5.5, 2.3, 4. , 1.3],
[6.5, 3. , 5.2, 2. ],
[5.5, 2.4, 3.7, 1. ],
[7.6, 3. , 6.6, 2.1],
[5. , 3.6, 1.4, 0.2],
[5.9, 3. , 5.1, 1.8],
[6.3, 2.5, 5. , 1.9],
[6.1, 3. , 4.9, 1.8],
[4.9, 3. , 1.4, 0.2],
[6.7, 3. , 5.2, 2.3],
[5.1, 3.5, 1.4, 0.3],
[6.3, 2.3, 4.4, 1.3],
[4.4, 2.9, 1.4, 0.2],
[6.8, 3.2, 5.9, 2.3],
[5.1, 3.8, 1.6, 0.2],
[7.2, 3.6, 6.1, 2.5],
[5.7, 3.8, 1.7, 0.3],
[5. , 2. , 3.5, 1. ],
[5. , 3. , 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.8, 5.1, 2.4],
[5.8, 4. , 1.2, 0.2],
[6.1, 2.8, 4.7, 1.2],
[5.4, 3.9, 1.7, 0.4],
[6.5, 2.8, 4.6, 1.5],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.4, 1.7, 0.2],
[4.9, 2.4, 3.3, 1. ],
[5.1, 3.4, 1.5, 0.2]]),
array([[6.2, 2.9, 4.3, 1.3],
[6.7, 3. , 5. , 1.7],
[5.2, 4.1, 1.5, 0.1],
[5.7, 2.6, 3.5, 1. ],
[7.4, 2.8, 6.1, 1.9],
[5.6, 3. , 4.5, 1.5],
[6.9, 3.1, 5.1, 2.3],
[6. , 2.2, 5. , 1.5],
[5.5, 3.5, 1.3, 0.2],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.2, 6. , 1.8],
[6. , 3. , 4.8, 1.8],
[5.2, 3.5, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.2],
[5. , 3.3, 1.4, 0.2],
[5.6, 2.8, 4.9, 2. ],
[5.6, 2.7, 4.2, 1.3],
[5. , 3.5, 1.6, 0.6],
[7.9, 3.8, 6.4, 2. ],
[6.3, 3.4, 5.6, 2.4],
[5. , 3.4, 1.6, 0.4],
[6.2, 2.8, 4.8, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.5, 4.2, 1.4, 0.2],
[4.6, 3.6, 1. , 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.7, 2.9, 4.2, 1.3],
[7.7, 2.6, 6.9, 2.3],
[7.7, 3.8, 6.7, 2.2],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 3.9, 1.2],
[6.6, 2.9, 4.6, 1.3],
[4.7, 3.2, 1.6, 0.2],
[6.7, 3.1, 4.4, 1.4],
[6.4, 3.2, 4.5, 1.5],
[4.7, 3.2, 1.3, 0.2],
[6.6, 3. , 4.4, 1.4]]),
array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
1, 0]),
array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)
y_predict = my_knn_clf.predict(result[1])
sum(y_predict == result[3])
sum(y_predict == result[3])/len(result[3])
08.手写KNN算法测试的更多相关文章
- [纯C#实现]基于BP神经网络的中文手写识别算法
效果展示 这不是OCR,有些人可能会觉得这东西会和OCR一样,直接进行整个字的识别就行,然而并不是. OCR是2维像素矩阵的像素数据.而手写识别不一样,手写可以把用户写字的笔画时间顺序,抽象成一个维度 ...
- 用C实现单隐层神经网络的训练和预测(手写BP算法)
实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...
- 手写KMeans算法
KMeans算法是一种无监督学习,它会将相似的对象归到同一类中. 其基本思想是: 1.随机计算k个类中心作为起始点. 将数据点分配到理其最近的类中心. 3.移动类中心. 4.重复2,3直至类中心不再改 ...
- 手写k-means算法
作为聚类的代表算法,k-means本属于NP难问题,通过迭代优化的方式,可以求解出近似解. 伪代码如下: 1,算法部分 距离采用欧氏距离.参数默认值随意选的. import numpy as np d ...
- Javascript 手写 LRU 算法
LRU 是 Least Recently Used 的缩写,即最近最少使用.作为一种经典的缓存策略,它的基本思想是长期不被使用的数据,在未来被用到的几率也不大,所以当新的数据进来时我们可以优先把这些数 ...
- 手写LRU算法
import java.util.LinkedHashMap; import java.util.Map; public class LRUCache<K, V> extends Link ...
- 手写hashmap算法
/** * 01.自定义一个hashmap * 02.实现put增加键值对,实现key重复时替换key的值 * 03.重写toString方法,方便查看map中的键值对信息 * 04.实现get方法, ...
- 基于kNN的手写字体识别——《机器学习实战》笔记
看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...
- 手写BP(反向传播)算法
BP算法为深度学习中参数更新的重要角色,一般基于loss对参数的偏导进行更新. 一些根据均方误差,每层默认激活函数sigmoid(不同激活函数,则更新公式不一样) 假设网络如图所示: 则更新公式为: ...
随机推荐
- SpringBoot-文件系统-Excel,PDF,XML,CSV
SpringBoot-文件系统-Excel,PDF,XML,CSV 1.Excel文件管理 1.1 POI依赖 1.2 文件读取 1.3 文件创建 1.4 文件导出 1.5 文件导出接口 2.PDF文 ...
- JDBC的操作步骤和实例()
加载JDBC驱动程序 提供JDBC连接的URL 创建数据库的连接 创建一个Statement 执行SQL语句 处理结果 关闭JDBC对象 实例JdbcUtils 创建一个JDBC程序包含7个步骤: 1 ...
- 小白搭建WNMP详细教程---PHP安装与设置
php的安装请参考WAMP中PHP的安装教程https://www.cnblogs.com/missbye/p/12049925.html 需要注意的是,我们下载的PHP版本要下载Non Thread ...
- 2019牛客暑期多校训练营(第四场)D-triples I
>传送门< 题意:求最少需要多少个3的倍数按位或后可以得到数字a 思路:利用3的倍数对应的二进制数的性质来先选出一个x,然后根据数字a再配一个y出来 首先,我们都知道十进制中,任意一个数只 ...
- Codeforces Round #690 (Div. 3)
第一次 ak cf 的正式比赛,不正式的是寒假里 div4 的 Testing Round,好啦好啦不要问我为什么没有 ak div4 了,差一题差一题 =.= 不知不觉已经咕了一个月了2333. 比 ...
- 1150 Travelling Salesman Problem
The "travelling salesman problem" asks the following question: "Given a list of citie ...
- HDU6403 Card Game【基环树 + 树形DP】
HDU6403 Card Game 题意: 给出\(N\)张卡片,卡片正反两面都有数字,现在要翻转一些卡片使得所有卡片的正面的值各不相同,问最小翻转次数和最小翻转情况下的不同方案数 \(N\le 10 ...
- Docker运行时资源限制
Docker 运行时资源限制Docker 基于 Linux 内核提供的 cgroups 功能,可以限制容器在运行时使用到的资源,比如内存.CPU.块 I/O.网络等. 内存限制概述Docker 提供的 ...
- MySQL 事务特征 & 隔离级别
数据库事务特征 Atomicity 原子性 事务是一个原子性质的操作单元,事务里面的对数据库的操作要么都执行,要么都不执行, Consistent 一致性 在事务开始之前和完成之后,数据都必须保持一致 ...
- Nginx 版本回滚
目录 参考信息 源码安装 nginx-1.14.2 版本升级 nginx-1.16.1 版本回滚 ①.对于软件的版本升级.添加官方模块.添加第三方模块,都需要用源码安装包重新生成(configure) ...