导入库

import numpy as np
from sklearn import datasets
import matplotlib.pyplot as plt

导入数据

iris = datasets.load_iris()

数据准备

X = iris.data
y = iris.target
X.shape, y.shape
((150, 4), (150,))

数据分割(28开)

# 因为训练集矩阵和标签向量是分割的,不能单独对某一个进行乱序
# 需要将其合并整体乱序再分割

X_join_y = np.hstack([X, y.reshape(-1,1)])

# 随机,导致每次数据分割结果都会改变
# 如果有debug需求,需要保证每次运行的分割结果一致
# 则需要对random进行seed设置

np.random.seed(1)
np.random.shuffle(X_join_y)
train,test = np.vsplit(X_join_y, [int(0.8*len(X_join_y))])
train.shape,test.shape
((120, 5), (30, 5))

准备data和target

# X_train, y_train, X_test, y_test 成功拿到了训练集(数据+标签)和测试集(数据+标签)

X_train = train[:,0:4]
y_train = train[:,-1]
X_test = test[:,0:4]
y_test = test[:,-1]

KNN手写算法

import numpy as np
from math import sqrt
from collections import Counter
class KNNClassifier: def __init__(self, k):
# 初始化KNN分类器
self.k = k
self._X_train = None
self._y_train = None def fit(self, X_train, y_train):
# 根据训练集X_train, Y_train训练分类器
self._X_train = X_train
self._y_train = y_train
return self def predict(self, X_predict):
# 给定待遇测的数据集X_predict,返回表示X_predict的结果向量
y_predict = [self._predict(x) for x in X_predict]
return np.array(y_predict) def _predict(self, x):
# 给定单个待遇测数据x,返回x的预测结果值
distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._X_train]
nearest = np.argsort(distances)
topK_y = [self._y_train[i] for i in nearest[:self.k]]
votes = Counter(topK_y)
return votes.most_common(1)[0][0] def __repr__(self):
return "KNN=(%d)" % self.k
from sklearn.model_selection import train_test_split

