来自OpenCV2.3.1 sample/c/mushroom.cpp

1.首先读入agaricus-lepiota.data的训练样本。

样本中第一项是e或p代表有毒或无毒的标志位;其他是特征,可以把每个样本看做一个特征向量;

cvSeqPush( seq, el_ptr );读入序列seq中,每一项都存储一个样本即特征向量;

之后,把特征向量与标志位分别读入CvMat* data与CvMat* reponses中

还有一个CvMat* missing保留丢失位当前小于0位置;

2.训练样本

  1. dtree = new CvDTree;
  2. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
  3. CvDTreeParams( 8, // max depth
  4. 10, // min sample count 样本数小于10时,停止分裂
  5. 0, // regression accuracy: N/A here;回归树的限制精度
  6. true, // compute surrogate split, as we have missing data;;为真时,计算missing data和变量的重要性
  7. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
  8. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation where K is equal to cv_folds
  9. true, // use 1SE rule => smaller tree;If true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
  10. true, // throw away the pruned tree branches
  11. priors //错分类的代价我们判断的:有毒VS无毒 错误的代价比 the array of priors, the bigger p_weight, the more attention
  12. // to the poisonous mushrooms
  13. // (a mushroom will be judjed to be poisonous with bigger chance)
  14. ));

3.

  1. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;

