《机器学习实战》ID3算法实现
注释:之前从未接触过决策树,直接上手对着书看源码,有点难,确实有点难~~
本代码是基于ID3编写,之后的ID4.5和CART等还没学习到
一.决策树的原理
没有看网上原理,直接看源码懂得原理,下面是我一个抛砖引玉的例子:

太丑了,在Linux下面操作实在不习惯,用的Kolourpqint画板也不好用,凑合看吧!
假设有两个特征:no surfing 、Flippers ,一个结果:Fish
现在假如给你一个测试:no surfing = 1, Flippers=0, 如何知道Fish的结果?太简单了Fish==A...
现在样本你不知道排序的情况下,那我们操作的步骤只能是两种:
1.no surfing = 1时判断Fish,直接得出结果Fish==A
2.Flippers=0时判断Fish,Fish可能是A也可能是B,再判断no surfing =1时,得出Fish == A
从上面我们可以看出,你选择的特征顺序对结果无影响,但是对计算的过程影响很大,我们能不能找到一种很好的途径去解决这个问题呢?
下面是两种方法:

方法一

方法二
由以上的两种思路可以得出,不同的分类方法差距很大吧?
决策树就是用来解决如何选用最佳的方法的一种算法!!!
一点不了解的,先花几分钟看一下我“信息熵”,这是整个算法的核心。
二.决策树的实现
(1)计算信息熵
为什么计算“信息熵”?自己去看原理就懂了。
def claShannonEnt(setData):
lengthData = len(setData)
dicData = {}
for cnt in range(lengthData):
if setData[cnt,-1] not in dicData.keys():
dicData[setData[cnt,-1]] = 0
dicData[setData[cnt,-1]] += 1
Hent = 0.0#输出信息ent
for key in dicData.keys():
pData = float(dicData[key])/lengthData
Hent -= pData*math.log(pData,2)
return Hent
(2)划分数据集
划分之后计算部分的信息熵之和,信息熵越小越好,信息增益越大越好。
def splitData(setData,axis,value):
''' setData: sample sata
axis : 轴的位置
value : 满足条件的值
'''
lengthData = setData.shape[0]
resultMat = np.zeros([1,setData.shape[1]])
for count in range(lengthData):
if int(setData[count,axis]) == int(value) :
resultMat = np.vstack((resultMat,setData[count,:]))
returnMat = resultMat[1:,:]
resultMat = np.hstack((returnMat[:,0:axis],returnMat[:,axis+1:]))
return resultMat
(3)选择最佳的划分方案
这里的原理就是划分之后的信息熵变小,信息增益变大,其中信息熵越小越好,也就是信息增益越大越好,循环比较每种划分之后的信息增益。
def chooseBestTeature(setData):
numFeature = setData.shape[1] - 1 #特征数量
baceEntropy = claShannonEnt(setData) #信息熵
bestGain = 0.0 #最好增益
bestFeature = 0 #最好特征
for i in range(numFeature):
#featList = [example[i] for example in setData]
featList = setData[:,i]
uniquaVals = set(featList) #不同的Value值,set之后就变成无序集合
newEntropy = 0.0
for value in uniquaVals:
subDataSet = splitData(setData,i,value)#分割特征
prob = len(subDataSet)/float(len(setData))
newEntropy += prob * claShannonEnt(subDataSet)#平均信息熵
infoGain = baceEntropy - newEntropy
if (infoGain > bestGain):#求得最大增益
bestGain = infoGain
bestFeature = i
return bestFeature
(4)计算分类之后的标签
这里有点难理解,准备在下面程序讲解的,写到这里就直接讲解了。
这是为了分类不了的情况做的准备,比如:[1,1,'yes'],[1,1,'no'],[1,0,'no'],[1,0,'yes'],[0,0,'no'],[0,0,'yes'],[0,1,'no'],[0,1,'yes'],大家可以按照上面的方法动手试试怎么分割?
我们可以想象一下,就像以前中学学的解方程,Y1+Y2=10 && 2Y1 +2Y2 =10 ,你怎么求解Y1和Y2 ?两个有冲突的方程和上面的样本之间的冲突是一样的。
这明显是一个出错的样本导致的,那怎么解决呢?
再给出一组样本:[1,1,'yes'],[1,1,'yes'],[1,1,'no'],[1,1,'yes'],我们利用错误的样本为少数,多数的样本为正确的,所以[1,1] = 'YES'
#计算分类之后的标签
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount
(5)建立决策树
这里采用递归的方法进行划分
调出循环的条件是:
1.最后的标签相同--->>>也就是最后就省一个答案了,没必要划分直接得出结果了。
2.就是第四点说的无解题,那就多的保留,少的丢弃。

