使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵。

使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何。实验结果如下。

#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%

可见,在对灰度数据进行二分类,也就是0/1处理时,效果是最好的。

使用0/1处理,最终结果如下:

#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.

最终结果是87.3%的正确率。与SVM和KNN的超过95%相比,差距不小。而且消耗时间更长。

需要注意的是,此次Decision Tree算法中,并未对决策树进行剪枝。因此,还有可以提升的空间。

python代码见最下面。其中:

calcShannonEntropy(dataSet):是对矩阵的熵进行计算,根据各个数据点的分类情况,使用香农定理计算;

splitDataSet(dataSet, axis, value): 是获取第axis维度上的值为value的所有行所组成的矩阵。对于第axis维度上的数据,分别计算他们的splitDataSet的矩阵的熵,并与该维度上数据的出现概率相乘求和,可以得到使用第axis维度构建决策树后,整体的熵。

chooseBestFeatureToSplit(dataSet): 根据splitDataSet函数,对比得到整体的熵与原矩阵的熵相比,熵的增量最大的维度。根据此维度feature来构建决策树。

createDecisionTree(dataSet, features): 递归构建决策树。若在叶子节点处没法分类,则采用majorityCnt(classList)方法统计出现最多次的class作为分类。

代码如下:

  1. #Decision tree for MNIST dataset by arthur503.
  2. #Data format: 'class    label1:pixel    label2:pixel ...'
  3. #Warning: without fix overfitting!
  4. #
  5. #Test change pixel data into more categories than 0/1:
  6. #int(pixel)/50: 37%
  7. #int(pixel)/64: 45.9%
  8. #int(pixel)/96: 52.3%
  9. #int(pixel)/128: 62.48%
  10. #int(pixel)/152: 59.1%
  11. #int(pixel)/176: 57.6%
  12. #int(pixel)/192: 54.0%
  13. #
  14. #Result:
  15. #Train with 10k, test with 60k: 77.79%
  16. #Train with 60k, test with 10k: 87.3%
  17. #Time cost: 3 hours.
  18. from numpy import *
  19. import operator
  20. def calcShannonEntropy(dataSet):
  21. numEntries = len(dataSet)
  22. labelCounts = {}
  23. for featureVec in dataSet:
  24. currentLabel = featureVec[0]
  25. if currentLabel not in labelCounts.keys():
  26. labelCounts[currentLabel] = 1
  27. else:
  28. labelCounts[currentLabel] += 1
  29. shannonEntropy = 0.0
  30. for key in labelCounts:
  31. prob = float(labelCounts[key])/numEntries
  32. shannonEntropy -= prob  * log2(prob)
  33. return shannonEntropy
  34. #get all rows whose axis item equals value.
  35. def splitDataSet(dataSet, axis, value):
  36. subDataSet = []
  37. for featureVec in dataSet:
  38. if featureVec[axis] == value:
  39. reducedFeatureVec = featureVec[:axis]
  40. reducedFeatureVec.extend(featureVec[axis+1:])   #if axis == -1, this will cause error!
  41. subDataSet.append(reducedFeatureVec)
  42. return subDataSet
  43. def chooseBestFeatureToSplit(dataSet):
  44. #Notice: Actucally, index 0 of numFeatures is not feature(it is class label).
  45. numFeatures = len(dataSet[0])
  46. baseEntropy = calcShannonEntropy(dataSet)
  47. bestInfoGain = 0.0
  48. bestFeature = numFeatures - 1   #DO NOT use -1! or splitDataSet(dataSet, -1, value) will cause error!
  49. #feature index start with 1(not 0)!
  50. for i in range(numFeatures)[1:]:
  51. featureList = [example[i] for example in dataSet]
  52. featureSet = set(featureList)
  53. newEntropy = 0.0
  54. for value in featureSet:
  55. subDataSet = splitDataSet(dataSet, i, value)
  56. prob = len(subDataSet)/float(len(dataSet))
  57. newEntropy += prob * calcShannonEntropy(subDataSet)
  58. infoGain = baseEntropy - newEntropy
  59. if infoGain > bestInfoGain:
  60. bestInfoGain = infoGain
  61. bestFeature = i
  62. return bestFeature
  63. #classify on leaf of decision tree.
  64. def majorityCnt(classList):
  65. classCount = {}
  66. for vote in classList:
  67. if vote not in classCount:
  68. classCount[vote] = 0
  69. classCount[vote] += 1
  70. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
  71. return sortedClassCount[0][0]
  72. #Create Decision Tree.
  73. def createDecisionTree(dataSet, features):
  74. print 'create decision tree... length of features is:'+str(len(features))
  75. classList = [example[0] for example in dataSet]
  76. if classList.count(classList[0]) == len(classList):
  77. return classList[0]
  78. if len(dataSet[0]) == 1:
  79. return majorityCnt(classList)
  80. bestFeatureIndex = chooseBestFeatureToSplit(dataSet)
  81. bestFeatureLabel = features[bestFeatureIndex]
  82. myTree = {bestFeatureLabel:{}}
  83. del(features[bestFeatureIndex])
  84. featureValues = [example[bestFeatureIndex] for example in dataSet]
  85. featureSet = set(featureValues)
  86. for value in featureSet:
  87. subFeatures = features[:]
  88. myTree[bestFeatureLabel][value] = createDecisionTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures)
  89. return myTree
  90. def line2Mat(line):
  91. mat = line.strip().split(' ')
  92. for i in range(len(mat)-1):
  93. pixel = mat[i+1].split(':')[1]
  94. #change MNIST pixel data into 0/1 format.
  95. mat[i+1] = int(pixel)/128
  96. return mat
  97. #return matrix as a list(instead of a matrix).
  98. #features is the 28*28 pixels in MNIST dataset.
  99. def file2Mat(fileName):
  100. f = open(fileName)
  101. lines = f.readlines()
  102. matrix = []
  103. for line in lines:
  104. mat = line2Mat(line)
  105. matrix.append(mat)
  106. f.close()
  107. print 'Read file '+str(fileName) + ' to array done! Matrix shape:'+str(shape(matrix))
  108. return matrix
  109. #Classify test file.
  110. def classify(inputTree, featureLabels, testVec):
  111. firstStr = inputTree.keys()[0]
  112. secondDict = inputTree[firstStr]
  113. featureIndex = featureLabels.index(firstStr)
  114. predictClass = '-1'
  115. for key in secondDict.keys():
  116. if testVec[featureIndex] == key:
  117. if type(secondDict[key]) == type({}):
  118. predictClass = classify(secondDict[key], featureLabels, testVec)
  119. else:
  120. predictClass = secondDict[key]
  121. return predictClass
  122. def classifyTestFile(inputTree, featureLabels, testDataSet):
  123. rightCnt = 0
  124. for i in range(len(testDataSet)):
  125. classLabel = testDataSet[i][0]
  126. predictClassLabel = classify(inputTree, featureLabels, testDataSet[i])
  127. if classLabel == predictClassLabel:
  128. rightCnt += 1
  129. if i % 200 == 0:
  130. print 'num '+str(i)+'. ratio: ' + str(float(rightCnt)/(i+1))
  131. return float(rightCnt)/len(testDataSet)
  132. def getFeatureLabels(length):
  133. strs = []
  134. for i in range(length):
  135. strs.append('#'+str(i))
  136. return strs
  137. #Normal file
  138. trainFile = 'train_60k.txt'
  139. testFile = 'test_10k.txt'
  140. #Scaled file
  141. #trainFile = 'train_60k_scale.txt'
  142. #testFile = 'test_10k_scale.txt'
  143. #Test file
  144. #trainFile = 'test_only_1.txt'
  145. #testFile = 'test_only_2.txt'
  146. #train decision tree.
  147. dataSet = file2Mat(trainFile)
  148. #Actually, the 0 item is class, not feature labels.
  149. featureLabels = getFeatureLabels(len(dataSet[0]))
  150. print 'begin to create decision tree...'
  151. myTree = createDecisionTree(dataSet, featureLabels)
  152. print 'create decision tree done.'
  153. #predict with decision tree.
  154. testDataSet = file2Mat(testFile)
  155. featureLabels = getFeatureLabels(len(testDataSet[0]))
  156. rightRatio = classifyTestFile(myTree, featureLabels, testDataSet)
  157. print 'total right ratio: ' + str(rightRatio)
 

