Linear Classification

在上一讲里,我们介绍了图像分类问题以及一个简单的分类模型K-NN模型,我们已经知道K-NN的模型有几个严重的缺陷,第一就是要保存训练集里的所有样本,这个比较消耗存储空间;第二就是要遍历所有的训练样本,这种逐一比较的方式比较耗时而低效。

现在,我们要介绍一种更加强大的图像分类模型,这个模型会很自然地引申出神经网络和Convolutional Neural Networks(CNN),这个模型有两个重要的组成部分,一个是score function,将原始数据映射到输出变量;另外一个就是loss function,衡量预测值与真实值之间的误差。

我们先看模型的第一部分,定义一个score function,将图像的像素值,映射到一个输出变量,这个输出变量表示图像属于每一类的置信度或者说概率,我们假设有一批训练图像,xi∈RD,每一个训练样本都有一个类标签yi,其中,i=1,2,...N,yi∈{1,2,...K},就是说,我们有

N个训练样本,这N个训练样本属于K个类别,我们要定义的score function就是满足如下映射:f:RD→RK,这里我们先介绍一种最简单常用的线性映射,如下所示:

f(xi,W,b)=Wxi+b

在上面的表达式中,xi是一个高维向量,包含一幅图像的所有像素,将图像从m×n×3变成D×1,矩阵W(K×D )和向量b(K×1)称为模型的参数,其中W叫做权值,而b称为偏移向量,我们用下面的图来表示这个映射过程:

为了能够视觉化这个过程,我们假设图像是只有四个像素(实际情况一般至少是几千个像素),将图像变成一个列向量然后与权值W相乘,在加上偏移向量b,最后得到score,从结果来看,这个分类模型将这幅图像判定为是一条狗。

下图展示了线性分类模型对图像分类的过程,因为我们不能将高维向量可视化,所以我们假设在二维平面观看这些图像,那么线性分类模型在各个类别之间的边界就有可能如下图所示:

从上面可以看出,W的每一行都相当于某一类的分类器,从几何意义上看,如果我们改变W中某一行的值,那么该行所对应的分类器将会发生旋转。对应权值W的另外一种解释就是每一行可以看成一种模板:template,一幅图像在每一类上的score可以通过template与该图像做内积获得,这种情况下,线性分类有点像是在做模板匹配,下图给出了在CIFAR-10数据库上利用线性分类模型学习得到的template,W的每一行都相当于一个template。实际运算的时候,我们也会把偏移向量b看成是W的某一列,这样原有的权值W和b组成新的权值W′=[W;b],那么score function也可以由f(xi,W,b)变成f(xi,W)。

之前我们做运算和训练的时候,都是利用图像的原始数据,一般来说,我们需要做一些预处理,我们会将一个训练集里的所有样本做归一化。比如图像,将图像从[0,255]映射到

[-1,1]的范围,而且减去均值向量,保证训练集的均值为0。

我们已经介绍了score function,现在我们要介绍线性分类模型的另外一个重要组成部分:loss function,或者成为cost function,这个用来衡量预测值与目标值之间的误差。定义loss function的方式有很多,这里我们先介绍一种经常使用的loss function,叫做Multiclass Support Vector Machine (SVM) loss。简称 SVM loss,下面给出该函数的定义,假设训练集第i个样本的输入为xi,yi表示该样本属于第几类,利用score function f(xi,W)我们可以计算该样本xi属于每一类的score,比如f(xi,W)j表示样本xi属于第j类的score,那么该loss function定义为:

Li=∑j≠yimax(0,f(xi,W)j−f(xi,W)yi+Δ)

请注意,由于我们这里介绍的是线性模型f(xi,W)=Wxi,所以我们也可以将上式重新写成:

Li=∑j≠yimax(0,wTjxi−wTyixi+Δ)

其中,wTj表示W的第j行,如果是今后介绍的更加复杂的模型,上面这个表达式就不一定成立。上面的max(0,−)函数称为hinge loss,这是线性的hinge loss,有的时候也会用二次的hinge loss:max(0,−)2,下图解释了loss function的作用。Δ给出了其他类与某一类相差的界限,如果其他类与某一类相差的在这个界限之外,那么这些误差不会累计到loss function,反之,如果相差在界限范围内,这些误差就会累计到loss function,所以我们的目标就是寻找满足条件的参数W,使得训练样本都能被正确分类,并且让loss function尽可能地低。