result = train_test_split(X, y)
result
[array([[7.2, 3. , 5.8, 1.6],
[5.4, 3.9, 1.3, 0.4],
[6.5, 3.2, 5.1, 2. ],
[6.1, 3. , 4.6, 1.4],
[4.6, 3.2, 1.4, 0.2],
[6.9, 3.2, 5.7, 2.3],
[6.1, 2.8, 4. , 1.3],
[5.7, 3. , 4.2, 1.2],
[5.8, 2.7, 4.1, 1. ],
[5.5, 2.5, 4. , 1.3],
[5.7, 2.5, 5. , 2. ],
[4.6, 3.4, 1.4, 0.3],
[5.9, 3.2, 4.8, 1.8],
[6.3, 2.9, 5.6, 1.8],
[6.8, 3. , 5.5, 2.1],
[6.4, 2.7, 5.3, 1.9],
[6. , 2.9, 4.5, 1.5],
[6. , 2.2, 4. , 1. ],
[4.8, 3. , 1.4, 0.1],
[5.6, 2.5, 3.9, 1.1],
[7.1, 3. , 5.9, 2.1],
[6.7, 3.3, 5.7, 2.1],
[5.5, 2.6, 4.4, 1.2],
[6.3, 3.3, 4.7, 1.6],
[6.7, 3.1, 4.7, 1.5],
[4.3, 3. , 1.1, 0.1],
[4.8, 3.4, 1.9, 0.2],
[6.7, 3.3, 5.7, 2.5],
[6. , 2.7, 5.1, 1.6],
[6.5, 3. , 5.5, 1.8],
[4.9, 2.5, 4.5, 1.7],
[5. , 3.5, 1.3, 0.3],
[5.9, 3. , 4.2, 1.5],
[5.5, 2.4, 3.8, 1.1],
[6.2, 2.2, 4.5, 1.5],
[6.3, 2.7, 4.9, 1.8],
[4.4, 3. , 1.3, 0.2],
[7.7, 3. , 6.1, 2.3],
[7. , 3.2, 4.7, 1.4],
[6.4, 2.8, 5.6, 2.2],
[5.7, 2.8, 4.5, 1.3],
[6.4, 2.9, 4.3, 1.3],
[5.6, 3. , 4.1, 1.3],
[6.3, 2.8, 5.1, 1.5],
[4.9, 3.6, 1.4, 0.1],
[6. , 3.4, 4.5, 1.6],
[5.7, 4.4, 1.5, 0.4],
[4.8, 3. , 1.4, 0.3],
[5.4, 3.7, 1.5, 0.2],
[5.4, 3.4, 1.5, 0.4],
[5. , 2.3, 3.3, 1. ],
[6.9, 3.1, 4.9, 1.5],
[5.1, 3.8, 1.9, 0.4],
[6.4, 2.8, 5.6, 2.1],
[5.1, 3.8, 1.5, 0.3],
[5. , 3.4, 1.5, 0.2],
[5.1, 3.3, 1.7, 0.5],
[5.2, 2.7, 3.9, 1.4],
[6.1, 2.6, 5.6, 1.4],
[7.7, 2.8, 6.7, 2. ],
[5.8, 2.7, 5.1, 1.9],
[6.8, 2.8, 4.8, 1.4],
[4.4, 3.2, 1.3, 0.2],
[5.3, 3.7, 1.5, 0.2],
[6.9, 3.1, 5.4, 2.1],
[5.1, 2.5, 3. , 1.1],
[5.7, 2.8, 4.1, 1.3],
[6.4, 3.1, 5.5, 1.8],
[6.2, 3.4, 5.4, 2.3],
[5.8, 2.7, 5.1, 1.9],
[6.3, 2.5, 4.9, 1.5],
[5.8, 2.6, 4. , 1.2],
[4.6, 3.1, 1.5, 0.2],
[4.9, 3.1, 1.5, 0.2],
[5.6, 2.9, 3.6, 1.3],
[5.1, 3.7, 1.5, 0.4],
[5. , 3.2, 1.2, 0.2],
[6.5, 3. , 5.8, 2.2],
[7.3, 2.9, 6.3, 1.8],
[5.2, 3.4, 1.4, 0.2],
[4.5, 2.3, 1.3, 0.3],
[5.5, 2.3, 4. , 1.3],
[6.5, 3. , 5.2, 2. ],
[5.5, 2.4, 3.7, 1. ],
[7.6, 3. , 6.6, 2.1],
[5. , 3.6, 1.4, 0.2],
[5.9, 3. , 5.1, 1.8],
[6.3, 2.5, 5. , 1.9],
[6.1, 3. , 4.9, 1.8],
[4.9, 3. , 1.4, 0.2],
[6.7, 3. , 5.2, 2.3],
[5.1, 3.5, 1.4, 0.3],
[6.3, 2.3, 4.4, 1.3],
[4.4, 2.9, 1.4, 0.2],
[6.8, 3.2, 5.9, 2.3],
[5.1, 3.8, 1.6, 0.2],
[7.2, 3.6, 6.1, 2.5],
[5.7, 3.8, 1.7, 0.3],
[5. , 2. , 3.5, 1. ],
[5. , 3. , 1.6, 0.2],
[4.8, 3.4, 1.6, 0.2],
[4.8, 3.1, 1.6, 0.2],
[6.7, 3.1, 5.6, 2.4],
[5.8, 2.8, 5.1, 2.4],
[5.8, 4. , 1.2, 0.2],
[6.1, 2.8, 4.7, 1.2],
[5.4, 3.9, 1.7, 0.4],
[6.5, 2.8, 4.6, 1.5],
[4.9, 3.1, 1.5, 0.1],
[5.4, 3.4, 1.7, 0.2],
[4.9, 2.4, 3.3, 1. ],
[5.1, 3.4, 1.5, 0.2]]),
array([[6.2, 2.9, 4.3, 1.3],
[6.7, 3. , 5. , 1.7],
[5.2, 4.1, 1.5, 0.1],
[5.7, 2.6, 3.5, 1. ],
[7.4, 2.8, 6.1, 1.9],
[5.6, 3. , 4.5, 1.5],
[6.9, 3.1, 5.1, 2.3],
[6. , 2.2, 5. , 1.5],
[5.5, 3.5, 1.3, 0.2],
[6.7, 2.5, 5.8, 1.8],
[7.2, 3.2, 6. , 1.8],
[6. , 3. , 4.8, 1.8],
[5.2, 3.5, 1.5, 0.2],
[5.1, 3.5, 1.4, 0.2],
[5. , 3.3, 1.4, 0.2],
[5.6, 2.8, 4.9, 2. ],
[5.6, 2.7, 4.2, 1.3],
[5. , 3.5, 1.6, 0.6],
[7.9, 3.8, 6.4, 2. ],
[6.3, 3.4, 5.6, 2.4],
[5. , 3.4, 1.6, 0.4],
[6.2, 2.8, 4.8, 1.8],
[5.4, 3. , 4.5, 1.5],
[5.5, 4.2, 1.4, 0.2],
[4.6, 3.6, 1. , 0.2],
[6.1, 2.9, 4.7, 1.4],
[6.4, 3.2, 5.3, 2.3],
[5.7, 2.9, 4.2, 1.3],
[7.7, 2.6, 6.9, 2.3],
[7.7, 3.8, 6.7, 2.2],
[6.3, 3.3, 6. , 2.5],
[5.8, 2.7, 3.9, 1.2],
[6.6, 2.9, 4.6, 1.3],
[4.7, 3.2, 1.6, 0.2],
[6.7, 3.1, 4.4, 1.4],
[6.4, 3.2, 4.5, 1.5],
[4.7, 3.2, 1.3, 0.2],
[6.6, 3. , 4.4, 1.4]]),
array([2, 0, 2, 1, 0, 2, 1, 1, 1, 1, 2, 0, 1, 2, 2, 2, 1, 1, 0, 1, 2, 2,
1, 1, 1, 0, 0, 2, 1, 2, 2, 0, 1, 1, 1, 2, 0, 2, 1, 2, 1, 1, 1, 2,
0, 1, 0, 0, 0, 0, 1, 1, 0, 2, 0, 0, 0, 1, 2, 2, 2, 1, 0, 0, 2, 1,
1, 2, 2, 2, 1, 1, 0, 0, 1, 0, 0, 2, 2, 0, 0, 1, 2, 1, 2, 0, 2, 2,
2, 0, 2, 0, 1, 0, 2, 0, 2, 0, 1, 0, 0, 0, 2, 2, 0, 1, 0, 1, 0, 0,
1, 0]),
array([1, 1, 0, 1, 2, 1, 2, 2, 0, 2, 2, 2, 0, 0, 0, 2, 1, 0, 2, 2, 0, 2,
1, 0, 0, 1, 2, 1, 2, 2, 2, 1, 1, 0, 1, 1, 0, 1])]
my_knn_clf = KNNClassifier(k=3)
my_knn_clf.fit(result[0], result[2])
KNN=(3)