使用Decision Tree对MNIST数据集进行实验的更多相关文章

  1. 使用libsvm对MNIST数据集进行实验

    使用libsvm对MNIST数据集进行实验 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出话来. 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libs ...

  2. 使用libsvm对MNIST数据集进行实验---浅显易懂!

    原文:http://blog.csdn.net/arthur503/article/details/19974057 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出 ...

  3. 使用KNN对MNIST数据集进行实验

    由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢.这里只是想测试观察一下KNN的效果而已,不调参. K选择之前看过貌似最好不要超过2 ...

  4. 决策树Decision Tree 及实现

    Decision Tree 及实现 标签: 决策树熵信息增益分类有监督 2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报  分类: Data Mining(25)  Pyt ...

  5. 用于分类的决策树(Decision Tree)-ID3 C4.5

    决策树(Decision Tree)是一种基本的分类与回归方法(ID3.C4.5和基于 Gini 的 CART 可用于分类,CART还可用于回归).决策树在分类过程中,表示的是基于特征对实例进行划分, ...

  6. (转)Decision Tree

    Decision Tree:Analysis 大家有没有玩过猜猜看(Twenty Questions)的游戏?我在心里想一件物体,你可以用一些问题来确定我心里想的这个物体:如是不是植物?是否会飞?能游 ...

  7. 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel

    一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...

  8. CART分类与回归树与GBDT(Gradient Boost Decision Tree)

    一.CART分类与回归树 资料转载: http://dataunion.org/5771.html        Classification And Regression Tree(CART)是决策 ...

  9. class-决策树Decision Tree

    顾名思义,决策树model是树形结构,在分类中,表示基于特征对实例进行分类的过程.可以认为是"if-else"的合集,也可以认为是特征空间,类空间上条件概率分布.主要优点是分类速度 ...

