引言

上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的。考虑到数据中可能存在噪音,还引入了松弛变量。

理论是抽象的,问题是具体的。站在岸上学不会游泳,光看着梨子不可能知道梨子的滋味。本篇博客就是用SVM分类算法解决一个经典的机器学习问题--手写数字识别。体会一下SVM算法的具体过程,理理它的一般性的思路。

问题的提出

人类视觉系统是世界上众多的奇迹之一。看看下面的手写数字序列:



大多数人毫不费力就能够认出这些数字为504192。如果尝试让计算机程序来识别诸如上面的数字,就会明显感受到视觉模式识别的困难。关于我们识别形状--------“9顶上有一个圈,右下方则是一条竖线”这样的简单直觉,实际上算法很难轻易表达出来。



SVM分类算法以另一个角度来考虑问题。其思路是获取大量的手写数字,常称作训练样本,然后开发出一个可以从这些训练样本中进行学习的系统。换言之,SVM使用样本来自动推断出识别手写数字的规则。随着样本数量的增加,算法可以学到更多关于手写数字的知识,这样就能够提升自身的准确性。

本文采用的数据集就是著名的“MNIST数据集”。这个数据集有60000个训练样本数据集和10000个测试用例。直接调用scikit-learn库中的SVM,使用默认的参数,1000张手写数字图片,判断准确的图片就高达9435张。

SVM的算法过程

通常,对于分类问题。我们会将数据集分成三部分,训练集、测试集、交叉验证集。用训练集训练生成模型,用测试集和交叉验证集进行验证模型的准确性。

加载数据的代码如下:

"""
mnist_loader
~~~~~~~~~~~~
一个加载模式识别图片数据的库。
""" #### Libraries
# Standard library
import cPickle
import gzip # Third-party libraries
import numpy as np def load_data():
"""
返回包含训练数据、验证数据、测试数据的元组的模式识别数据
训练数据包含50,000张图片,测试数据和验证数据都只包含10,000张图片
"""
f = gzip.open('../data/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)

SVM算法进行训练和预测的代码如下:

"""
mnist_svm
~~~~~~~~~
使用SVM分类器,从MNIST数据集中进行手写数字识别的分类程序
""" #### Libraries
# My libraries
import mnist_loader # Third-party libraries
from sklearn import svm
import time def svm_baseline():
print time.strftime('%Y-%m-%d %H:%M:%S')
training_data, validation_data, test_data = mnist_loader.load_data()
# 传递训练模型的参数,这里用默认的参数
clf = svm.SVC()
# clf = svm.SVC(C=8.0, kernel='rbf', gamma=0.00,cache_size=8000,probability=False)
# 进行模型训练
clf.fit(training_data[0], training_data[1])
# test
# 测试集测试预测结果
predictions = [int(a) for a in clf.predict(test_data[0])]
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))
print "%s of %s test values correct." % (num_correct, len(test_data[1]))
print time.strftime('%Y-%m-%d %H:%M:%S') if __name__ == "__main__":
svm_baseline()

以上代码没有用验证集进行验证。这是因为本例中,用测试集和验证集要判断的是一个东西,没有必要刻意用验证集再来验证一遍。事实上,我的确用验证集也试了一下,和测试集的结果基本一样。呵呵

直接运行代码,结果如下:

2016-01-02 14:01:46
9435 of 10000 test values correct.
2016-01-02 14:12:37

在我的ubuntu上,运行11分钟左右就可以完成训练,并预测测试集的结果。

需要说明的是,svm.SVC()函数的几个重要参数。直接用help命令查看一下文档,这里我稍微翻译了一下:

C : 浮点型,可选 (默认=1.0)。误差项的惩罚参数C

kernel : 字符型, 可选 (默认='rbf')。指定核函数类型。只能是'linear', 'poly', 'rbf', 'sigmoid', 'precomputed' 或者自定义的。如果没有指定,默认使用'rbf'。如果使用自定义的核函数,需要预先计算核矩阵。

degree : 整形, 可选 (默认=3)。用多项式核函数('poly')时,多项式核函数的参数d,用其他核函数,这个参数可忽略

gamma : 浮点型, 可选 (默认=0.0)。'rbf', 'poly' and 'sigmoid'核函数的系数。如果gamma是0,实际将使用特征维度的倒数值进行运算。也就是说,如果特征是100个维度,实际的gamma是1/100。

coef0 : 浮点型, 可选 (默认=0.0)。核函数的独立项,'poly' 和'sigmoid'核时才有意义。

可以适当调整一下SVM分类算法,看看不同参数的结果。当我的参数选择为C=100.0, kernel='rbf', gamma=0.03时,预测的准确度就已经高达98.5%了。

SVM参数的调优初探

SVM分类算法需要调整的参数就只有几个。那么这些参数如何选取,有没有一些经验性的规律呢?

  • 核函数选择



如上图,线性核函数的分类边界是线性的,非线性核函数分类边界是很复杂的非线性边界。所以当能直观地观察数据时,大致可以判断分类边界,从而有倾向性地选择核函数。

  • 参数gamma和C的选择

