SVM学习笔记(二)----手写数字识别
引言
上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的。考虑到数据中可能存在噪音,还引入了松弛变量。
理论是抽象的,问题是具体的。站在岸上学不会游泳,光看着梨子不可能知道梨子的滋味。本篇博客就是用SVM分类算法解决一个经典的机器学习问题--手写数字识别。体会一下SVM算法的具体过程,理理它的一般性的思路。
问题的提出
人类视觉系统是世界上众多的奇迹之一。看看下面的手写数字序列:
大多数人毫不费力就能够认出这些数字为504192。如果尝试让计算机程序来识别诸如上面的数字,就会明显感受到视觉模式识别的困难。关于我们识别形状--------“9顶上有一个圈,右下方则是一条竖线”这样的简单直觉,实际上算法很难轻易表达出来。
SVM分类算法以另一个角度来考虑问题。其思路是获取大量的手写数字,常称作训练样本,然后开发出一个可以从这些训练样本中进行学习的系统。换言之,SVM使用样本来自动推断出识别手写数字的规则。随着样本数量的增加,算法可以学到更多关于手写数字的知识,这样就能够提升自身的准确性。
本文采用的数据集就是著名的“MNIST数据集”。这个数据集有60000个训练样本数据集和10000个测试用例。直接调用scikit-learn库中的SVM,使用默认的参数,1000张手写数字图片,判断准确的图片就高达9435张。
SVM的算法过程
通常,对于分类问题。我们会将数据集分成三部分,训练集、测试集、交叉验证集。用训练集训练生成模型,用测试集和交叉验证集进行验证模型的准确性。
加载数据的代码如下:
"""
mnist_loader
~~~~~~~~~~~~
一个加载模式识别图片数据的库。
"""
#### Libraries
# Standard library
import cPickle
import gzip
# Third-party libraries
import numpy as np
def load_data():
"""
返回包含训练数据、验证数据、测试数据的元组的模式识别数据
训练数据包含50,000张图片,测试数据和验证数据都只包含10,000张图片
"""
f = gzip.open('../data/mnist.pkl.gz', 'rb')
training_data, validation_data, test_data = cPickle.load(f)
f.close()
return (training_data, validation_data, test_data)
SVM算法进行训练和预测的代码如下:
"""
mnist_svm
~~~~~~~~~
使用SVM分类器,从MNIST数据集中进行手写数字识别的分类程序
"""
#### Libraries
# My libraries
import mnist_loader
# Third-party libraries
from sklearn import svm
import time
def svm_baseline():
print time.strftime('%Y-%m-%d %H:%M:%S')
training_data, validation_data, test_data = mnist_loader.load_data()
# 传递训练模型的参数,这里用默认的参数
clf = svm.SVC()
# clf = svm.SVC(C=8.0, kernel='rbf', gamma=0.00,cache_size=8000,probability=False)
# 进行模型训练
clf.fit(training_data[0], training_data[1])
# test
# 测试集测试预测结果
predictions = [int(a) for a in clf.predict(test_data[0])]
num_correct = sum(int(a == y) for a, y in zip(predictions, test_data[1]))
print "%s of %s test values correct." % (num_correct, len(test_data[1]))
print time.strftime('%Y-%m-%d %H:%M:%S')
if __name__ == "__main__":
svm_baseline()
以上代码没有用验证集进行验证。这是因为本例中,用测试集和验证集要判断的是一个东西,没有必要刻意用验证集再来验证一遍。事实上,我的确用验证集也试了一下,和测试集的结果基本一样。呵呵
直接运行代码,结果如下:
2016-01-02 14:01:46
9435 of 10000 test values correct.
2016-01-02 14:12:37
在我的ubuntu上,运行11分钟左右就可以完成训练,并预测测试集的结果。
需要说明的是,svm.SVC()函数的几个重要参数。直接用help命令查看一下文档,这里我稍微翻译了一下:
C : 浮点型,可选 (默认=1.0)。误差项的惩罚参数C
kernel : 字符型, 可选 (默认='rbf')。指定核函数类型。只能是'linear', 'poly', 'rbf', 'sigmoid', 'precomputed' 或者自定义的。如果没有指定,默认使用'rbf'。如果使用自定义的核函数,需要预先计算核矩阵。
degree : 整形, 可选 (默认=3)。用多项式核函数('poly')时,多项式核函数的参数d,用其他核函数,这个参数可忽略
gamma : 浮点型, 可选 (默认=0.0)。'rbf', 'poly' and 'sigmoid'核函数的系数。如果gamma是0,实际将使用特征维度的倒数值进行运算。也就是说,如果特征是100个维度,实际的gamma是1/100。
coef0 : 浮点型, 可选 (默认=0.0)。核函数的独立项,'poly' 和'sigmoid'核时才有意义。
可以适当调整一下SVM分类算法,看看不同参数的结果。当我的参数选择为C=100.0, kernel='rbf', gamma=0.03时,预测的准确度就已经高达98.5%了。
SVM参数的调优初探
SVM分类算法需要调整的参数就只有几个。那么这些参数如何选取,有没有一些经验性的规律呢?
- 核函数选择
如上图,线性核函数的分类边界是线性的,非线性核函数分类边界是很复杂的非线性边界。所以当能直观地观察数据时,大致可以判断分类边界,从而有倾向性地选择核函数。
- 参数gamma和C的选择
机器学习大牛Andrew Ng说,关于SVM分类算法,他一直用的是高斯核函数,其它核函数他基本就没用过。可见,这个核函数应用最广。
gamma参数,当使用高斯核进行映射时,如果选得很小的话,高次特征上的权重实际上衰减得非常快,所以实际上(数值上近似一下)相当于一个低维的子空间;反过来,如果gamma选得很大,则可以将任意的数据映射为线性可分——这样容易导致非常严重的过拟合问题。
C参数是寻找 margin 最大的超平面”和“保证数据点偏差量最小”)之间的权重。C越大,模型允许的偏差越小。
下图是一个简单的二分类情况下,不同的gamma和C对分类结果的影响。
相同的C,gamma越大,分类边界离样本越近。相同的gamma,C越大,分类越严格。
下图是不同C和gamma下分类器交叉验证准确率的热力图
由图可知,模型对gamma参数是很敏感的。如果gamma太大,无论C取多大都不能阻止过拟合。当gamma很小,分类边界很像线性的。取中间值时,好的模型的gamma和C大致分布在对角线位置。还应该注意到,当gamma取中间值时,C取值可以是很大的。
在实际项目中,这几个参数按一定的步长,多试几次,一般就能得到比较好的分类效果了。
小结
回顾一下整个问题。我们进行了如下操作。对数据集分成了三部分,训练集、测试集和交叉验证集。用SVM分类模型进行训练,依据测试集和验证集的预测结果来优化参数。依靠sklearn这个强大的机器学习库,我们也能解决手写识别这么高大上的问题了。事实上,我们只用了几行简单代码,就让测试集的预测准确率高达98.5%。
SVM算法也没有想象的那么高不可攀嘛,呵呵!
事实上,就算是一般性的机器学习问题,我们也是有一些一般性的思路的,如下:
SVM学习笔记(二)----手写数字识别的更多相关文章
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 机器学习框架ML.NET学习笔记【4】多元分类之手写数字识别
一.问题与解决方案 通过多元分类算法进行手写数字识别,手写数字的图片分辨率为8*8的灰度图片.已经预先进行过处理,读取了各像素点的灰度值,并进行了标记. 其中第0列是序号(不参与运算).1-64列是像 ...
- 机器学习框架ML.NET学习笔记【5】多元分类之手写数字识别(续)
一.概述 上一篇文章我们利用ML.NET的多元分类算法实现了一个手写数字识别的例子,这个例子存在一个问题,就是输入的数据是预处理过的,很不直观,这次我们要直接通过图片来进行学习和判断.思路很简单,就是 ...
- 学习笔记CB009:人工神经网络模型、手写数字识别、多层卷积网络、词向量、word2vec
人工神经网络,借鉴生物神经网络工作原理数学模型. 由n个输入特征得出与输入特征几乎相同的n个结果,训练隐藏层得到意想不到信息.信息检索领域,模型训练合理排序模型,输入特征,文档质量.文档点击历史.文档 ...
- 基于tensorflow的MNIST手写数字识别(二)--入门篇
http://www.jianshu.com/p/4195577585e6 基于tensorflow的MNIST手写字识别(一)--白话卷积神经网络模型 基于tensorflow的MNIST手写数字识 ...
- 实现手写数字识别(数据集50000张图片)比较3种算法神经网络、灰度平均值、SVM各自的准确率—Jason niu
对手写数据集50000张图片实现阿拉伯数字0~9识别,并且对结果进行分析准确率, 手写数字数据集下载:http://yann.lecun.com/exdb/mnist/ 首先,利用图片本身的属性,图片 ...
- 用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别
用MXnet实战深度学习之一:安装GPU版mxnet并跑一个MNIST手写数字识别 http://phunter.farbox.com/post/mxnet-tutorial1 用MXnet实战深度学 ...
- MindSpore手写数字识别初体验,深度学习也没那么神秘嘛
摘要:想了解深度学习却又无从下手,不如从手写数字识别模型训练开始吧! 深度学习作为机器学习分支之一,应用日益广泛.语音识别.自动机器翻译.即时视觉翻译.刷脸支付.人脸考勤--不知不觉,深度学习已经渗入 ...
随机推荐
- onActivityResult 通过case对不同情况进行处理
@Overridepublic void onActivityResult(int requestCode, int resultCode, Intent data) { switch (result ...
- Jquery学习笔记(3)--注册验证
嗯哼,验证用户名,密码,重复密码,手机号,邮箱.提交时全部进行验证,通过才跳转. <!DOCTYPE html> <html lang="en"> < ...
- 数据库设计(三)11 important database designing rules which I follow
原文地址:https://www.codeproject.com/Articles/359654/important-database-designing-rules-which-I-fo 麻辣个
- java 获取用户真实ip
/** * 获取用户真实ip * @param request * @return */ public static String getIpAddr(HttpServletRequest reque ...
- 用Vue.js开发一个电影App的前端界面
我们要构建一个什么样的App? 我们大多数人使用在线流媒体服务(如Netflix)观看我们最喜欢的电影或者节目.这篇文章将重点介绍如何通过使用vue.js 2 建立一个类似风格的电影流媒体WEB交互界 ...
- 【Android Studio】之构建项目报错
问题1: 报错: Could not download fastutil.jar (it.unimi.dsi:fastutil:7.2.0): No cached version available ...
- R语言安装sqldb包报错解决办法
我使用Rtudio环境,安装sqldb几次出错.网上没有好的教程. 经过自己试验之后,这样处理.我写出来以后,供大家参考. > install.packages("sqldf" ...
- C++ 类的继承六(多继承的二义性--虚基类)
//多继承的二义性--虚基类(了解为主) #include<iostream> using namespace std; /* 多继承在现在的项目开发中一般不使用,他会增加项目的复杂度 * ...
- 用用匿名函数和闭包加apply强制待定函数调用时使用特定上下文
<button id="test">点我</button> <script> var button={ clicked:false, click ...
- (转)memcache缓存
转自:http://369369.blog.51cto.com/319630/833234/ memcache分布式缓存 Memcache知识点梳理 Memcached概念: Memcached ...