使用Decision Tree对MNIST数据集进行实验
使用的Decision Tree中,对MNIST中的灰度值进行了0/1处理,方便来进行分类和计算熵。
使用较少的测试数据测试了在对灰度值进行多分类的情况下,分类结果的正确率如何。实验结果如下。
#Test change pixel data into more categories than 0/1:
#int(pixel)/50: 37%
#int(pixel)/64: 45.9%
#int(pixel)/96: 52.3%
#int(pixel)/128: 62.48%
#int(pixel)/152: 59.1%
#int(pixel)/176: 57.6%
#int(pixel)/192: 54.0%
可见,在对灰度数据进行二分类,也就是0/1处理时,效果是最好的。
使用0/1处理,最终结果如下:
#Result:
#Train with 10k, test with 60k: 77.79%
#Train with 60k, test with 10k: 87.3%
#Time cost: 3 hours.
最终结果是87.3%的正确率。与SVM和KNN的超过95%相比,差距不小。而且消耗时间更长。
需要注意的是,此次Decision Tree算法中,并未对决策树进行剪枝。因此,还有可以提升的空间。
python代码见最下面。其中:
calcShannonEntropy(dataSet):是对矩阵的熵进行计算,根据各个数据点的分类情况,使用香农定理计算;
splitDataSet(dataSet, axis, value): 是获取第axis维度上的值为value的所有行所组成的矩阵。对于第axis维度上的数据,分别计算他们的splitDataSet的矩阵的熵,并与该维度上数据的出现概率相乘求和,可以得到使用第axis维度构建决策树后,整体的熵。
chooseBestFeatureToSplit(dataSet): 根据splitDataSet函数,对比得到整体的熵与原矩阵的熵相比,熵的增量最大的维度。根据此维度feature来构建决策树。
createDecisionTree(dataSet, features): 递归构建决策树。若在叶子节点处没法分类,则采用majorityCnt(classList)方法统计出现最多次的class作为分类。
代码如下:
- #Decision tree for MNIST dataset by arthur503.
- #Data format: 'class label1:pixel label2:pixel ...'
- #Warning: without fix overfitting!
- #
- #Test change pixel data into more categories than 0/1:
- #int(pixel)/50: 37%
- #int(pixel)/64: 45.9%
- #int(pixel)/96: 52.3%
- #int(pixel)/128: 62.48%
- #int(pixel)/152: 59.1%
- #int(pixel)/176: 57.6%
- #int(pixel)/192: 54.0%
- #
- #Result:
- #Train with 10k, test with 60k: 77.79%
- #Train with 60k, test with 10k: 87.3%
- #Time cost: 3 hours.
- from numpy import *
- import operator
- def calcShannonEntropy(dataSet):
- numEntries = len(dataSet)
- labelCounts = {}
- for featureVec in dataSet:
- currentLabel = featureVec[0]
- if currentLabel not in labelCounts.keys():
- labelCounts[currentLabel] = 1
- else:
- labelCounts[currentLabel] += 1
- shannonEntropy = 0.0
- for key in labelCounts:
- prob = float(labelCounts[key])/numEntries
- shannonEntropy -= prob * log2(prob)
- return shannonEntropy
- #get all rows whose axis item equals value.
- def splitDataSet(dataSet, axis, value):
- subDataSet = []
- for featureVec in dataSet:
- if featureVec[axis] == value:
- reducedFeatureVec = featureVec[:axis]
- reducedFeatureVec.extend(featureVec[axis+1:]) #if axis == -1, this will cause error!
- subDataSet.append(reducedFeatureVec)
- return subDataSet
- def chooseBestFeatureToSplit(dataSet):
- #Notice: Actucally, index 0 of numFeatures is not feature(it is class label).
- numFeatures = len(dataSet[0])
- baseEntropy = calcShannonEntropy(dataSet)
- bestInfoGain = 0.0
- bestFeature = numFeatures - 1 #DO NOT use -1! or splitDataSet(dataSet, -1, value) will cause error!
- #feature index start with 1(not 0)!
- for i in range(numFeatures)[1:]:
- featureList = [example[i] for example in dataSet]
- featureSet = set(featureList)
- newEntropy = 0.0
- for value in featureSet:
- subDataSet = splitDataSet(dataSet, i, value)
- prob = len(subDataSet)/float(len(dataSet))
- newEntropy += prob * calcShannonEntropy(subDataSet)
- infoGain = baseEntropy - newEntropy
- if infoGain > bestInfoGain:
- bestInfoGain = infoGain
- bestFeature = i
- return bestFeature
- #classify on leaf of decision tree.
- def majorityCnt(classList):
- classCount = {}
- for vote in classList:
- if vote not in classCount:
- classCount[vote] = 0
- classCount[vote] += 1
- sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reverse=True)
- return sortedClassCount[0][0]
- #Create Decision Tree.
- def createDecisionTree(dataSet, features):
- print 'create decision tree... length of features is:'+str(len(features))
- classList = [example[0] for example in dataSet]
- if classList.count(classList[0]) == len(classList):
- return classList[0]
- if len(dataSet[0]) == 1:
- return majorityCnt(classList)
- bestFeatureIndex = chooseBestFeatureToSplit(dataSet)
- bestFeatureLabel = features[bestFeatureIndex]
- myTree = {bestFeatureLabel:{}}
- del(features[bestFeatureIndex])
- featureValues = [example[bestFeatureIndex] for example in dataSet]
- featureSet = set(featureValues)
- for value in featureSet:
- subFeatures = features[:]
- myTree[bestFeatureLabel][value] = createDecisionTree(splitDataSet(dataSet, bestFeatureIndex, value), subFeatures)
- return myTree
- def line2Mat(line):
- mat = line.strip().split(' ')
- for i in range(len(mat)-1):
- pixel = mat[i+1].split(':')[1]
- #change MNIST pixel data into 0/1 format.
- mat[i+1] = int(pixel)/128
- return mat
- #return matrix as a list(instead of a matrix).
- #features is the 28*28 pixels in MNIST dataset.
- def file2Mat(fileName):
- f = open(fileName)
- lines = f.readlines()
- matrix = []
- for line in lines:
- mat = line2Mat(line)
- matrix.append(mat)
- f.close()
- print 'Read file '+str(fileName) + ' to array done! Matrix shape:'+str(shape(matrix))
- return matrix
- #Classify test file.
- def classify(inputTree, featureLabels, testVec):
- firstStr = inputTree.keys()[0]
- secondDict = inputTree[firstStr]
- featureIndex = featureLabels.index(firstStr)
- predictClass = '-1'
- for key in secondDict.keys():
- if testVec[featureIndex] == key:
- if type(secondDict[key]) == type({}):
- predictClass = classify(secondDict[key], featureLabels, testVec)
- else:
- predictClass = secondDict[key]
- return predictClass
- def classifyTestFile(inputTree, featureLabels, testDataSet):
- rightCnt = 0
- for i in range(len(testDataSet)):
- classLabel = testDataSet[i][0]
- predictClassLabel = classify(inputTree, featureLabels, testDataSet[i])
- if classLabel == predictClassLabel:
- rightCnt += 1
- if i % 200 == 0:
- print 'num '+str(i)+'. ratio: ' + str(float(rightCnt)/(i+1))
- return float(rightCnt)/len(testDataSet)
- def getFeatureLabels(length):
- strs = []
- for i in range(length):
- strs.append('#'+str(i))
- return strs
- #Normal file
- trainFile = 'train_60k.txt'
- testFile = 'test_10k.txt'
- #Scaled file
- #trainFile = 'train_60k_scale.txt'
- #testFile = 'test_10k_scale.txt'
- #Test file
- #trainFile = 'test_only_1.txt'
- #testFile = 'test_only_2.txt'
- #train decision tree.
- dataSet = file2Mat(trainFile)
- #Actually, the 0 item is class, not feature labels.
- featureLabels = getFeatureLabels(len(dataSet[0]))
- print 'begin to create decision tree...'
- myTree = createDecisionTree(dataSet, featureLabels)
- print 'create decision tree done.'
- #predict with decision tree.
- testDataSet = file2Mat(testFile)
- featureLabels = getFeatureLabels(len(testDataSet[0]))
- rightRatio = classifyTestFile(myTree, featureLabels, testDataSet)
- print 'total right ratio: ' + str(rightRatio)
使用Decision Tree对MNIST数据集进行实验的更多相关文章
- 使用libsvm对MNIST数据集进行实验
使用libsvm对MNIST数据集进行实验 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出话来. 1. libsvm介绍 虽然原理要求很高的数学知识等,但是libs ...
- 使用libsvm对MNIST数据集进行实验---浅显易懂!
原文:http://blog.csdn.net/arthur503/article/details/19974057 在学SVM中的实验环节,老师介绍了libsvm的使用.当时看完之后感觉简单的说不出 ...
- 使用KNN对MNIST数据集进行实验
由于KNN的计算量太大,还没有使用KD-tree进行优化,所以对于60000训练集,10000测试集的数据计算比较慢.这里只是想测试观察一下KNN的效果而已,不调参. K选择之前看过貌似最好不要超过2 ...
- 决策树Decision Tree 及实现
Decision Tree 及实现 标签: 决策树熵信息增益分类有监督 2014-03-17 12:12 15010人阅读 评论(41) 收藏 举报 分类: Data Mining(25) Pyt ...
- 用于分类的决策树(Decision Tree)-ID3 C4.5
决策树(Decision Tree)是一种基本的分类与回归方法(ID3.C4.5和基于 Gini 的 CART 可用于分类,CART还可用于回归).决策树在分类过程中,表示的是基于特征对实例进行划分, ...
- (转)Decision Tree
Decision Tree:Analysis 大家有没有玩过猜猜看(Twenty Questions)的游戏?我在心里想一件物体,你可以用一些问题来确定我心里想的这个物体:如是不是植物?是否会飞?能游 ...
- 从零到一:caffe-windows(CPU)配置与利用mnist数据集训练第一个caffemodel
一.前言 本文会详细地阐述caffe-windows的配置教程.由于博主自己也只是个在校学生,目前也写不了太深入的东西,所以准备从最基础的开始一步步来.个人的计划是分成配置和运行官方教程,利用自己的数 ...
- CART分类与回归树与GBDT(Gradient Boost Decision Tree)
一.CART分类与回归树 资料转载: http://dataunion.org/5771.html Classification And Regression Tree(CART)是决策 ...
- class-决策树Decision Tree
顾名思义,决策树model是树形结构,在分类中,表示基于特征对实例进行分类的过程.可以认为是"if-else"的合集,也可以认为是特征空间,类空间上条件概率分布.主要优点是分类速度 ...
随机推荐
- Tomcat类加载器机制
Tomcat为什么需要定制自己的ClassLoader: 1.定制特定的规则:隔离webapp,安全考虑,reload热插拔 2.缓存类 3.事先加载 要说Tomcat的Classloader机制,我 ...
- Lintcode: Median
Given a unsorted array with integers, find the median of it. A median is the middle number of the ar ...
- 转:Selenium之CSS Selector定位详解
CSS selector定位 CSS(Cascading Style Sheets)是一种语言,它被用来描述 HTML 和 XML 文档的样式. 百度输入框: <input name=&quo ...
- 树形DP +01背包(HDU 1011)
题意:有n个房间,有n-1条道路连接着n个房间,每个房间都有若干个野怪和一定的能量值,有m个士兵从1房间入口进去,到达每个房间必须要留下若干士兵杀死所有的野怪,然后其他人继续走,(一个士兵可以杀死20 ...
- Java基础(10):java基础第一部分综合测试题,成绩合法性校验与排序
题目: 编写一个 JAVA 程序,实现输出考试成绩的前三名 要求: 1. 考试成绩已保存在数组 scores 中,数组元素依次为 89 , -23 , 64 , 91 , 119 , 52 , 73 ...
- .NET: XML
XML在平常生活中用得很多,它的结构很简单,跟windows explorer有点像. 对它进行操作主要有三种方式:XmlDocument, 假设有这么一个XML文件Book.XML <?xml ...
- org.openqa.selenium.WebDriverException: f.QueryInterface is not a function Command duration or timeout:
今天偶遇一个问题,运行项目时,发现这个问题: org.openqa.selenium.WebDriverException: f.QueryInterface is not a functionCom ...
- bzoj3489 A simple rmq problem 可持久化树套树
先预处理出两个个数组pre,next.pre[i]表示上一个与i位置数字相同的位置,若不存在则设为0:next[i]表示下一个与i位置数字相同的位置,若不存在则设为n+1.那么一个满足在区间[L,R] ...
- springday04-go2
练习:计算一个人的bmi指数.算法如下: 身高 单位是米 比如1.70 体重 单位是公斤 比如90 bmi指数 = 体重/身高/身高 如果bmi>24过重,否则正常.视图需要两个,一个是bmi_ ...
- 解决Android抽屉被击穿问题
1,创建一个抽屉DrawerLayout,在V4包下android.support.v4.widget.DrawerLayout,在要设置抽屉的布局中设置android:layout_gravity= ...