决策树Decision Tree 及实现
Python(24)
Machine Learning(46) 
版权声明:本文为博主原创文章,未经博主允许不得转载。
本文基于python逐步实现Decision Tree(决策树),分为以下几个步骤:
- 加载数据集
- 熵的计算
- 根据最佳分割feature进行数据分割
- 根据最大信息增益选择最佳分割feature
- 递归构建决策树
- 样本分类
关于决策树的理论方面本文几乎不讲,详情请google keywords:“决策树 信息增益 熵”
将分别体现于代码。
本文只建一个.py文件,所有代码都在这个py里
1.加载数据集
我们选用UCI经典Iris为例
Brief of IRIS:
|
Data Set Characteristics: |
Multivariate |
Number of Instances: |
150 |
Area: |
Life |
|
Attribute Characteristics: |
Real |
Number of Attributes: |
4 |
Date Donated |
1988-07-01 |
|
Associated Tasks: |
Classification |
Missing Values? |
No |
Number of Web Hits: |
533125 |
Code:
- from numpy import *
- #load "iris.data" to workspace
- traindata = loadtxt("D:\ZJU_Projects\machine learning\ML_Action\Dataset\Iris.data",delimiter = ',',usecols = (0,1,2,3),dtype = float)
- trainlabel = loadtxt("D:\ZJU_Projects\machine learning\ML_Action\Dataset\Iris.data",delimiter = ',',usecols = (range(4,5)),dtype = str)
- feaname = ["#0","#1","#2","#3"] # feature names of the 4 attributes (features)
Result:
左图为实际数据集,四个离散型feature,一个label表示类别(有Iris-setosa, Iris-versicolor,Iris-virginica 三个类)
2. 熵的计算
entropy是香农提出来的(信息论大牛),定义见wiki
注意这里的entropy是H(C|X=xi)而非H(C|X), H(C|X)的计算见第下一个点,还要乘以概率加和
Code:
- from math import log
- def calentropy(label):
- n = label.size # the number of samples
- #print n
- count = {} #create dictionary "count"
- for curlabel in label:
- if curlabel not in count.keys():
- count[curlabel] = 0
- count[curlabel] += 1
- entropy = 0
- #print count
- for key in count:
- pxi = float(count[key])/n #notice transfering to float first
- entropy -= pxi*log(pxi,2)
- return entropy
- #testcode:
- #x = calentropy(trainlabel)
Result:
3. 根据最佳分割feature进行数据分割
假定我们已经得到了最佳分割feature,在这里进行分割(最佳feature为splitfea_idx)
第二个函数idx2data是根据splitdata得到的分割数据的两个index集合返回datal (samples less than pivot), datag(samples greater than pivot), labell, labelg。 这里我们根据所选特征的平均值作为pivot
- #split the dataset according to label "splitfea_idx"
- def splitdata(oridata,splitfea_idx):
- arg = args[splitfea_idx] #get the average over all dimensions
- idx_less = [] #create new list including data with feature less than pivot
- idx_greater = [] #includes entries with feature greater than pivot
- n = len(oridata)
- for idx in range(n):
- d = oridata[idx]
- if d[splitfea_idx] < arg:
- #add the newentry into newdata_less set
- idx_less.append(idx)
- else:
- idx_greater.append(idx)
- return idx_less,idx_greater
- #testcode:2
- #idx_less,idx_greater = splitdata(traindata,2)
- #give the data and labels according to index
- def idx2data(oridata,label,splitidx,fea_idx):
- idxl = splitidx[0] #split_less_indices
- idxg = splitidx[1] #split_greater_indices
- datal = []
- datag = []
- labell = []
- labelg = []
- for i in idxl:
- datal.append(append(oridata[i][:fea_idx],oridata[i][fea_idx+1:]))
- for i in idxg:
- datag.append(append(oridata[i][:fea_idx],oridata[i][fea_idx+1:]))
- labell = label[idxl]
- labelg = label[idxg]
- return datal,datag,labell,labelg
这里args是参数,决定分裂节点的阈值(每个参数对应一个feature,大于该值分到>branch,小于该值分到<branch),我们可以定义如下:
- args = mean(traindata,axis = 0)
测试:按特征2进行分类,得到的less和greater set of indices分别为:
也就是按args[2]进行样本集分割,<和>args[2]的branch分别有57和93个样本。
4. 根据最大信息增益选择最佳分割feature
信息增益为代码中的info_gain, 注释中是熵的计算
- #select the best branch to split
- def choosebest_splitnode(oridata,label):
- n_fea = len(oridata[0])
- n = len(label)
- base_entropy = calentropy(label)
- best_gain = -1
- for fea_i in range(n_fea): #calculate entropy under each splitting feature
- cur_entropy = 0
- idxset_less,idxset_greater = splitdata(oridata,fea_i)
- prob_less = float(len(idxset_less))/n
- prob_greater = float(len(idxset_greater))/n
- #entropy(value|X) = \sum{p(xi)*entropy(value|X=xi)}
- cur_entropy += prob_less*calentropy(label[idxset_less])
- cur_entropy += prob_greater * calentropy(label[idxset_greater])
- info_gain = base_entropy - cur_entropy #notice gain is before minus after
- if(info_gain>best_gain):
- best_gain = info_gain
- best_idx = fea_i
- return best_idx
- #testcode:
- #x = choosebest_splitnode(traindata,trainlabel)
这里的测试针对所有数据,分裂一次选择哪个特征呢?
5. 递归构建决策树
详见code注释,buildtree递归地构建树。
递归终止条件:
①该branch内没有样本(subset为空) or
②分割出的所有样本属于同一类 or
③由于每次分割消耗一个feature,当没有feature的时候停止递归,返回当前样本集中大多数sample的label
- #create the decision tree based on information gain
- def buildtree(oridata, label):
- if label.size==0: #if no samples belong to this branch
- return "NULL"
- listlabel = label.tolist()
- #stop when all samples in this subset belongs to one class
- if listlabel.count(label[0])==label.size:
- return label[0]
- #return the majority of samples' label in this subset if no extra features avaliable
- if len(feanamecopy)==0:
- cnt = {}
- for cur_l in label:
- if cur_l not in cnt.keys():
- cnt[cur_l] = 0
- cnt[cur_l] += 1
- maxx = -1
- for keys in cnt:
- if maxx < cnt[keys]:
- maxx = cnt[keys]
- maxkey = keys
- return maxkey
- bestsplit_fea = choosebest_splitnode(oridata,label) #get the best splitting feature
- print bestsplit_fea,len(oridata[0])
- cur_feaname = feanamecopy[bestsplit_fea] # add the feature name to dictionary
- print cur_feaname
- nodedict = {cur_feaname:{}}
- del(feanamecopy[bestsplit_fea]) #delete current feature from feaname
- split_idx = splitdata(oridata,bestsplit_fea) #split_idx: the split index for both less and greater
- data_less,data_greater,label_less,label_greater = idx2data(oridata,label,split_idx,bestsplit_fea)
- #build the tree recursively, the left and right tree are the "<" and ">" branch, respectively
- nodedict[cur_feaname]["<"] = buildtree(data_less,label_less)
- nodedict[cur_feaname][">"] = buildtree(data_greater,label_greater)
- return nodedict
- #testcode:
- #mytree = buildtree(traindata,trainlabel)
- #print mytree
Result:
mytree就是我们的结果,#1表示当前使用第一个feature做分割,'<'和'>'分别对应less 和 greater的数据。
6. 样本分类
根据构建出的mytree进行分类,递归走分支
- #classify a new sample
- def classify(mytree,testdata):
- if type(mytree).__name__ != 'dict':
- return mytree
- fea_name = mytree.keys()[0] #get the name of first feature
- fea_idx = feaname.index(fea_name) #the index of feature 'fea_name'
- val = testdata[fea_idx]
- nextbranch = mytree[fea_name]
- #judge the current value > or < the pivot (average)
- if val>args[fea_idx]:
- nextbranch = nextbranch[">"]
- else:
- nextbranch = nextbranch["<"]
- return classify(nextbranch,testdata)
- #testcode
- tt = traindata[0]
- x = classify(mytree,tt)
- print x
Result:
为了验证代码准确性,我们换一下args参数,把它们都设成0(很小)
args = [0,0,0,0]
建树和分类的结果如下:
可见没有小于pivot(0)的项,于是dict中每个<的key对应的value都为空。
本文中全部代码下载:决策树python实现
Reference: Machine Learning in Action
from: http://blog.csdn.net/abcjennifer/article/details/20905311
决策树Decision Tree 及实现的更多相关文章
- 机器学习算法实践:决策树 (Decision Tree)(转载)
前言 最近打算系统学习下机器学习的基础算法,避免眼高手低,决定把常用的机器学习基础算法都实现一遍以便加深印象.本文为这系列博客的第一篇,关于决策树(Decision Tree)的算法实现,文中我将对决 ...
- 数据挖掘 决策树 Decision tree
数据挖掘-决策树 Decision tree 目录 数据挖掘-决策树 Decision tree 1. 决策树概述 1.1 决策树介绍 1.1.1 决策树定义 1.1.2 本质 1.1.3 决策树的组 ...
- 用于分类的决策树(Decision Tree)-ID3 C4.5
决策树(Decision Tree)是一种基本的分类与回归方法(ID3.C4.5和基于 Gini 的 CART 可用于分类,CART还可用于回归).决策树在分类过程中,表示的是基于特征对实例进行划分, ...
- (ZT)算法杂货铺——分类算法之决策树(Decision tree)
https://www.cnblogs.com/leoo2sk/archive/2010/09/19/decision-tree.html 3.1.摘要 在前面两篇文章中,分别介绍和讨论了朴素贝叶斯分 ...
- 决策树decision tree原理介绍_python sklearn建模_乳腺癌细胞分类器(推荐AAA)
sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003& ...
- 机器学习方法(四):决策树Decision Tree原理与实现技巧
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术.应用感兴趣的同学加入. 前面三篇写了线性回归,lass ...
- 机器学习-决策树 Decision Tree
咱们正式进入了机器学习的模型的部分,虽然现在最火的的机器学习方面的库是Tensorflow, 但是这里还是先简单介绍一下另一个数据处理方面很火的库叫做sklearn.其实咱们在前面已经介绍了一点点sk ...
- 决策树 Decision Tree
决策树是一个类似于流程图的树结构:其中,每个内部结点表示在一个属性上的测试,每个分支代表一个属性输出,而每个树叶结点代表类或类分布.树的最顶层是根结点.  决策树的构建 想要构建一个决策树,那么咱们 ...
- 【机器学习算法-python实现】决策树-Decision tree(2) 决策树的实现
(转载请注明出处:http://blog.csdn.net/buptgshengod) 1.背景 接着上一节说,没看到请先看一下上一节关于数据集的划分数据集划分.如今我们得到了每一个特征值得 ...
随机推荐
- webstorm常用快捷键(常用)
ctrl+/ 注释 ctrl+shift+/ 注释一块的代码 ctrl+shift+z 返回撤撤销前的操作 ctrl+shift+up/down 代码向上/向下移动 ctrl+b或ctrl+鼠标左键单 ...
- 【BZOJ】2924: [Poi1998]Flat broken lines
题意 平面上有\(n\)个点,如果两个点的线段与\(x\)轴的角在\([-45^{\circ}, 45^{\circ}]\),则两个点可以连线.求最少的折线(折线由线段首尾相连)使得覆盖所有点. 分析 ...
- 【BZOJ1968】【AHoi2005】COMMON约数研究
Description Input 只有一行一个整数 N(0 < N < 1000000). Output 只有一行输出,为整数M,即f(1)到f(N)的累加和. Sample Input ...
- IEqualityComparer<T>
在linq中使用union和distinct都不起作用,结果发现必须传入一个实现了IEqualityComparer<T>的比较器 public class CompareUser : I ...
- 使用Uboot启动内核并挂载NFS根文件系统
配置编译好内核之后,将生成的内核文件uImage拷贝到/tftpboot/下,通过tftp服务器将内核下载到开发板,使用命令:tftp 31000000 uImage.下载完成之后配置bootargs ...
- 既然nodejs是单线程的,那么它怎么处理多请求高并发的?
单线程解决高并发的思路就是采用非阻塞,异步编程的思想.简单概括就是当遇到非常耗时的IO操作时,采用非阻塞的方式,继续执行后面的代码,并且进入事件循环,当IO操作完成时,程序会被通知IO操作已经完成.主 ...
- android api汇集
参考文章: 知乎-想写个 App 练手,有什么有趣的 API 接口推荐吗? 使用 Espresso 和 Dagger 测试网络服务 http://developer.simsimi.com/apps# ...
- HTML基础--JS简介、基本语法、类型转换、变量、运算符、分支语句、循环语句、数组、函数、函数调用.avi
JS简介 1.JavaScript是个什么东西? 它是个脚本语言,需要有宿主文件,它的宿主文件是HTML文件. 2.它与Java什么关系? 没有什么直接的联系,Java是Sun公司(已被Oracle收 ...
- 获取Android studio的SHA1值
D:\Android\BaiduMapsApiASDemo>c: C:\>cd .android 系统找不到指定的路径. C:\>cd Users C:\Users>cd Ad ...
- Jfinal验证码功能
//验证码工具类 import java.awt.Color;import java.awt.Font;import java.awt.Graphics;import java.awt.image.B ...