机器学习大牛Andrew Ng说,关于SVM分类算法,他一直用的是高斯核函数,其它核函数他基本就没用过。可见,这个核函数应用最广。

gamma参数,当使用高斯核进行映射时,如果选得很小的话,高次特征上的权重实际上衰减得非常快,所以实际上(数值上近似一下)相当于一个低维的子空间;反过来,如果gamma选得很大,则可以将任意的数据映射为线性可分——这样容易导致非常严重的过拟合问题。

C参数是寻找 margin 最大的超平面”和“保证数据点偏差量最小”)之间的权重。C越大,模型允许的偏差越小。

下图是一个简单的二分类情况下,不同的gamma和C对分类结果的影响。



相同的C,gamma越大,分类边界离样本越近。相同的gamma,C越大,分类越严格。

下图是不同C和gamma下分类器交叉验证准确率的热力图



由图可知,模型对gamma参数是很敏感的。如果gamma太大,无论C取多大都不能阻止过拟合。当gamma很小,分类边界很像线性的。取中间值时,好的模型的gamma和C大致分布在对角线位置。还应该注意到,当gamma取中间值时,C取值可以是很大的。

在实际项目中,这几个参数按一定的步长,多试几次,一般就能得到比较好的分类效果了。

小结

回顾一下整个问题。我们进行了如下操作。对数据集分成了三部分,训练集、测试集和交叉验证集。用SVM分类模型进行训练,依据测试集和验证集的预测结果来优化参数。依靠sklearn这个强大的机器学习库,我们也能解决手写识别这么高大上的问题了。事实上,我们只用了几行简单代码,就让测试集的预测准确率高达98.5%。

SVM算法也没有想象的那么高不可攀嘛,呵呵!

事实上,就算是一般性的机器学习问题,我们也是有一些一般性的思路的,如下:

SVM学习笔记(二)----手写数字识别的更多相关文章

  1. 深度学习之 mnist 手写数字识别

    深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...

  2. 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)

    上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...

  3. 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别

    一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...

  4. 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)

    一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...

  5. 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec

    人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...

  6. 基于tensorflow的MNIST手写数字识别(二)--入门篇

    http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...

  7. 实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu

    对手写数据集50000张图片实现阿拉伯数字0~9识别,并且对结果进行分析准确率, 手写数字数据集下载:http://yann.lecun.com/exdb/mnist/ 首先,利用图片本身的属性,图片 ...

  8. 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别

    用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...

  9. MindSpore手写数字识别初体验,深度学习也没那么神秘嘛

    摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...

随机推荐

  1. android6.0 adbd深入分析(二)adb驱动数据的处理、写数据到adb驱动节点

     上篇博客最后讲到在output_thread中.读取了adb驱动的数据后.就调用write_packet(t->fd, t->serial, &p)函数,把数据网socket ...

  2. 将td中文字过长的部分变成省略号显示的小技巧

    首先设置表格的样式table-layout:"fixed"再设置表格的宽度(这步必须) 最后再设置td样式的三个必要属性 代码如下: text-overflow: ellipsis ...

  3. java 线程之间的协作 wait()与notifyAll()

    watermark/2/text/aHR0cDovL2Jsb2cuY3Nkbi5uZXQvbGlhbmdydWkxOTg4/font/5a6L5L2T/fontsize/400/fill/I0JBQk ...

  4. Cocos3.0 的android返回键功能实现

    比如:Game.h   Game.cpp 头文件Game.h中定义: void onKeyReleased(EventKeyboard::KeyCode keyCode,Event * pEvent) ...

  5. 一个共通的viewModel搞定所有的分页查询一览及数据导出(easyui + knockoutjs + mvc4.0)

    前言 大家看标题就明白了我想写什么了,在做企业信息化系统中可能大家写的最多的一种页面就是查询页面了.其实每个查询页面,除了条件不太一样,数据不太一样,其它的其实都差不多.所以我就想提取一些共通的东西出 ...

  6. K-NN算法 学习总结

    1. K-NN算法简介 K-NN算法 ( K Nearest Neighbor, K近邻算法 ), 是机器学习中的一个经典算法, 比较简单且容易理解. K-NN算法通过计算新数据与训练数据特征值之间的 ...

  7. iOS开发-你真的会用SDWebImage?

    SDWebImage作为眼下最受欢迎的图片下载第三方框架,使用率非常高.可是你真的会用吗?本文接下来将通过例子分析怎样合理使用SDWebImage. 使用场景:自己定义的UITableViewCell ...

  8. 设置label中的对齐方式

    QLabel.setAlignment (self, Qt.Alignment) 下面查看Qt.Alignment: Qt.AlignmentFlag This enum type is used t ...

  9. Django model :add a non-nullable field 'SKU' to product without a default; we can't do that

    You are trying to add a non-nullable field 'SKU' to product without a default; we can't do that (the ...

  10. BZOJ 3261 最大异或和 可持久化Trie树

    题目大意:给定一个序列,提供下列操作: 1.在数组结尾插入一个数 2.给定l,r,x,求一个l<=p<=r,使x^a[p]^a[p+1]^...^a[n]最大 首先我们能够维护前缀和 然后 ...