随机推荐

  1. asmdisk opened & asmdisk cached

    ASMDISK OPENED - Disk is present in the storage system and is being accessed by Automatic Storage Ma ...

  2. js 多选题选项内容显示在标题下

    <body><div class="page-container"> <div class="view-container"> ...

  3. linux:什么是linux

    1>.linux是一套作业系统(linux就是核心与呼叫这两层),每一种作业系统都是在他专门的硬体机器上面运行的:linux是一个Open Source的作业系统,具有可移植性 2>.li ...

  4. IOS 设备参数

    Iphone,Ipad,ITouch 各个型号参数对比

  5. [原创]java WEB学习笔记82:Hibernate学习之路---映射 一对多关联关系,配置,CRUD方法测试及注意点

    本博客的目的:①总结自己的学习过程,相当于学习笔记 ②将自己的经验分享给大家,相互学习,互相交流,不可商用 内容难免出现问题,欢迎指正,交流,探讨,可以留言,也可以通过以下方式联系. 本人互联网技术爱 ...

  6. BackgroundWorker的使用方法

    http://msdn.microsoft.com/zh-cn/library/system.componentmodel.backgroundworker(VS.80).aspx Backgroun ...

  7. springmvc+spring+mybatis分页查询实例版本3,添加条件检索

    在第二个版本上添加了姓名模糊查询,年龄区间查询;自以为easy,结果发现mybatis的各种参数写法基本搞混或是忘了,zuo啊,直接上代码,然后赶紧把mybatis整理一遍再研究自己的项目,应该还会有 ...

  8. linux中模块的相关操作

    /lib/modules/[kernel版本]/modules.dep 这个文件记录了模块的依赖关系 modprobe 和 insmod 这两个命令都可以加载模块,但是modeprobe会自动分析模块 ...

  9. C++字符串和string类介绍

    一.C风格字符串 ◆ 1.字符串是用字符型数组存储的,字符串要求其尾部以'\0'作为结束标志.如:    char string[ ]="C++ programming language&q ...

  10. [转]Delphi多线程编程入门(二)——通过调用API实现多线程

    以下是一篇很值得看的关于Delphi多线程编程的文章,内容很全面,建议收藏. 一.入门 ㈠. function CreateThread(    lpThreadAttributes: Pointer ...