4.interactive_classification通过人工输入特征来判断。

    1. #include "opencv2/core/core_c.h"
    2. #include "opencv2/ml/ml.hpp"
    3. #include <stdio.h>
    4. void help()
    5. {
    6. printf("\nThis program demonstrated the use of OpenCV's decision tree function for learning and predicting data\n"
    7. "Usage :\n"
    8. "./mushroom <path to agaricus-lepiota.data>\n"
    9. "\n"
    10. "The sample demonstrates how to build a decision tree for classifying mushrooms.\n"
    11. "It uses the sample base agaricus-lepiota.data from UCI Repository, here is the link:\n"
    12. "\n"
    13. "Newman, D.J. & Hettich, S. & Blake, C.L. & Merz, C.J. (1998).\n"
    14. "UCI Repository of machine learning databases\n"
    15. "[http://www.ics.uci.edu/~mlearn/MLRepository.html].\n"
    16. "Irvine, CA: University of California, Department of Information and Computer Science.\n"
    17. "\n"
    18. "// loads the mushroom database, which is a text file, containing\n"
    19. "// one training sample per row, all the input variables and the output variable are categorical,\n"
    20. "// the values are encoded by characters.\n\n");
    21. }
    22. int mushroom_read_database( const char* filename, CvMat** data, CvMat** missing, CvMat** responses )
    23. {
    24. const int M = 1024;
    25. FILE* f = fopen( filename, "rt" );
    26. CvMemStorage* storage;
    27. CvSeq* seq;
    28. char buf[M+2], *ptr;
    29. float* el_ptr;
    30. CvSeqReader reader;
    31. int i, j, var_count = 0;
    32. if( !f )
    33. return 0;
    34. // read the first line and determine the number of variables
    35. if( !fgets( buf, M, f ))
    36. {
    37. fclose(f);
    38. return 0;
    39. }
    40. for( ptr = buf; *ptr != '\0'; ptr++ )
    41. var_count += *ptr == ',';//计算每个样本的数量,每个样本一个“,”,样本数量=var_count+1;
    42. assert( ptr - buf == (var_count+1)*2 );
    43. // create temporary memory storage to store the whole database
    44. //把样本存入seq中,存储空间是storage;
    45. el_ptr = new float[var_count+1];
    46. storage = cvCreateMemStorage();
    47. seq = cvCreateSeq( 0, sizeof(*seq), (var_count+1)*sizeof(float), storage );//
    48. for(;;)
    49. {
    50. for( i = 0; i <= var_count; i++ )
    51. {
    52. int c = buf[i*2];
    53. el_ptr[i] = c == '?' ? -1.f : (float)c;
    54. }
    55. if( i != var_count+1 )
    56. break;
    57. cvSeqPush( seq, el_ptr );
    58. if( !fgets( buf, M, f ) || !strchr( buf, ',' ) )
    59. break;
    60. }
    61. fclose(f);
    62. // allocate the output matrices and copy the base there
    63. *data = cvCreateMat( seq->total, var_count, CV_32F );//行数:样本数量;列数:样本大小;
    64. *missing = cvCreateMat( seq->total, var_count, CV_8U );
    65. *responses = cvCreateMat( seq->total, 1, CV_32F );//样本标志;
    66. cvStartReadSeq( seq, &reader );
    67. for( i = 0; i < seq->total; i++ )
    68. {
    69. const float* sdata = (float*)reader.ptr + 1;
    70. float* ddata = data[0]->data.fl + var_count*i;
    71. float* dr = responses[0]->data.fl + i;
    72. uchar* dm = missing[0]->data.ptr + var_count*i;
    73. for( j = 0; j < var_count; j++ )
    74. {
    75. ddata[j] = sdata[j];
    76. dm[j] = sdata[j] < 0;
    77. }
    78. *dr = sdata[-1];//样本的第一个位置是标志;
    79. CV_NEXT_SEQ_ELEM( seq->elem_size, reader );
    80. }
    81. cvReleaseMemStorage( &storage );
    82. delete el_ptr;
    83. return 1;
    84. }
    85. CvDTree* mushroom_create_dtree( const CvMat* data, const CvMat* missing,
    86. const CvMat* responses, float p_weight )
    87. {
    88. CvDTree* dtree;
    89. CvMat* var_type;
    90. int i, hr1 = 0, hr2 = 0, p_total = 0;
    91. float priors[] = { 1, p_weight };
    92. var_type = cvCreateMat( data->cols + 1, 1, CV_8U );
    93. cvSet( var_type, cvScalarAll(CV_VAR_CATEGORICAL) ); // all the variables are categorical
    94. dtree = new CvDTree;
    95. dtree->train( data, CV_ROW_SAMPLE, responses, 0, 0, var_type, missing,
    96. CvDTreeParams( 8, // max depth
    97. 10, // min sample count样本数小于10时,停止分裂
    98. 0, // regression accuracy: N/A here;回归树的限制精度
    99. true, // compute surrogate split, as we have missing data;为真时,计算missing data和可变的重要性正确度
    100. 15, // max number of categories (use sub-optimal algorithm for larger numbers)类型上限以保证计算速度。树会以次优分裂(suboptimal split)的形式生长。只对2种取值以上的树有意义
    101. 10, // the number of cross-validation folds;If cv_folds > 1 then prune a tree with K-fold cross-validation
    102. true, // use 1SE rule => smaller treeIf true 修剪树. 这将使树更紧凑,更能抵抗训练数据噪声,但有点不太准确
    103. true, // throw away the pruned tree branches
    104. priors // the array of priors, the bigger p_weight, the more attention
    105. // to the poisonous mushrooms
    106. // (a mushroom will be judjed to be poisonous with bigger chance)
    107. ));
    108. // compute hit-rate on the training database, demonstrates predict usage.
    109. for( i = 0; i < data->rows; i++ )
    110. {
    111. CvMat sample, mask;
    112. cvGetRow( data, &sample, i );
    113. cvGetRow( missing, &mask, i );
    114. double r = dtree->predict( &sample, &mask )->value;//使用predict来预测样本,结果为 CvDTreeNode结构,dtree->predict(sample,mask)->value是分类情况下的类别或回归情况下的函数估计值;
    115. int d = fabs(r - responses->data.fl[i]) >= FLT_EPSILON;//大于阈值FLT_EPSILON被判断为误检
    116. if( d )
    117. {
    118. if( r != 'p' )
    119. hr1++;
    120. else
    121. hr2++;
    122. }
    123. p_total += responses->data.fl[i] == 'p';
    124. }
    125. printf( "Results on the training database:\n"
    126. "\tPoisonous mushrooms mis-predicted: %d (%g%%)\n"
    127. "\tFalse-alarms: %d (%g%%)\n", hr1, (double)hr1*100/p_total,
    128. hr2, (double)hr2*100/(data->rows - p_total) );
    129. cvReleaseMat( &var_type );
    130. return dtree;
    131. }
    132. static const char* var_desc[] =
    133. {
    134. "cap shape (bell=b,conical=c,convex=x,flat=f)",
    135. "cap surface (fibrous=f,grooves=g,scaly=y,smooth=s)",
    136. "cap color (brown=n,buff=b,cinnamon=c,gray=g,green=r,\n\tpink=p,purple=u,red=e,white=w,yellow=y)",
    137. "bruises? (bruises=t,no=f)",
    138. "odor (almond=a,anise=l,creosote=c,fishy=y,foul=f,\n\tmusty=m,none=n,pungent=p,spicy=s)",
    139. "gill attachment (attached=a,descending=d,free=f,notched=n)",
    140. "gill spacing (close=c,crowded=w,distant=d)",
    141. "gill size (broad=b,narrow=n)",
    142. "gill color (black=k,brown=n,buff=b,chocolate=h,gray=g,\n\tgreen=r,orange=o,pink=p,purple=u,red=e,white=w,yellow=y)",
    143. "stalk shape (enlarging=e,tapering=t)",
    144. "stalk root (bulbous=b,club=c,cup=u,equal=e,rhizomorphs=z,rooted=r)",
    145. "stalk surface above ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    146. "stalk surface below ring (ibrous=f,scaly=y,silky=k,smooth=s)",
    147. "stalk color above ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    148. "stalk color below ring (brown=n,buff=b,cinnamon=c,gray=g,orange=o,\n\tpink=p,red=e,white=w,yellow=y)",
    149. "veil type (partial=p,universal=u)",
    150. "veil color (brown=n,orange=o,white=w,yellow=y)",
    151. "ring number (none=n,one=o,two=t)",
    152. "ring type (cobwebby=c,evanescent=e,flaring=f,large=l,\n\tnone=n,pendant=p,sheathing=s,zone=z)",
    153. "spore print color (black=k,brown=n,buff=b,chocolate=h,green=r,\n\torange=o,purple=u,white=w,yellow=y)",
    154. "population (abundant=a,clustered=c,numerous=n,\n\tscattered=s,several=v,solitary=y)",
    155. "habitat (grasses=g,leaves=l,meadows=m,paths=p\n\turban=u,waste=w,woods=d)",
    156. 0
    157. };
    158. void print_variable_importance( CvDTree* dtree, const char** var_desc )
    159. {
    160. const CvMat* var_importance = dtree->get_var_importance();
    161. int i;
    162. char input[1000];
    163. if( !var_importance )
    164. {
    165. printf( "Error: Variable importance can not be retrieved\n" );
    166. return;
    167. }
    168. printf( "Print variable importance information? (y/n) " );
    169. scanf( "%1s", input );
    170. if( input[0] != 'y' && input[0] != 'Y' )
    171. return;
    172. for( i = 0; i < var_importance->cols*var_importance->rows; i++ )
    173. {
    174. double val = var_importance->data.db[i];
    175. if( var_desc )
    176. {
    177. char buf[100];
    178. int len = strchr( var_desc[i], '(' ) - var_desc[i] - 1;
    179. strncpy( buf, var_desc[i], len );
    180. buf[len] = '\0';
    181. printf( "%s", buf );
    182. }
    183. else
    184. printf( "var #%d", i );
    185. printf( ": %g%%\n", val*100. );
    186. }
    187. }
    188. void interactive_classification( CvDTree* dtree, const char** var_desc )
    189. {
    190. char input[1000];
    191. const CvDTreeNode* root;
    192. CvDTreeTrainData* data;
    193. if( !dtree )
    194. return;
    195. root = dtree->get_root();
    196. data = dtree->get_data();
    197. for(;;)
    198. {
    199. const CvDTreeNode* node;
    200. printf( "Start/Proceed with interactive mushroom classification (y/n): " );
    201. scanf( "%1s", input );
    202. if( input[0] != 'y' && input[0] != 'Y' )
    203. break;
    204. printf( "Enter 1-letter answers, '?' for missing/unknown value...\n" );
    205. // custom version of predict
    206. //传统的预测方式;
    207. node = root;
    208. for(;;)
    209. {
    210. CvDTreeSplit* split = node->split;
    211. int dir = 0;
    212. if( !node->left || node->Tn <= dtree->get_pruned_tree_idx() || !node->split )
    213. break;
    214. for( ; split != 0; )
    215. {
    216. int vi = split->var_idx, j;
    217. int count = data->cat_count->data.i[vi];
    218. const int* map = data->cat_map->data.i + data->cat_ofs->data.i[vi];
    219. printf( "%s: ", var_desc[vi] );
    220. scanf( "%1s", input );
    221. if( input[0] == '?' )
    222. {
    223. split = split->next;
    224. continue;
    225. }
    226. // convert the input character to the normalized value of the variable
    227. for( j = 0; j < count; j++ )
    228. if( map[j] == input[0] )
    229. break;
    230. if( j < count )
    231. {
    232. dir = (split->subset[j>>5] & (1 << (j&31))) ? -1 : 1;
    233. if( split->inversed )
    234. dir = -dir;
    235. break;
    236. }
    237. else
    238. printf( "Error: unrecognized value\n" );
    239. }
    240. if( !dir )
    241. {
    242. printf( "Impossible to classify the sample\n");
    243. node = 0;
    244. break;
    245. }
    246. node = dir < 0 ? node->left : node->right;
    247. }
    248. if( node )
    249. printf( "Prediction result: the mushroom is %s\n",
    250. node->class_idx == 0 ? "EDIBLE" : "POISONOUS" );
    251. printf( "\n-----------------------------\n" );
    252. }
    253. }
    254. int main( int argc, char** argv )
    255. {
    256. CvMat *data = 0, *missing = 0, *responses = 0;
    257. CvDTree* dtree;
    258. const char* base_path = argc >= 2 ? argv[1] : "agaricus-lepiota.data";
    259. help();
    260. if( !mushroom_read_database( base_path, &data, &missing, &responses ) )
    261. {
    262. printf( "\nUnable to load the training database\n\n");
    263. help();
    264. return -1;
    265. }
    266. dtree = mushroom_create_dtree( data, missing, responses,
    267. 10 // poisonous mushrooms will have 10x higher weight in the decision tree
    268. );
    269. cvReleaseMat( &data );
    270. cvReleaseMat( &missing );
    271. cvReleaseMat( &responses );
    272. print_variable_importance( dtree, var_desc );
    273. interactive_classification( dtree, var_desc );
    274. delete dtree;
    275. return 0;
    276. }
    277. //from: http://blog.csdn.net/yangtrees/article/details/7490852

