学习OpenCV——SVM 手写数字检测
转自http://blog.csdn.net/firefight/article/details/6452188
是MNIST手写数字图片库:http://code.google.com/p/supplement-of-the-mnist-database-of-handwritten-digits/downloads/list
其他方法:http://blog.csdn.net/onezeros/article/details/5672192
使用OPENCV训练手写数字识别分类器
1,下载训练数据和测试数据文件,这里用的是MNIST手写数字图片库,其中训练数据库中为60000个,测试数据库中为10000个
2,创建训练数据和测试数据文件读取函数,注意字节顺序为大端
3,确定字符特征方式为最简单的8×8网格内的字符点数
4,创建SVM,训练并读取,结果如下
1000个训练样本,测试数据正确率80.21%(并没有体现SVM小样本高准确率的特性啊)
10000个训练样本,测试数据正确率95.45%
60000个训练样本,测试数据正确率97.67%
5,编写手写输入的GUI程序,并进行验证,效果还可以接受。
以下为主要代码,以供参考
(类似的也实现了随机树分类器,比较发现在相同的样本数情况下,SVM准确率略高)
- #include "stdafx.h"
- #include <fstream>
- #include "opencv2/opencv.hpp"
- #include <vector>
- using namespace std;
- using namespace cv;
- #define SHOW_PROCESS 0
- #define ON_STUDY 0
- class NumTrainData
- {
- public:
- NumTrainData()
- {
- memset(data, 0, sizeof(data));
- result = -1;
- }
- public:
- float data[64];
- int result;
- };
- vector<NumTrainData> buffer;
- int featureLen = 64;
- void swapBuffer(char* buf)
- {
- char temp;
- temp = *(buf);
- *buf = *(buf+3);
- *(buf+3) = temp;
- temp = *(buf+1);
- *(buf+1) = *(buf+2);
- *(buf+2) = temp;
- }
- void GetROI(Mat& src, Mat& dst)
- {
- int left, right, top, bottom;
- left = src.cols;
- right = 0;
- top = src.rows;
- bottom = 0;
- //Get valid area
- for(int i=0; i<src.rows; i++)
- {
- for(int j=0; j<src.cols; j++)
- {
- if(src.at<uchar>(i, j) > 0)
- {
- if(j<left) left = j;
- if(j>right) right = j;
- if(i<top) top = i;
- if(i>bottom) bottom = i;
- }
- }
- }
- //Point center;
- //center.x = (left + right) / 2;
- //center.y = (top + bottom) / 2;
- int width = right - left;
- int height = bottom - top;
- int len = (width < height) ? height : width;
- //Create a squre
- dst = Mat::zeros(len, len, CV_8UC1);
- //Copy valid data to squre center
- Rect dstRect((len - width)/2, (len - height)/2, width, height);
- Rect srcRect(left, top, width, height);
- Mat dstROI = dst(dstRect);
- Mat srcROI = src(srcRect);
- srcROI.copyTo(dstROI);
- }
- int ReadTrainData(int maxCount)
- {
- //Open image and label file
- const char fileName[] = "../res/train-images.idx3-ubyte";
- const char labelFileName[] = "../res/train-labels.idx1-ubyte";
- ifstream lab_ifs(labelFileName, ios_base::binary);
- ifstream ifs(fileName, ios_base::binary);
- if( ifs.fail() == true )
- return -1;
- if( lab_ifs.fail() == true )
- return -1;
- //Read train data number and image rows / cols
- char magicNum[4], ccount[4], crows[4], ccols[4];
- ifs.read(magicNum, sizeof(magicNum));
- ifs.read(ccount, sizeof(ccount));
- ifs.read(crows, sizeof(crows));
- ifs.read(ccols, sizeof(ccols));
- int count, rows, cols;
- swapBuffer(ccount);
- swapBuffer(crows);
- swapBuffer(ccols);
- memcpy(&count, ccount, sizeof(count));
- memcpy(&rows, crows, sizeof(rows));
- memcpy(&cols, ccols, sizeof(cols));
- //Just skip label header
- lab_ifs.read(magicNum, sizeof(magicNum));
- lab_ifs.read(ccount, sizeof(ccount));
- //Create source and show image matrix
- Mat src = Mat::zeros(rows, cols, CV_8UC1);
- Mat temp = Mat::zeros(8, 8, CV_8UC1);
- Mat img, dst;
- char label = 0;
- Scalar templateColor(255, 0, 255 );
- NumTrainData rtd;
- //int loop = 1000;
- int total = 0;
- while(!ifs.eof())
- {
- if(total >= count)
- break;
- total++;
- cout << total << endl;
- //Read label
- lab_ifs.read(&label, 1);
- label = label + '0';
- //Read source data
- ifs.read((char*)src.data, rows * cols);
- GetROI(src, dst);
- #if(SHOW_PROCESS)
- //Too small to watch
- img = Mat::zeros(dst.rows*10, dst.cols*10, CV_8UC1);
- resize(dst, img, img.size());
- stringstream ss;
- ss << "Number " << label;
- string text = ss.str();
- putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
- //imshow("img", img);
- #endif
- rtd.result = label;
- resize(dst, temp, temp.size());
- //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
- for(int i = 0; i<8; i++)
- {
- for(int j = 0; j<8; j++)
- {
- rtd.data[ i*8 + j] = temp.at<uchar>(i, j);
- }
- }
- buffer.push_back(rtd);
- //if(waitKey(0)==27) //ESC to quit
- // break;
- maxCount--;
- if(maxCount == 0)
- break;
- }
- ifs.close();
- lab_ifs.close();
- return 0;
- }
- void newRtStudy(vector<NumTrainData>& trainData)
- {
- int testCount = trainData.size();
- Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
- Mat res = Mat::zeros(testCount, 1, CV_32SC1);
- for (int i= 0; i< testCount; i++)
- {
- NumTrainData td = trainData.at(i);
- memcpy(data.data + i*featureLen*sizeof(float), td.data, featureLen*sizeof(float));
- res.at<unsigned int>(i, 0) = td.result;
- }
- /////////////START RT TRAINNING//////////////////
- CvRTrees forest;
- CvMat* var_importance = 0;
- forest.train( data, CV_ROW_SAMPLE, res, Mat(), Mat(), Mat(), Mat(),
- CvRTParams(10,10,0,false,15,0,true,4,100,0.01f,CV_TERMCRIT_ITER));
- forest.save( "new_rtrees.xml" );
- }
- int newRtPredict()
- {
- CvRTrees forest;
- forest.load( "new_rtrees.xml" );
- const char fileName[] = "../res/t10k-images.idx3-ubyte";
- const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
- ifstream lab_ifs(labelFileName, ios_base::binary);
- ifstream ifs(fileName, ios_base::binary);
- if( ifs.fail() == true )
- return -1;
- if( lab_ifs.fail() == true )
- return -1;
- char magicNum[4], ccount[4], crows[4], ccols[4];
- ifs.read(magicNum, sizeof(magicNum));
- ifs.read(ccount, sizeof(ccount));
- ifs.read(crows, sizeof(crows));
- ifs.read(ccols, sizeof(ccols));
- int count, rows, cols;
- swapBuffer(ccount);
- swapBuffer(crows);
- swapBuffer(ccols);
- memcpy(&count, ccount, sizeof(count));
- memcpy(&rows, crows, sizeof(rows));
- memcpy(&cols, ccols, sizeof(cols));
- Mat src = Mat::zeros(rows, cols, CV_8UC1);
- Mat temp = Mat::zeros(8, 8, CV_8UC1);
- Mat m = Mat::zeros(1, featureLen, CV_32FC1);
- Mat img, dst;
- //Just skip label header
- lab_ifs.read(magicNum, sizeof(magicNum));
- lab_ifs.read(ccount, sizeof(ccount));
- char label = 0;
- Scalar templateColor(255, 0, 0);
- NumTrainData rtd;
- int right = 0, error = 0, total = 0;
- int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
- while(ifs.good())
- {
- //Read label
- lab_ifs.read(&label, 1);
- label = label + '0';
- //Read data
- ifs.read((char*)src.data, rows * cols);
- GetROI(src, dst);
- //Too small to watch
- img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
- resize(dst, img, img.size());
- rtd.result = label;
- resize(dst, temp, temp.size());
- //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
- for(int i = 0; i<8; i++)
- {
- for(int j = 0; j<8; j++)
- {
- m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
- }
- }
- if(total >= count)
- break;
- char ret = (char)forest.predict(m);
- if(ret == label)
- {
- right++;
- if(total <= 5000)
- right_1++;
- else
- right_2++;
- }
- else
- {
- error++;
- if(total <= 5000)
- error_1++;
- else
- error_2++;
- }
- total++;
- #if(SHOW_PROCESS)
- stringstream ss;
- ss << "Number " << label << ", predict " << ret;
- string text = ss.str();
- putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
- imshow("img", img);
- if(waitKey(0)==27) //ESC to quit
- break;
- #endif
- }
- ifs.close();
- lab_ifs.close();
- stringstream ss;
- ss << "Total " << total << ", right " << right <<", error " << error;
- string text = ss.str();
- putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
- imshow("img", img);
- waitKey(0);
- return 0;
- }
- void newSvmStudy(vector<NumTrainData>& trainData)
- {
- int testCount = trainData.size();
- Mat m = Mat::zeros(1, featureLen, CV_32FC1);
- Mat data = Mat::zeros(testCount, featureLen, CV_32FC1);
- Mat res = Mat::zeros(testCount, 1, CV_32SC1);
- for (int i= 0; i< testCount; i++)
- {
- NumTrainData td = trainData.at(i);
- memcpy(m.data, td.data, featureLen*sizeof(float));
- normalize(m, m);
- memcpy(data.data + i*featureLen*sizeof(float), m.data, featureLen*sizeof(float));
- res.at<unsigned int>(i, 0) = td.result;
- }
- /////////////START SVM TRAINNING//////////////////
- CvSVM svm = CvSVM();
- CvSVMParams param;
- CvTermCriteria criteria;
- criteria= cvTermCriteria(CV_TERMCRIT_EPS, 1000, FLT_EPSILON);
- param= CvSVMParams(CvSVM::C_SVC, CvSVM::RBF, 10.0, 8.0, 1.0, 10.0, 0.5, 0.1, NULL, criteria);
- svm.train(data, res, Mat(), Mat(), param);
- svm.save( "SVM_DATA.xml" );
- }
- int newSvmPredict()
- {
- CvSVM svm = CvSVM();
- svm.load( "SVM_DATA.xml" );
- const char fileName[] = "../res/t10k-images.idx3-ubyte";
- const char labelFileName[] = "../res/t10k-labels.idx1-ubyte";
- ifstream lab_ifs(labelFileName, ios_base::binary);
- ifstream ifs(fileName, ios_base::binary);
- if( ifs.fail() == true )
- return -1;
- if( lab_ifs.fail() == true )
- return -1;
- char magicNum[4], ccount[4], crows[4], ccols[4];
- ifs.read(magicNum, sizeof(magicNum));
- ifs.read(ccount, sizeof(ccount));
- ifs.read(crows, sizeof(crows));
- ifs.read(ccols, sizeof(ccols));
- int count, rows, cols;
- swapBuffer(ccount);
- swapBuffer(crows);
- swapBuffer(ccols);
- memcpy(&count, ccount, sizeof(count));
- memcpy(&rows, crows, sizeof(rows));
- memcpy(&cols, ccols, sizeof(cols));
- Mat src = Mat::zeros(rows, cols, CV_8UC1);
- Mat temp = Mat::zeros(8, 8, CV_8UC1);
- Mat m = Mat::zeros(1, featureLen, CV_32FC1);
- Mat img, dst;
- //Just skip label header
- lab_ifs.read(magicNum, sizeof(magicNum));
- lab_ifs.read(ccount, sizeof(ccount));
- char label = 0;
- Scalar templateColor(255, 0, 0);
- NumTrainData rtd;
- int right = 0, error = 0, total = 0;
- int right_1 = 0, error_1 = 0, right_2 = 0, error_2 = 0;
- while(ifs.good())
- {
- //Read label
- lab_ifs.read(&label, 1);
- label = label + '0';
- //Read data
- ifs.read((char*)src.data, rows * cols);
- GetROI(src, dst);
- //Too small to watch
- img = Mat::zeros(dst.rows*30, dst.cols*30, CV_8UC3);
- resize(dst, img, img.size());
- rtd.result = label;
- resize(dst, temp, temp.size());
- //threshold(temp, temp, 10, 1, CV_THRESH_BINARY);
- for(int i = 0; i<8; i++)
- {
- for(int j = 0; j<8; j++)
- {
- m.at<float>(0,j + i*8) = temp.at<uchar>(i, j);
- }
- }
- if(total >= count)
- break;
- normalize(m, m);
- char ret = (char)svm.predict(m);
- if(ret == label)
- {
- right++;
- if(total <= 5000)
- right_1++;
- else
- right_2++;
- }
- else
- {
- error++;
- if(total <= 5000)
- error_1++;
- else
- error_2++;
- }
- total++;
- #if(SHOW_PROCESS)
- stringstream ss;
- ss << "Number " << label << ", predict " << ret;
- string text = ss.str();
- putText(img, text, Point(10, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
- imshow("img", img);
- if(waitKey(0)==27) //ESC to quit
- break;
- #endif
- }
- ifs.close();
- lab_ifs.close();
- stringstream ss;
- ss << "Total " << total << ", right " << right <<", error " << error;
- string text = ss.str();
- putText(img, text, Point(50, 50), FONT_HERSHEY_SIMPLEX, 1.0, templateColor);
- imshow("img", img);
- waitKey(0);
- return 0;
- }
- int main( int argc, char *argv[] )
- {
- #if(ON_STUDY)
- int maxCount = 60000;
- ReadTrainData(maxCount);
- //newRtStudy(buffer);
- newSvmStudy(buffer);
- #else
- //newRtPredict();
- newSvmPredict();
- #endif
- return 0;
- }
- //from: http://blog.csdn.net/yangtrees/article/details/7458466
学习OpenCV——SVM 手写数字检测的更多相关文章
- 基于opencv的手写数字识别(MFC,HOG,SVM)
参考了秋风细雨的文章:http://blog.csdn.net/candyforever/article/details/8564746 花了点时间编写出了程序,先看看效果吧. 识别效果大概都能正确. ...
- 深度学习之 mnist 手写数字识别
深度学习之 mnist 手写数字识别 开始学习深度学习,先来一个手写数字的程序 import numpy as np import os import codecs import torch from ...
- 【深度学习系列】手写数字识别卷积神经--卷积神经网络CNN原理详解(一)
上篇文章我们给出了用paddlepaddle来做手写数字识别的示例,并对网络结构进行到了调整,提高了识别的精度.有的同学表示不是很理解原理,为什么传统的机器学习算法,简单的神经网络(如多层感知机)都可 ...
- 基于opencv的手写数字字符识别
摘要 本程序主要参照论文,<基于OpenCV的脱机手写字符识别技术>实现了,对于手写阿拉伯数字的识别工作.识别工作分为三大步骤:预处理,特征提取,分类识别.预处理过程主要找到图像的ROI部 ...
- mnist手写数字检测
# -*- coding: utf-8 -*- """ Created on Tue Apr 23 06:16:04 2019 @author: 92958 " ...
- 简单HOG+SVM mnist手写数字分类
使用工具 :VS2013 + OpenCV 3.1 数据集:minst 训练数据:60000张 测试数据:10000张 输出模型:HOG_SVM_DATA.xml 数据准备 train-images- ...
- 使用神经网络来识别手写数字【译】(三)- 用Python代码实现
实现我们分类数字的网络 好,让我们使用随机梯度下降和 MNIST训练数据来写一个程序来学习怎样识别手写数字. 我们用Python (2.7) 来实现.只有 74 行代码!我们需要的第一个东西是 MNI ...
- SVM学习笔记(二)----手写数字识别
引言 上一篇博客整理了一下SVM分类算法的基本理论问题,它分类的基本思想是利用最大间隔进行分类,处理非线性问题是通过核函数将特征向量映射到高维空间,从而变成线性可分的,但是运算却是在低维空间运行的.考 ...
- 手把手教你使用LabVIEW OpenCV DNN实现手写数字识别(含源码)
@ 目录 前言 一.OpenCV DNN模块 1.OpenCV DNN简介 2.LabVIEW中DNN模块函数 二.TensorFlow pb文件的生成和调用 1.TensorFlow2 Keras模 ...
随机推荐
- 【BZOJ】2277: [Poi2011]Strongbox
题意 有一个密码箱,\(0\)到\(n-1\)中的某些整数是它的密码.如果\(a\)和\(b\)都是它的密码,那么\((a+b)%n\)也是它的密码(\(a,b\)可以相等).某人试了\(k\)次密码 ...
- 【BZOJ1003】1003: [ZJOI2006]物流运输trans SPFA+DP
Description 物流公司要把一批货物从码头A运到码头B.由于货物量比较大,需要n天才能运完.货物运输过程中一般要转停好几个码头.物流公司通常会设计一条固定的运输路线,以便对整个运输过程实施严格 ...
- JavaScript放置位置区别
JavaScript放置位置区别 页面中的脚本会在页面载入浏览器后立即执行.我们并不总希望这样.有时,我们希望当页面载入时执行脚本,而另外的时候,我们则希望当用户触发事件时才执行脚本. 位于 head ...
- HttpClient_httpclient 4.3.1 post get的工具类
package com.ryx.util; import java.util.ArrayList; import java.util.List; import java.util.Map; impor ...
- Let It Be - The Beatles - Lyrics
轉載自 https://www.youtube.com/watch?v=0714IbwC3HA When I find myself in times of trouble, Mother Mary ...
- 浅谈iOS视频开发
浅谈iOS视频开发 这段时间对视频开发进行了一些了解,在这里和大家分享一下我自己觉得学习步骤和资料,希望对那些对视频感兴趣的朋友有些帮助. 一.iOS系统自带播放器 要了解iOS视频开发,首先我们从 ...
- CSS3+HTML5实现块阴影与文字阴影
CSS 3 + HTML 5 是未来的 Web,它们都还没有正式到来,虽然不少浏览器已经开始对它们提供部分支持.本教程分5节介绍了 5 个 CSS3 技巧,可以帮你实现未来的 Web,不过,这些技术不 ...
- Scala命令设置JVM参数的规则
Scala下设置JVM参数简单分析 Scala 启动shell脚本,简化后的scala REPL 启动命令大致如下所示: java -Xmx256M -Xms32M \-Xbootclasspath/ ...
- 一次有趣的XSS漏洞挖掘分析(3)最终篇
这真是最后一次了.真的再不逗这个程序员了.和预期一样,勤奋的程序员今天又更新程序了.因为前面写的payload都有一个致命的弱点,就是document.write()会完全破坏DOM结构.而且再“完事 ...
- XML于JSON
XML:可拓展的标记语言(跨平台数据表现)用于保存数据 XML:标记需要关闭 :单根性 .NET中DOM常用对象: XmlDocument :一个XML文档 XmlNode:xml中的单个节点 Xml ...