def creatTree(dataSet,labels):
classList = dataSet[:,-1]
#标签全部相等的时候退出
if list(classList).count(classList[0]) == len(list(classList)):
return classList[0]
#最后的标签不相同,这个时候没办法分割,所以只能选择一个占比例大的标签了,博客会给具体例子
if len(dataSet[0,:]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestTeature(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValue = dataSet[:,bestFeat]
uniqueVals = set(featValue)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = creatTree(splitData(dataSet,bestFeat,value),subLabels)
return myTree
(6)使用决策树
就像建立决策树一样,采用递归一层一层的去找到数据属于哪个类,看懂上面的建立之后现在这里不很简单
def classify(inputTrees,featLabels,testVec):
firstStr = list(inputTrees.keys())[0]#字典首元素
secondDict = inputTrees[firstStr]#下一个字典
featIndex = featLabels.index(firstStr)#标签中的位置
for key in secondDict.keys():
if testVec[featIndex] == int(key):#分支
if type(secondDict[key]).__name__=='dict':#如果还是字典说明还得划分
classLabels = classify(secondDict[key],featLabels,testVec)#迭代划分
else: classLabels = secondDict[key]#不是字典说明已经分类
return classLabels
(7)存储决策树函数
(8)总程序设计
注意:我用的是Numpy数据,而不是List数据,这是有区别的,没有完全按照书上编写!
import numpy as np
import matplotlib.pyplot as ply
import math
import operator def claShannonEnt(setData):
lengthData = len(setData)
dicData = {}
for cnt in range(lengthData):
if setData[cnt,-1] not in dicData.keys():
dicData[setData[cnt,-1]] = 0
dicData[setData[cnt,-1]] += 1
Hent = 0.0#输出信息ent
for key in dicData.keys():
pData = float(dicData[key])/lengthData
Hent -= pData*math.log(pData,2)
return Hent def splitData(setData,axis,value):
''' setData: sample sata
axis : 轴的位置
value : 满足条件的值
'''
lengthData = setData.shape[0]
resultMat = np.zeros([1,setData.shape[1]])
for count in range(lengthData):
if int(setData[count,axis]) == int(value) :
resultMat = np.vstack((resultMat,setData[count,:]))
returnMat = resultMat[1:,:]
resultMat = np.hstack((returnMat[:,0:axis],returnMat[:,axis+1:]))
return resultMat def chooseBestTeature(setData):
numFeature = setData.shape[1] - 1 #特征数量
baceEntropy = claShannonEnt(setData) #信息熵
bestGain = 0.0 #最好增益
bestFeature = 0 #最好特征
for i in range(numFeature):
#featList = [example[i] for example in setData]
featList = setData[:,i]
uniquaVals = set(featList) #不同的Value值,set之后就变成无序集合
newEntropy = 0.0
for value in uniquaVals:
subDataSet = splitData(setData,i,value)#分割特征
prob = len(subDataSet)/float(len(setData))
newEntropy += prob * claShannonEnt(subDataSet)#平均信息熵
infoGain = baceEntropy - newEntropy
if (infoGain > bestGain):#求得最大增益
bestGain = infoGain
bestFeature = i
return bestFeature #计算分类之后的标签
def majorityCnt(classList):
classCount = {}
for vote in classList:
if vote not in classCount.keys():
classCount[vote] = 0
classCount[vote] += 1
sortedClassCount = sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
return sortedClassCount def creatTree(dataSet,labels):
classList = dataSet[:,-1]
#标签全部相等的时候退出
if list(classList).count(classList[0]) == len(list(classList)):
return classList[0]
#最后的标签不相同,这个时候没办法分割,所以只能选择一个占比例大的标签了,博客会给具体例子
if len(dataSet[0,:]) == 1:
return majorityCnt(classList)
bestFeat = chooseBestTeature(dataSet)
bestFeatLabel = labels[bestFeat]
myTree = {bestFeatLabel:{}}
del(labels[bestFeat])
featValue = dataSet[:,bestFeat]
uniqueVals = set(featValue)
for value in uniqueVals:
subLabels = labels[:]
myTree[bestFeatLabel][value] = creatTree(splitData(dataSet,bestFeat,value),subLabels)
return myTree
import numpy as np
import trees if __name__ == '__main__':
testData = np.array([[1,1,'yes'],[1,1,'no'],[1,0,'no'],[1,0,'yes'],[0,0,'no'],[0,0,'yes'],[0,1,'no'],[0,1,'yes']])
myTree = trees.creatTree(testData,['no surfacing','flippers'])#['yes','yes','no','no','no']
print(myTree)
《机器学习实战》ID3算法实现的更多相关文章
- 机器学习笔记----- ID3算法的python实战
本文申明:本文原创,如有转载请申明.数据代码来自实验数据都是来自[美]Peter Harrington 写的<Machine Learning in Action>这本书,侵删. Hell ...
- 机器学习决策树ID3算法,手把手教你用Python实现
本文始发于个人公众号:TechFlow,原创不易,求个关注 今天是机器学习专题的第21篇文章,我们一起来看一个新的模型--决策树. 决策树的定义 决策树是我本人非常喜欢的机器学习模型,非常直观容易理解 ...
- 学习笔记之机器学习实战 (Machine Learning in Action)
机器学习实战 (豆瓣) https://book.douban.com/subject/24703171/ 机器学习是人工智能研究领域中一个极其重要的研究方向,在现今的大数据时代背景下,捕获数据并从中 ...
- Python四步实现决策树ID3算法,参考机器学习实战
一.编写计算历史数据的经验熵函数 from math import log def calcShannonEnt(dataSet): numEntries = len(dataSet) labelCo ...
- 决策树ID3算法python实现 -- 《机器学习实战》
from math import log import numpy as np import matplotlib.pyplot as plt import operator #计算给定数据集的香农熵 ...
- 《机器学习实战》学习笔记第三章 —— 决策树之ID3、C4.5算法
主要内容: 一.决策树模型 二.信息与熵 三.信息增益与ID3算法 四.信息增益比与C4.5算法 五.决策树的剪枝 一.决策树模型 1.所谓决策树,就是根据实例的特征对实例进行划分的树形结构.其中有两 ...
- python机器学习笔记 ID3决策树算法实战
前面学习了决策树的算法原理,这里继续对代码进行深入学习,并掌握ID3的算法实践过程. ID3算法是一种贪心算法,用来构造决策树,ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性 ...
- 机器学习实战 -- 决策树(ID3)
机器学习实战 -- 决策树(ID3) ID3是什么我也不知道,不急,知道他是干什么的就行 ID3是最经典最基础的一种决策树算法,他会将每一个特征都设为决策节点,有时候,一个数据集中,某些特征属 ...
- 《机器学习实战》学习笔记第九章 —— 决策树之CART算法
相关博文: <机器学习实战>学习笔记第三章 —— 决策树 主要内容: 一.CART算法简介 二.分类树 三.回归树 四.构建回归树 五.回归树的剪枝 六.模型树 七.树回归与标准回归的比较 ...
- 机器学习实战笔记(Python实现)-01-K近邻算法(KNN)
--------------------------------------------------------------------------------------- 本系列文章为<机器 ...
随机推荐
- pyhanlp 共性分析与短语提取内容详解
pyhanlp 共性分析与短语提取内容详解 简介 HanLP中的词语提取是基于互信息与信息熵.想要计算互信息与信息熵有限要做的是 文本分词进行共性分析.在作者的原文中,有几个问题,为了便于说明,这 ...
- const引用返回值
一.引用 引用是别名 必须在定义引用时进行初始化.初始化是指明引用指向哪个对象的唯一方法. const 引用是指向 const 对象的引用: ; const int &refVal = iva ...
- JAVA高并发系列
高并发Java(1):前言 高并发Java(2):多线程基础 高并发Java(3):Java内存模型和线程安全 高并发Java(4):无锁 高并发Java(5):JDK并发包1 高并发Java(6): ...
- XBOX360
[汇总+分享]XBOX360多人游戏汇总贴https://tieba.baidu.com/p/3550398060?pn=13&red_tag=3423139816&traceid= ...
- oracle--分组后获取每组数据第一条数据
SELECT * FROM (SELECT ROW_NUMBER() OVER(PARTITION BY cc.queuename ORDER BY cc.enroldate DESC) rn, cc ...
- ALGO-17_蓝桥杯_算法训练_乘积最大(DP)
问题描述 今年是国际数学联盟确定的“——世界数学年”,又恰逢我国著名数学家华罗庚先生诞辰90周年.在华罗庚先生的家乡江苏金坛,组织了一场别开生面的数学智力竞赛的活动,你的一个好朋友XZ也有幸得以参加. ...
- 导入数据库时出现ORA-01435: 用户不存在
报错信息: IMP-00003: 遇到 ORACLE 错误 1435 ORA-01435: 用户不存在 成功终止导入,但出现警告. 我的导入脚本为: imp system/*****@min file ...
- elasticsearch 口水篇(4)java客户端 - 原生esClient
上一篇(elasticsearch 口水篇(3)java客户端 - Jest)Jest是第三方客户端,基于REST Api进行调用(httpClient),本篇简单介绍下elasticsearch原生 ...
- P3811 乘法逆元
传送 乘法逆元:ax ≡ 1 (mod p),其中x为a的逆元,求模意义下的乘法逆元,通常有一下几种方法: 1.拓展欧几里得(也就是exgcd) ax ≡ 1 (mod p) ax-py=1 这就变成 ...
- NodeJs使用Express框架开发时的快速调试方法
习惯了php开发,可以直接使用echo或者var_dump()将想要查看的变量结果输出到网页查看,非常的方便.但是使用express开发时,每次修改文件后,都需要使用npm start命令重启服务,然 ...