OpenCV码源笔记——Decision Tree决策树的更多相关文章

  1. OpenCV码源笔记——RandomTrees (二)(Forest)

    源码细节: ● 训练函数 bool CvRTrees::train( const CvMat* _train_data, int _tflag,                        cons ...

  2. OpenCV码源笔记——RandomTrees (一)

    OpenCV2.3中Random Trees(R.T.)的继承结构: API: CvRTParams 定义R.T.训练用参数,CvDTreeParams的扩展子类,但并不用到CvDTreeParams ...

  3. Decision tree(决策树)算法初探

    0. 算法概述 决策树(decision tree)是一种基本的分类与回归方法.决策树模型呈树形结构(二分类思想的算法模型往往都是树形结构) 0x1:决策树模型的不同角度理解 在分类问题中,表示基于特 ...

  4. decision tree 决策树(一)

    一 决策树 原理:分类决策树模型是一种描述对实例进行分类的树形结构.决策树由结点(node)和有向边(directed edge)组成.结点有两种类型:内部结点(internal node)和叶结点( ...

  5. Decision tree——决策树

    基本流程 决策树是通过分次判断样本属性来进行划分样本类别的机器学习模型.每个树的结点选择一个最优属性来进行样本的分流,最终将样本类别划分出来. 决策树的关键就是分流时最优属性$a$的选择.使用所谓信息 ...

  6. 决策树Decision Tree 及实现

    Decision Tree 及实现 标签: 决策树熵信息增益分类有监督 2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报  分类: Data Mining(25)  Pyt ...

  7. Spark MLlib - Decision Tree源码分析

    http://spark.apache.org/docs/latest/mllib-decision-tree.html 以决策树作为开始,因为简单,而且也比较容易用到,当前的boosting或ran ...

  8. [ML学习笔记] 决策树与随机森林(Decision Tree&Random Forest)

    [ML学习笔记] 决策树与随机森林(Decision Tree&Random Forest) 决策树 决策树算法以树状结构表示数据分类的结果.每个决策点实现一个具有离散输出的测试函数,记为分支 ...

  9. 【机器学习】决策树(Decision Tree) 学习笔记

    [机器学习]决策树(decision tree) 学习笔记 标签(空格分隔): 机器学习 决策树简介 决策树(decision tree)是一个树结构(可以是二叉树或非二叉树).其每个非叶节点表示一个 ...

随机推荐

  1. Operating Cisco Router

    Operating Cisco Router consider the hardware on the ends of the serial link, in particular where the ...

  2. Ubuntu 14.04安装Chromium浏览器并添加Flash插件Pepper Flash Player

    安装方法Ubuntu 14.04及衍生版本用户命令: 因为默认库里面有Chromium和Pepper Flash Player,安装非常容易,打开终端,输入以下命令: sudo apt-get upd ...

  3. ASP.NET MVC +EasyUI 权限设计(二)环境搭建

    请注明转载地址:http://www.cnblogs.com/arhat 今天突然发现博客园出问题了,老魏使用了PC,手机,平板都访问博客园了,都是不能正常的访问,原因是不能加载CSS,也就是不能访问 ...

  4. 微软职位内部推荐-Senior Speech TTS

    微软近期Open的职位: Job Description: Responsibilities Do you want to change the way the world interacts wit ...

  5. 实现 iframe 子页面调用父页面中的js方法

    父页面:index.html(使用iframe包含子页面child.html) [xhtml] view plaincopyprint? <html> <head> <s ...

  6. PHP中应用Service Locator服务定位及单例模式

    单例模式将一个对象实例化后,放在静态变量中,供程序调用. 服务定位(ServiceLocator)就是对象工场Factory,调用者对象直接调用Service Locator,与被调用对象减轻了依赖关 ...

  7. 【UOJ】【34】多项式乘法

    快速傅里叶变换模板题 算法理解请看<算法导论>第30章<多项式与快速傅里叶变换>,至于证明插值唯一性什么的看不懂也没关系啦-只要明白这个过程是怎么算的就ok. 递归版:(425 ...

  8. Java多线程——<四>让线程有返回值

    一.概述 到目前为止,我们已经能够声明并使一个线程任务运行起来了.但是遇到一个问题:现在定义的任务都没有任何返回值,那么加入我们希望一个任务运行结束后告诉我一个结果,该结果表名任务执行成功或失败,此时 ...

  9. C# 中请求数据方式

    #region 根据URL获取结果集        /// <summary>        /// 根据URL获取结果集 默认为GET,如果数据量大了可以传入POST        // ...

  10. ios开发之网络数据的下载与上传

    要实现网络数据的下载与上传,主要有三种方式 > NSURLConnection  针对少量数据,使用“GET”或“POST”方法从服务器获取数据,使用“POST”方法向服务器传输数据; > ...