y_predict = my_knn_clf.predict(result[1])
sum(y_predict == result[3])
sum(y_predict == result[3])/len(result[3])

08.手写KNN算法测试的更多相关文章

  1. [纯C#实现]基于BP神经网络的中文手写识别算法

    效果展示 这不是OCR,有些人可能会觉得这东西会和OCR一样,直接进行整个字的识别就行,然而并不是. OCR是2维像素矩阵的像素数据.而手写识别不一样,手写可以把用户写字的笔画时间顺序,抽象成一个维度 ...

  2. 用C实现单隐层神经网络的训练和预测(手写BP算法)

    实验要求:•实现10以内的非负双精度浮点数加法,例如输入4.99和5.70,能够预测输出为10.69•使用Gprof测试代码热度 代码框架•随机初始化1000对数值在0~10之间的浮点数,保存在二维数 ...

  3. 手写KMeans算法

    KMeans算法是一种无监督学习,它会将相似的对象归到同一类中. 其基本思想是: 1.随机计算k个类中心作为起始点. 将数据点分配到理其最近的类中心. 3.移动类中心. 4.重复2,3直至类中心不再改 ...

  4. 手写k-means算法

    作为聚类的代表算法,k-means本属于NP难问题,通过迭代优化的方式,可以求解出近似解. 伪代码如下: 1,算法部分 距离采用欧氏距离.参数默认值随意选的. import numpy as np d ...

  5. Javascript 手写 LRU 算法

    LRU 是 Least Recently Used 的缩写,即最近最少使用.作为一种经典的缓存策略,它的基本思想是长期不被使用的数据,在未来被用到的几率也不大,所以当新的数据进来时我们可以优先把这些数 ...

  6. 手写LRU算法

    import java.util.LinkedHashMap; import java.util.Map; public class LRUCache<K, V> extends Link ...

  7. 手写hashmap算法

    /** * 01.自定义一个hashmap * 02.实现put增加键值对,实现key重复时替换key的值 * 03.重写toString方法,方便查看map中的键值对信息 * 04.实现get方法, ...

  8. 基于kNN的手写字体识别——《机器学习实战》笔记

    看完一节<机器学习实战>,算是踏入ML的大门了吧!这里就详细讲一下一个demo:使用kNN算法实现手写字体的简单识别 kNN 先简单介绍一下kNN,就是所谓的K-近邻算法: [作用原理]: ...

  9. 手写BP(反向传播)算法

    BP算法为深度学习中参数更新的重要角色,一般基于loss对参数的偏导进行更新. 一些根据均方误差,每层默认激活函数sigmoid(不同激活函数,则更新公式不一样) 假设网络如图所示: 则更新公式为: ...

随机推荐

  1. ICMP&&PING

    ICMP 1.定位:互联网控制报文协议(Internet Control Message Protocol),是TCP/IP协议族的一个子协议,位于网络层.它被IP用于提供许多不同的服务.ICMP是一 ...

  2. codeforces 292E. Copying Data

    We often have to copy large volumes of information. Such operation can take up many computer resourc ...

  3. HDU 3336——Count the string

    It is well known that AekdyCoin is good at string problems as well as number theory problems. When g ...

  4. Educational DP Contest F - LCS (LCS输出路径)

    题意:有两个字符串,求他们的最长公共子序列并输出. 题解:首先跑个LCS记录一下dp数组,然后根据dp数组来反着还原路径,只有当两个位置的字符相同时才输出. 代码: char s[N],t[N]; i ...

  5. __getattr__,__getattribute__和__get__的区别

    dir(object)  列出对象的大多数属性 getattr(object, name) 从object对象中获取name字符串指定的属性 hasattr(object, name) 如果objec ...

  6. C#之Dispose

    前言 谈到Dispose,首先需要理解C#的资源 资源类型 托管资源:由CLR创建和释放 非托管资源:资源的创建和释放不由CLR管理.比如IO.网络连接.数据库连接等等.需要开发人员手动释放. 如何释 ...

  7. Python基础--核心数据类型

    python的核心数据类型: Number 数字(整数,浮点数,复数,布尔型数) String 字符串 List 列表 Tuple 元组 Dictionary 字典 Set 集合 1. 整数(整型数) ...

  8. LINUX - vim高效操作

    (一)可以为操作的一行添加下划线 set cursorline

  9. LINUX - 寄存器和堆栈

    堆栈模型: 函数调用: EBP:ESP EBP当前调用函数的栈底: ESP当前调用函数的栈顶: ---------------------------------------------------- ...

  10. leetcode15 三数之和 双指针

    注意题目没要求数字只能用一次 a + b + c = 0 即为 -b=a+c,同时要求数字不全为正(然后发现a+b+c就行...不过多想想没坏处嘛) 先处理特殊情况,然后 先排序 注意不重复,只需要有 ...