【cs231n】图像分类-Nearest Neighbor Classifier(最近邻分类器)【python3实现】
【学习自CS231n课程】
转载请注明出处:http://www.cnblogs.com/GraceSkyer/p/8735908.html
图像分类:
一张图像的表示:长度、宽度、通道(3个颜色通道,分别是红R、绿G、蓝B)。
对于计算机来说,图像是一个由数字组成的巨大的三维数组,数组元素是取值范围从0到255的整数,其中0表示全黑,255表示全白。
图像分类的任务:对于一个给定的图像,预测它属于的那个分类标签。
如何写图像分类算法呢?
数据驱动方法:
图像分类流程:
图像分类:输入一个元素为像素值的数组,然后给它分配一个分类标签。完整流程如下:
- 输入:输入是包含N个图像的集合,每个图像的标签是K种分类标签中的一种。这个集合称为训练集。
- 学习:这一步的任务是使用训练集来学习每个类到底长什么样。一般该步骤叫做训练分类器或者学习一个模型。
- 评价:让分类器来预测它未曾见过的图像的分类标签,并以此来评价分类器的质量。我们会把分类器预测的标签和图像真正的分类标签对比。毫无疑问,分类器预测的分类标签和图像真正的分类标签如果一致,那就是好事,这样的情况越多越好。
Nearest Neighbor Classifier
最简单的分类器......
最近邻算法:在训练机器的过程中,我们什么也不做,我们只是单纯记录所有的训练数据,在图片预测的步骤,我们会拿一些新的图片去在训练数据中寻找与新图片最相似的,然后基于此,来给出一个标签。
例:用最近邻算法用于这数据集中的图片,在训练集中找到最接近的样本:图像分类数据集:CIFAR-10
这个数据集包含了60000张32X32的小图像。每张图像都有10种分类标签中的一种。这60000张图像被分为包含50000张图像的训练集和包含10000张图像的测试集。
我们需要的是,输入一张测试图片,从训练集中找出与它L1距离最近的一张训练图,将其所对应的分类标签作为答案,也就是作为测试图片的分类标签。
下面,让我们看看如何用代码来实现这个分类器。首先,我们将CIFAR-10的数据加载到内存中,并分成4个数组:训练数据和标签,测试数据和标签。在下面的代码中,Xtr(大小是50000x32x32x3)存有训练集中所有的图像,Ytr是对应的长度为50000的1维数组,存有图像对应的分类标签(从0到9):
Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar10/') # a magic function we provide
# flatten out all images to be one-dimensional
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072
现在我们得到所有的图像数据,并且把他们拉长成为行向量了。接下来展示如何训练并评价一个分类器:
nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr_rows, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print 'accuracy: %f' % ( np.mean(Yte_predict == Yte) )
作为评价标准,我们常常使用准确率,它描述了我们预测正确的得分。请注意以后我们实现的所有分类器都需要有这个API:train(X, y)函数。该函数使用训练集的数据和标签来进行训练。从其内部来看,类应该实现一些关于标签和标签如何被预测的模型。这里还有个predict(X)函数,它的作用是预测输入的新数据的分类标签。
L1距离(即两尺寸相同图对应位置间差异的和):
完整代码:
Python3 实现:【使用L1距离的Nearest Neighbor分类器】
【我是将cifar-10-batches-py数据放在代码同一目录下,代码中具体文件位置请根据情况自行设置】
import numpy as np
import pickle
import os class NearestNeighbor(object):
def __init__(self):
pass def train(self, X, y):
""" X is N x D where each row is an example. Y is 1-dimension of size N """
# the nearest neighbor classifier simply remembers all the training data
self.Xtr = X
self.ytr = y def predict(self, X):
""" X is N x D where each row is an example we wish to predict label for """
num_test = X.shape[0]
# lets make sure that the output type matches the input type
Ypred = np.zeros(num_test, dtype=self.ytr.dtype) # loop over all test rows
for i in range(num_test):
# find the nearest training image to the i'th test image
# using the L1 distance (sum of absolute value differences)
distances = np.sum(np.abs(self.Xtr - X[i, :]), axis=1)
min_index = np.argmin(distances) # get the index with smallest distance
Ypred[i] = self.ytr[min_index] # predict the label of the nearest example return Ypred def load_CIFAR_batch(file):
""" load single batch of cifar """
with open(file, 'rb') as f:
datadict = pickle.load(f, encoding='latin1')
X = datadict['data']
Y = datadict['labels']
X = X.reshape(10000, 3, 32, 32).transpose(0, 2, 3, 1).astype("float")
Y = np.array(Y)
return X, Y def load_CIFAR10(ROOT):
""" load all of cifar """
xs = []
ys = []
for b in range(1,6):
f = os.path.join(ROOT, 'data_batch_%d' % (b, ))
X, Y = load_CIFAR_batch(f)
xs.append(X)
ys.append(Y)
Xtr = np.concatenate(xs) # 使变成行向量
Ytr = np.concatenate(ys)
del X, Y
Xte, Yte = load_CIFAR_batch(os.path.join(ROOT, 'test_batch'))
return Xtr, Ytr, Xte, Yte Xtr, Ytr, Xte, Yte = load_CIFAR10('data/cifar-10-batches-py/') # a magic function we provide
# flatten out all images to be one-dimensional
Xtr_rows = Xtr.reshape(Xtr.shape[0], 32 * 32 * 3) # Xtr_rows becomes 50000 x 3072
Xte_rows = Xte.reshape(Xte.shape[0], 32 * 32 * 3) # Xte_rows becomes 10000 x 3072 nn = NearestNeighbor() # create a Nearest Neighbor classifier class
nn.train(Xtr_rows, Ytr) # train the classifier on the training images and labels
Yte_predict = nn.predict(Xte_rows) # predict labels on the test images
# and now print the classification accuracy, which is the average number
# of examples that are correctly predicted (i.e. label matches)
print('accuracy: %f' % (np.mean(Yte_predict == Yte)))
然后我运行了很久......几个小时?
运行结果:accuracy: 0.385900
若用L2距离(欧式距离):
修改上面代码1行就可以:
distances = np.sqrt(np.sum(np.square(self.Xtr - X[i,:]), axis = 1))
关于简单分类器的问题:
- 如果我们有N个实例,训练和测试的过程可以有多快?
- Train:O(1),predict:O(n)
这是糟糕的。训练是连续的过程,因为我们并不需要做任何事情, 我们只需存储数据。但在测试时,我们需要停下来,将数据集中N个训练实例 与我们的测试图像进行比较,这是个很慢的过程。
但是在实际使用中,我们希望训练过程比较慢,而测试过程快。因为,训练过程是在数据中心中完成的,它可以负担起非常大的运算量,从而训练出一个优秀的分类器。 然而,当你在测试过程部署分类器时,你希望它运行在手机上、浏览器、或其他低功耗设备,但你又希望分类器能够快速地运行,由此看来,最近邻算法有点落后了。。。
参考:
https://www.bilibili.com/video/av17204303/?from=search&seid=6625954842411789830
https://zhuanlan.zhihu.com/p/20894041?refer=intelligentunit
https://blog.csdn.net/dawningblue/article/details/75119639
https://www.cnblogs.com/hans209/p/6919851.html
【cs231n】图像分类-Nearest Neighbor Classifier(最近邻分类器)【python3实现】的更多相关文章
- 【cs231n】图像分类 k-Nearest Neighbor Classifier(K最近邻分类器)【python3实现】
[学习自CS231n课程] 转载请注明出处:http://www.cnblogs.com/GraceSkyer/p/8763616.html k-Nearest Neighbor(KNN)分类器 与其 ...
- CS231n——图像分类(KNN实现)
图像分类 目标:已有固定的分类标签集合,然后对于输入的图像,从分类标签集合中找出一个分类标签,最后把分类标签分配给该输入图像. 图像分类流程 输入:输入是包含N个图像的集合,每个图像的标签是K ...
- cs231n笔记 (一) 线性分类器
Liner classifier 线性分类器用作图像分类主要有两部分组成:一个是假设函数, 它是原始图像数据到类别的映射.另一个是损失函数,该方法可转化为一个最优化问题,在最优化过程中,将通过更新假设 ...
- Nearest neighbor graph | 近邻图
最近在开发一套自己的单细胞分析方法,所以copy paste事业有所停顿. 实例: R eNetIt v0.1-1 data(ralu.site) # Saturated spatial graph ...
- Approximate Nearest Neighbors.接近最近邻搜索
(一):次优最近邻:http://en.wikipedia.org/wiki/Nearest_neighbor_search 有少量修改:如有疑问,请看链接原文.....1.Survey:Neares ...
- K Nearest Neighbor 算法
文章出处:http://coolshell.cn/articles/8052.html K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KN ...
- Visualizing MNIST with t-SNE, MDS, Sammon’s Mapping and Nearest neighbor graph
MNIST 可视化 Visualizing MNIST: An Exploration of Dimensionality Reduction At some fundamental level, n ...
- Nearest Neighbor Search
## Nearest Neighbor Search ## Input file: standard input Output file: standard output Time limit: 1 ...
- K NEAREST NEIGHBOR 算法(knn)
K Nearest Neighbor算法又叫KNN算法,这个算法是机器学习里面一个比较经典的算法, 总体来说KNN算法是相对比较容易理解的算法.其中的K表示最接近自己的K个数据样本.KNN算法和K-M ...
随机推荐
- AE文档保存
private void barButtonItem4_ItemClick(object sender, DevExpress.XtraBars.ItemClickEventArgs e)//保存 { ...
- Walkway.js – 创建简约的 SVG 线条动画
Walkway.js 是一个使用线条和路径元素组成 SVG 动画图像的简单方法.只需根据提供的配置对象创建一个新的 Walkway 实例就可以了.这种效果特别适合那些崇尚简约设计风格的网页.目前, W ...
- JQuery判断浏览器类型(IE, Firefox…)
1 2 3 4 5 6 7 8 9 10 11 $(function() { if ($.browser.msie) { alert("这是一个IE浏览器&q ...
- 撩课-Web大前端每天5道面试题-Day26
1.vuejs与angularjs以及react的区别? .与AngularJS的区别 相同点: 都支持指令:内置指令和自定义指令. 都支持过滤器:内置过滤器和自定义过滤器. 都支持双向数据绑定. 都 ...
- Telnet 模拟邮件发送过程
Telnet 模拟邮件发送过程 windows要提前开启Telnet客户端的功能,再按照下面步骤完成邮件发送: 1.通过 cmd 进入命令窗口 2.连接要发送邮件的服务器:telnet smtp.al ...
- Android开发之旅5:应用程序基础及组件
引言 上篇Android开发之旅:应用程序基础及组件介绍了应用程序的基础知识及Android的四个组件,本篇将介绍如何激活组关闭组件等.本文的主题如下: 1.激活组件:意图(Intents) 1.1. ...
- Python中操作HTTP请求的urllib模块详解
urllib 是 Python 标准库中用于网络请求的库.该库有四个模块,分别是urllib.request,urllib.error,urllib.parse,urllib.robotparser. ...
- 微信支付报错:time_expire时间过短,刷卡至少1分钟,其他5分钟]
查了下代码: $input->SetTime_expire(date("YmdHis", time() + 600));//二维码过期时间.默认10min 10分钟,没问题. ...
- ORA-00054 资源正忙
现象: 执行update.truncate提示 ORA-00054: resource busy and acquire with NOWAIT specified. 解决方法: 因为系统是RAC系统 ...
- ios 下拉列表
#import <UIKit/UIKit.h> @class FVPullDownMenu; /** 指示器状态*/ typedef enum { IndicatorStateShow = ...