为了进一步提升模型的稳健性,我们会引入regularization penalty,R(W),最常见的形式是二次式:R(W)=∑i∑jW2ij,所以引入R(W)之后,loss

function就包含数据误差和regularization penalty两部分,如下式所示:

L=1N∑iLi+λR(W)

展开之后得到:

L=1N∑i∑j≠yi[max(0,f(xi,W)j−f(xi,W)yi+Δ)]+λ∑i∑jW2ij

通过引入regularization penalty,可以使得权值的分布更加平衡,不会单独侧重于某些局部变量。

前面我们忽略了Δ值的探讨,Δ应该选择多少比较合适?在实际应用中,我们发现把Δ设为1.0是非常安全的,事实上,参数Δ,λ都是控制loss function中数据偏差与regularization penalty之间的平衡的,因为W的幅值对score有直接的影响,如果我们把幅值增大,那么预测的score也会变大,反之同样成立,所以Δ设为1.0还是100.0对最终的数据偏差不会有太多影响,因为可以通过调整W的幅值来消除Δ大小带来的影响,因此,起关键作用的是λ,控制着W以多大的步幅变化。

Softmax classifier

前面介绍的SVM是线性分类器,现在我们介绍另外一种常用的非线性分类器,Softmax classifier。SVM将预测值看做是一种score,而Softmax classifier将预测值看成是一种概率,Softmax classifier的映射函数没有变化,还是f(xi;W)=Wxi,但是它的loss function采取了另外一种形式,称为cross-entropy loss,其定义如下:

Li=−log⎛⎝efyi∑jefj⎞⎠=−fyi+log∑jefj

这里,我们用fj表示对第j类的预测值,与SVM一样,整个训练集的loss function将是所有样本的平均loss加上regularization误差R(W),函数fj(z)=ezj∑kezk称为softmax函数,它可以将一组实数映射到[0,1]之间,并且其和为1,从信息论的角度看,cross entropy衡量地是一个实际分布p和一个估计的分布q之间的相关性:

H(p,q)=−∑xp(x)log(q(x))

因此,Softmax分类器是缩小预测的每一类的概率与实际概率的cross entropy。

从概率的角度来看,我们可以看到表达式:

P(yi|xi;W)=efyi∑jefj

可以看做是给定一张图像,其属于某一类的概率,指数项给出了概率值,而分母的归一化保证概率在[0,1]之间,而且其和为1,这样我们可以引入最大似然估计去解释这个

模型,如果进一步的,我们假设W是属于某一特定分布,比如高斯分布,那么我们可以用最大后验概率估计去解释这个模型,这里提到这些,只是为了让大家对此有一个

直观的了解。实际编写程序的时候,由于指数运算可能会涉及到很大的值,可能会使得模型在数值上不够稳定,所以一般会引入一个常数项C,如下所示:

efyi∑jefj=CefyiC∑jefj=efyi+logC∑jefj+logC

C的选择没有特别地规定,可以自由选择,通常我们定义logC=−maxjfj。下图显示了SVM与Softmax分类器做图像分类的区别:

声明:lecture notes里的图片都来源于该课程的网站,只能用于学习,

请勿作其它用途,如需转载,请说明该课程(http://cs231n.stanford.edu/)为引用来源。

Convolutional Neural Networks for Visual Recognition 2的更多相关文章

  1. Convolutional Neural Networks for Visual Recognition 1

    Introduction 这是斯坦福计算机视觉大牛李菲菲最新开设的一门关于deep learning在计算机视觉领域的相关应用的课程.这个课程重点介绍了deep learning里的一种比较流行的模型 ...

  2. Convolutional Neural Networks for Visual Recognition

    http://cs231n.github.io/   里面有很多相当好的文章 http://cs231n.github.io/convolutional-networks/ Table of Cont ...

  3. 卷积神经网络用于视觉识别Convolutional Neural Networks for Visual Recognition

    Table of Contents: Architecture Overview ConvNet Layers Convolutional Layer Pooling Layer Normalizat ...

  4. Convolutional Neural Networks for Visual Recognition 8

    Convolutional Neural Networks (CNNs / ConvNets) 前面做了如此漫长的铺垫,现在终于来到了课程的重点.Convolutional Neural Networ ...

  5. Convolutional Neural Networks for Visual Recognition 5

    Setting up the data and the model 前面我们介绍了一个神经元的模型,通过一个激励函数将高维的输入域权值的点积转化为一个单一的输出,而神经网络就是将神经元排列到每一层,形 ...

  6. Convolutional Neural Networks for Visual Recognition 7

    Two Simple Examples softmax classifier 后,我们介绍两个简单的例子,一个是线性分类器,一个是神经网络.由于网上的讲义给出的都是代码,我们这里用公式来进行推导.首先 ...

  7. Convolutional Neural Networks for Visual Recognition 4

    Modeling one neuron 下面我们开始介绍神经网络,我们先从最简单的一个神经元的情况开始,一个简单的神经元包括输入,激励函数以及输出.如下图所示: 一个神经元类似一个线性分类器,如果激励 ...

  8. cs231n spring 2017 lecture1 Introduction to Convolutional Neural Networks for Visual Recognition 听课笔记

    1. 生物学家做实验发现脑皮层对简单的结构比如角.边有反应,而通过复杂的神经元传递,这些简单的结构最终帮助生物体有了更复杂的视觉系统.1970年David Marr提出的视觉处理流程遵循这样的原则,拿 ...

  9. Stanford CS231n - Convolutional Neural Networks for Visual Recognition

    网易云课堂上有汉化的视频:http://study.163.com/course/courseLearn.htm?courseId=1003223001#/learn/video?lessonId=1 ...

随机推荐

  1. 阿里巴巴产品实习生N天

    时间貌似有些太遥远,已经没办法从刚来时的日子一天一天数.连上内网打开内外.看到45天,每一次不经意的邂逅总会让人认为奇妙而微妙,每一次的巧合总会让人认为是神在显灵(但愿天津安好,这里也曾在我心中滋润过 ...

  2. 微信小程序事件

    微信小程序事件1.什么是事件2.事件类别3.事件冒泡4.事件绑定5.事件对象详解笔记:1.事件是一种用户的行为,是一种通讯方式.2.事件类别:    点击事件:tap    长按事件:longtap  ...

  3. (转)linux设备驱动之USB数据传输分析 一

    三:传输过程的实现说到传输过程,我们必须要从URB开始说起,这个结构的就好比是网络子系统中的skb,好比是I/O中的bio.USB系统的信息传输就是打成URB结构,然后再过行传送的.URB的全称叫US ...

  4. 百度地图SnapshotReadyCallback截屏

    今天碰到了地图截图的功能,不太会,查查资料知道怎么弄了,跟大家分享一下 直接上代码,弄了一个方法,将截取的图片上传至服务器,返回给我们图片路径 //获取地图截图 private void getscr ...

  5. group_concat函数导致的主从同步异常

    group_concat函数导致的主从同步异常的问题总结 今天在处理一个group_concat函数导致的主从异常的问题,排查过程比较简单,不过第一次遇到这个问题记录一下排查的思路,后面如果再遇到其他 ...

  6. extendgcd模板

    看了数论第一章,终于搞懂了扩展欧几里德,其实就是普通欧几里德的逆推过程. // ax+by = gcd(a,b) ->求解x,y 其中a,b不全为0,可以为负数// 复杂度:O(log2a)vo ...

  7. 还在用 kill -9 停机?这才是最优雅的姿势(转)

    _ 最近瞥了一眼项目的重启脚本,发现运维一直在使用 kill-9<pid> 的方式重启 springboot embedded tomcat,其实大家几乎一致认为:kill-9<pi ...

  8. cocos2dx使用cocostudio导出的animation

    local uilocal function createLayerUI() if not ui then ui=cc.Layer:create(); createLayerUI=nil; end r ...

  9. Python中pymysql模块详解

    安装 pip install pymysql 使用操作 执行SQL #!/usr/bin/env pytho # -*- coding:utf-8 -*- import pymysql # 创建连接 ...

  10. Linux c编程:同步属性

    就像线程具有属性一样,线程的同步对象(如互斥量.读写锁.条件变量.自旋锁和屏障)也有属性 1.互斥量属性 用pthread_mutexattr_init初始化pthread_mutexattr_t结构 ...