Test1.py 主要是用来运行的 代码如下:

  1. # -*- coding: utf-8 -*-
  2.  
  3. from math import log
  4. import operator
  5. import treePlotter
  6.  
  7. def calcShannonEnt(dataSet):
  8. """
  9. 输入:数据集
  10. 输出:数据集的香农熵
  11. 描述:计算给定数据集的香农熵;熵越大,数据集的混乱程度越大
  12. """
  13. # 数据集个数
  14. numEntries = len(dataSet)
  15. #print("dd",numEntries)
  16. # 标签个数
  17. labelCounts = {}
  18. for featVec in dataSet:
  19. # 获取每一行的结果 也就是yes or no
  20. currentLabel = featVec[-1]
  21. # print('e',currentLabel)
  22. # 判断我获取的这个yes or no 在不在labelCounts字典中 如果不在创建新的设置为0
  23. if currentLabel not in labelCounts.keys():
  24. labelCounts[currentLabel] = 0
  25. labelCounts[currentLabel] += 1
  26. #print('r',labelCounts)
  27. shannonEnt = 0.0
  28. for key in labelCounts:
  29. # 计算类别信息熵
  30. prob = float(labelCounts[key])/numEntries
  31. shannonEnt -= prob * log(prob, 2)
  32. #print('----',shannonEnt)
  33. return shannonEnt
  34. # 分别按照这几个属性来计算信息熵 找出最大的,最后按照这一个来划分。
  35. def splitDataSet(dataSet, axis, value):
  36. """
  37. 输入:数据集,选择维度,选择值
  38. 输出:划分数据集
  39. 描述:按照给定特征划分数据集;去除选择维度中等于选择值的项
  40. """
  41. retDataSet = []
  42. # 这个时候 dataSet 还是完整的
  43. for featVec in dataSet:
  44. # print(axis,featVec)
  45. # print('A',featVec[axis],"是不是等于",value)
  46. if featVec[axis] == value:
  47. reduceFeatVec = featVec[:axis]
  48. # print("B",reduceFeatVec,"此时的维度:",axis)
  49. reduceFeatVec.extend(featVec[axis+1:])
  50. retDataSet.append(reduceFeatVec)
  51. # print('GG',retDataSet)
  52. return retDataSet
  53.  
  54. def chooseBestFeatureToSplit(dataSet):
  55. """
  56. 输入:数据集
  57. 输出:最好的划分维度
  58. 描述:选择最好的数据集划分维度
  59. """
  60. # 特征数量 也就是字段个数
  61. numFeatures = len(dataSet[0]) - 1
  62. ###################################################################
  63. # (1)信息增益
  64. # print('cc',numFeatures)
  65. # 信息增益实际上是ID3算法中用来进行属性选择度量的。
  66. # 它选择具有最高信息增益的属性来作为节点N的分裂属性。
  67. # 该属性使结果划分中的元组分类所需信息量最小。
  68. # 对D中的元组分类所需的期望信息为下式:
  69. baseEntropy = calcShannonEnt(dataSet) # 香农熵
  70. # print('z',baseEntropy)
  71. bestInfoGainRatio = 0.0 # 最好的熵
  72. bestFeature = -1 # 最好的特征
  73. for i in range(numFeatures):
  74. featList = [example[i] for example in dataSet]
  75. uniqueVals = set(featList)
  76. # print('s',i,uniqueVals)
  77. newEntropy = 0.0
  78. splitInfo = 0.0
  79. for value in uniqueVals:
  80. # 划分数据集
  81. # print("Bn",i,value)
  82. subDataSet = splitDataSet(dataSet, i, value)
  83. #print("After",subDataSet,i,value)
  84. prob = len(subDataSet)/float(len(dataSet))
  85. # 现在假定按照属性A划分D中的元组,且属性A将D划分成v个不同的类。
  86. # 在该划分之后,为了得到准确的分类还需要的信息由下面的式子度量
  87. newEntropy += prob * calcShannonEnt(subDataSet)
  88. # 信息增益定义为原来的信息需求(即仅基于类比例)与新需求(即对A划分之后得到的)之间的差
  89. splitInfo += -prob * log(prob, 2)
  90. # 信息增益
  91. infoGain = baseEntropy - newEntropy
  92. ##########################################################################
  93. if (splitInfo == 0): # 修复溢出错误
  94. continue
  95. #########################################################################
  96. # (2)信息增益率
  97. # 训练数据集D划分成对应于属性A测试的v个输出的v个划分产生的信息。信息增益率定义:
  98. infoGainRatio = infoGain / splitInfo
  99. # 选择具有最大增益率的属性作为分裂属性。
  100. if (infoGainRatio > bestInfoGainRatio):
  101. bestInfoGainRatio = infoGainRatio
  102. bestFeature = i
  103. return bestFeature
  104.  
  105. def majorityCnt(classList):
  106. """
  107. 输入:分类类别列表
  108. 输出:子节点的分类
  109. 描述:数据集已经处理了所有属性,但是类标签依然不是唯一的,
  110. 采用少数服从多数的原则决定该子节点的分类
  111. """
  112. ''' 找出数量最多的分类 '''
  113. # 分类字典
  114. classCount = {}
  115. for vote in classList:
  116. if vote not in classCount.keys():
  117. classCount[vote] = 0
  118. classCount[vote] += 1
  119. # 创建键值为classList中唯一值的数据字典,字典对象存储了classList中每个类标签出现的频率,最后利用operator操作键值排序字典,并返回出现次数最多的分类名称。
  120. # iteritems:迭代器
  121. # operator模块提供的itemgetter函数用于获取对象的哪些维的数据,参数为一些序号(即需要获取的数据在对象中的序号)
  122. # sorted() 是Python内置的一个排序函数,它会从一个"迭代器"返回一个排好序的新列表。
  123. sortedClassCount = sorted(classCount.iteritems(), key=operator.itemgetter(1), reversed=True)
  124. return sortedClassCount[0][0]
  125. # 创建决策树 参数:数据集、标签
  126. def createTree(dataSet, labels):
  127. """
  128. 输入:数据集,特征标签
  129. 输出:决策树
  130. 描述:递归构建决策树,利用上述的函数
  131. """
  132. # 截取dataSet的最后一行
  133. classList = [example[-1] for example in dataSet]
  134. # 数据集都是同一类的情况
  135. if classList.count(classList[0]) == len(classList):
  136. return classList[0]
  137. # 遍历完所有特征时返回出现次数最多的
  138. #print('bb',dataSet[1])
  139. # 如果数据集只有一个特征的情况
  140. if len(dataSet[0]) == 1:
  141. return majorityCnt(classList)
  142. # 最大增益率的属性作为分裂属性
  143. bestFeat = chooseBestFeatureToSplit(dataSet) # 最好的特征
  144. # print('bestFeat',bestFeat) # 0 2 当选择0(outlook)之后 剩下的012中选择2(windy)中
  145. bestFeatLabel = labels[bestFeat] # 最好的分类
  146. myTree = {bestFeatLabel:{}}
  147. # print(myTree) # {'outlook': {}} {'windy': {}}
  148. del(labels[bestFeat])
  149. # 得到列表包括节点所有的属性值
  150. featValues = [example[bestFeat] for example in dataSet]
  151. # print('featValues',featValues)
  152. uniqueVals = set(featValues)
  153. # print('uniqueVals',uniqueVals)
  154. for value in uniqueVals:
  155. # 去掉前面标签之后剩下的标签
  156. subLabels = labels[:]
  157. # print('subLabels',subLabels)
  158. myTree[bestFeatLabel][value] = createTree(splitDataSet(dataSet, bestFeat, value), subLabels)
  159. # print('myTree',myTree)
  160. return myTree
  161.  
  162. def classify(inputTree, featLabels, testVec):
  163. """
  164. 输入:决策树,分类标签,测试数据
  165. 输出:决策结果
  166. 描述:跑决策树
  167. """
  168. firstStr = list(inputTree.keys())[0]
  169. # print('t2',firstStr)
  170. secondDict = inputTree[firstStr]
  171. #print('t3',secondDict)
  172. featIndex = featLabels.index(firstStr)
  173. # print('t4',featIndex)
  174. for key in secondDict.keys():
  175. #print('key',key)
  176. if testVec[featIndex] == key:
  177. #print('testVec[featIndex]',testVec[featIndex])
  178. if type(secondDict[key]).__name__ == 'dict':
  179. classLabel = classify(secondDict[key], featLabels, testVec)
  180. else:
  181. classLabel = secondDict[key]
  182. # print('t5',classLabel)
  183. return classLabel
  184. # 通过输入的决策树和对应的标签 来对测试集合 进行预测
  185. def classifyAll(inputTree, featLabels, testDataSet):
  186. """
  187. 输入:决策树,分类标签,测试数据集
  188. 输出:决策结果
  189. 描述:跑决策树
  190. """
  191. # 空列表
  192. classLabelAll = []
  193. for testVec in testDataSet:
  194. # print('t1',testVec)
  195. # 将预测结果插入到classLabelAll中
  196. classLabelAll.append(classify(inputTree, featLabels, testVec))
  197. # print("t6",classLabelAll)
  198. return classLabelAll
  199.  
  200. # 训练集
  201. def createDataSet():
  202. """
  203. 天气情况 outlook-> sunny | overcast | rain
  204. 温度情况 temperature-> hot | mild | cool
  205. 湿度情况 humidity-> high | normal
  206. 风力情况 windy-> false | true
  207. """
  208. ######## no or yes is play golf ???
  209. dataSet = [["sunny", "hot", "high", "false", 'no'],
  210. ["sunny", "hot", "high", "true", 'no'],
  211. ["overcast", "hot", "high", "false", 'yes'],
  212. ["rain", "mild", "high", "false", 'yes'],
  213. ["rain", "cool", "normal", "false", 'yes'],
  214. ["rain", "cool", "normal", "true", 'no'],
  215. ["overcast","cool", "normal", "true", 'no'],
  216. ["rain", "hot", "high", "true", 'yes'],
  217. ["sunny", "mild", "high", "true", 'no'],
  218. ["rain", "hot", "normal", "true", 'yes'],
  219. ["overcast","mild", "high", "false", 'no']]
  220. # 对应的标签
  221. labels = ['outlook', 'temperature', 'humidity', 'windy']
  222. return dataSet, labels
  223. # 测试集
  224. def createTestSet():
  225. testSet = [["sunny", "mild", "high", "false"],
  226. ["sunny", "cool", "normal", "false"],
  227. ["rain", "mild", "normal", "false"],
  228. ["sunny", "mild", "normal", "true"],
  229. ["overcast","mild", "high", "true"],
  230. ["rain", "hot", "normal", "true"],
  231. ["sunny", "mild", "normal", "false"],
  232. ["rain", "hot", "high", "true"],
  233. ["sunny", "mild", "high", "true"],
  234. ["rain", "hot", "normal", "true"],
  235. ["overcast", "mild", "high", "false"],
  236. ["rain", "mild", "high", "true"]]
  237. return testSet
  238. #主函数 定义
  239. def main():
  240. dataSet, labels = createDataSet()
  241. labels_tmp = labels[:] # 拷贝 labels
  242. Tree = createTree(dataSet, labels_tmp)
  243. print('Tree:\n', Tree)
  244. treePlotter.createPlot(Tree)
  245. print('------------------------------')
  246. # 获取测试集 进行预测
  247. testSet = createTestSet()
  248. print('classifyResult:\n', classifyAll(Tree, labels, testSet))
  249. # 调用主函数
  250. if __name__ == '__main__':
  251. main()

treePlotter.py 用来画决策树。 代码如下所示:

  1. import matplotlib.pyplot as plt
  2. # 定义文本框和箭头格式
  3. decisionNode = dict(boxstyle="sawtooth", fc="0.8")
  4. leafNode = dict(boxstyle="round4", fc="0.8")
  5. arrow_args = dict(arrowstyle="<-")
  6. # 绘制带箭头的注释
  7. def plotNode(nodeTxt, centerPt, parentPt, nodeType):
  8. createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
  9. xytext=centerPt, textcoords='axes fraction', \
  10. va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)
  11. ''' 获得决策树的叶节点数 '''
  12. def getNumLeafs(myTree):
  13. numLeafs = 0
  14. # fistStr获得字典的键 代表树根
  15. firstStr = list(myTree.keys())[0] # 头结点
  16. # print('firstStr',firstStr)
  17. secondDict = myTree[firstStr] # 取出头结点的的字典
  18. for key in secondDict.keys(): # 测试节点的数据类型是否为字典
  19. if type(secondDict[key]).__name__ == 'dict':
  20. numLeafs += getNumLeafs(secondDict[key])
  21. else:
  22. numLeafs += 1
  23. return numLeafs
  24. ''' 求树的深度 '''
  25. def getTreeDepth(myTree):
  26. maxDepth = 0
  27. firstStr = list(myTree.keys())[0] # 头结点
  28. secondDict = myTree[firstStr]
  29. for key in secondDict.keys(): # 测试节点的数据类型是否为字典
  30. if type(secondDict[key]).__name__ == 'dict':
  31. thisDepth = getTreeDepth(secondDict[key]) + 1
  32. else:
  33. thisDepth = 1
  34. if thisDepth > maxDepth:
  35. maxDepth = thisDepth
  36. return maxDepth
  37. ''' 在父子节点之间填充文本信息 '''
  38. def plotMidText(cntrPt, parentPt, txtString):
  39. xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
  40. yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
  41. createPlot.ax1.text(xMid, yMid, txtString)
  42. ''' 根节点坐标 '''
  43. def plotTree(myTree, parentPt, nodeTxt):
  44. numLeafs = getNumLeafs(myTree) # 子节点数量
  45. depth = getTreeDepth(myTree) # 深度
  46. firstStr = list(myTree.keys())[0] # 根节点的key
  47. '''X坐标=节点的x偏移量 + 叶节点数距离
  48. 所有该节点下子叶子节点的距离:numLeafs / plotTree.totalW
  49. 但是坐标在叶子节点的中心:numLeafs / 2 / plotTree.totalW
  50. 又因为xOff初始坐标点在原点的左边:numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW ,这是偏移量
  51. 那么x = numLeafs / 2 / plotTree.totalW + 0.5 / plotTree.totalW + plotTree.xOff
  52. '''
  53. # 根节点坐标
  54. # 叶子节点距离
  55. cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
  56. # 标记子节点属性值
  57. plotMidText(cntrPt, parentPt, nodeTxt)
  58. plotNode(firstStr, cntrPt, parentPt, decisionNode)
  59. secondDict = myTree[firstStr]
  60. plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
  61. for key in secondDict.keys():
  62. if type(secondDict[key]).__name__ == 'dict':
  63. plotTree(secondDict[key], cntrPt, str(key))
  64. else:
  65. plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
  66. plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
  67. plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
  68. plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD
  69. # plot构建树
  70. def createPlot(inTree):
  71. # figure语法 创建自定义图像 定义了一个框架
  72. # num:图像编号或名称,数字为编号
  73. # facecolor:背景颜色
  74. fig = plt.figure(1, facecolor='white')
  75. #plt.close()将完全关闭图形窗口
  76. # plt.clf()将清除图形-您仍然可以在其上绘制另一个绘图。
  77. fig.clf()
  78. # xticks是一个列表,其中的元素就是x轴上将显示的坐标
  79. # yticks是y轴上显示的坐标,这里空列表则不显示坐标
  80. axprops = dict(xticks=[], yticks=[])
  81. # 这里定义一个子图窗口
  82. # 第一个参数xyz含义是,将框架划分为x行y列窗口,ax1代表其第z个窗口。
  83. # ps:111 就是一行一列第一个窗口
  84. # frameon = False将隐藏坐标轴
  85. createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
  86. # plotTree.totalW是决策树的叶子树,也代表宽度
  87. plotTree.totalw = float(getNumLeafs(inTree))
  88. # plotTree.totalD是决策树的深度
  89. plotTree.totalD = float(getTreeDepth(inTree))
  90. # 方便后面加上 1.0 / plotTree.totalW 后位置刚好在中间
  91. plotTree.xOff = -0.5 / plotTree.totalw
  92. plotTree.yOff = 1.0
  93. # 调用函数plotTree(),绘制整棵决策树,最后显示出来。
  94. plotTree(inTree, (0.5, 1.0), '')
  95. plt.show()

运行结果如下所示:

一个简单的C4.5算法,采用Python语言的更多相关文章

  1. 实现一个简单的邮箱地址爬虫(python)

    我经常收到关于email爬虫的问题.有迹象表明那些想从网页上抓取联系方式的人对这个问题很感兴趣.在这篇文章里,我想演示一下如何使用python实现一个简单的邮箱爬虫.这个爬虫很简单,但从这个例子中你可 ...

  2. 实现一个简单的虚拟demo算法

    假如现在你需要写一个像下面一样的表格的应用程序,这个表格可以根据不同的字段进行升序或者降序的展示. 这个应用程序看起来很简单,你可以想出好几种不同的方式来写.最容易想到的可能是,在你的 JavaScr ...

  3. 通过创建一个简单的骰子游戏来探究 Python

    在我的这系列的第一篇文章 中, 我已经讲解如何使用 Python 创建一个简单的.基于文本的骰子游戏.这次,我将展示如何使用 Python 模块 Pygame 来创建一个图形化游戏.它将需要几篇文章才 ...

  4. day-7 一个简单的决策树归纳算法(ID3)python编程实现

    本文介绍如何利用决策树/判定树(decision tree)中决策树归纳算法(ID3)解决机器学习中的回归问题.文中介绍基于有监督的学习方式,如何利用年龄.收入.身份.收入.信用等级等特征值来判定用户 ...

  5. 通过编写一个简单的漏洞扫描程序学习Python基本语句

    今天开始读<Python绝技:运用Python成为顶级黑客>一书,第一章用一个小例子来讲解Python的基本语法和语句.主要学习的内容有:1. 安装第三方库.2. 变量.字符串.列表.词典 ...

  6. 算法课上机实验(一个简单的GUI排序算法比较程序)

    (在家里的电脑上Linux Deepin截的图,屏幕大一点的话,deepin用着还挺不错的说) 这个应该是大二的算法课程上机实验时做的一个小程序,也是我的第一个GUI小程序,实现什么的都记不清了,只记 ...

  7. C++写一个简单的解析器(分析C语言)

    该方案实现了一个分析C语言的词法分析+解析. 注意: 1.简单语法,部分秕.它可以在本文法的基础上进行扩展,此过程使用自上而下LL(1)语法. 2.自己主动能达到求First 集和 Follow 集. ...

  8. 【学习笔记】PYTHON语言程序设计(北理工 嵩天)

    1 Python基本语法元素 1.1 程序设计基本方法 计算机发展历史上最重要的预测法则     摩尔定律:单位面积集成电路上可容纳晶体管数量约2年翻倍 cpu/gpu.内存.硬盘.电子产品价格等都遵 ...

  9. 一个简单的多机器人编队算法实现--PID

    用PID进行领航跟随法机器人编队控制 课题2:多机器人编队控制研究对象:两轮差动的移动机器人或车式移动机器人研究内容:平坦地形,编队的保持和避障,以及避障和队形切换算法等:起伏地形,还要考虑地形情况对 ...

  10. 一个简单的mock server

    在前后端分离的项目中, 前端无需等后端接口提供了才调试, 后端无需等第三方接口提供了才调试, 基于“契约”,可以通过mock server实现调试, 下面是一个简单的mock server,通过pyt ...

随机推荐

  1. 看懂java序列化,这篇就够了

    前言 相信大家日常开发中,经常看到 Java 对象 "implements Serializable".那么,它到底有什么用呢?本文带你全方位的解读序列化与反序列化这一块知识点. ...

  2. 浅析switch和if(开发中这两者的优缺点;分析出优缺点在使用就能更确定自己需要使用哪个函数了)

    分析 Switch 相较于 if 的优点 1.switch 执行效率  高于  if 的执行效率 分析: switch是在编译阶段将子函数的地址和判断条件绑定了,只要直接将a的直接映射到子函数地址去执 ...

  3. 即构SDK12月迭代:新增多项质量回调,互动白板、云录制SDK同步更新

    即构SDK12月迭代来啦,本月LiveRoom/AudioRoom SDK新增了端到端延迟质量回调.房间会话ID信息,便于在音视频通话.直播场景中进行时延.通话质量的评测.同时还优化了硬件设备权限变更 ...

  4. tensorflow神经网络归一化方法

    参考https://blog.csdn.net/chary8088/article/details/81542879

  5. Linux: rsyslog.conf 配置

    refer to: https://www.debian.org/doc/manuals/debian-handbook/sect.syslog.en.html 日志子系统 Each log mess ...

  6. 2021-8-2 Mysql个人练习题

    创建学生表 CREATE TABLE student( id int, uname VARCHAR(20), chinese FLOAT, english FLOAT, math FLOAT ); I ...

  7. React错误: Can't resolve 'react-dom/client'

    错误截图 解决方案 当你的react版本低于18时,但仍然报这个错误,可以采用如下方案 意外的发现当我采用上述方案时,我的React路由跳转时,页面不刷新的问题也解决了,很神奇,日后技艺精进再补充.

  8. python: ImportError: cannot import name '_unicodefun' from 'click'

    报错 报错原因 click模块版本问题 解决方案 指定click版本为8.0.4 参考链接 https://github.com/psf/black/issues/2964

  9. 达梦数据库: SQL查询报错《不是 GROUP BY 表达式解决方法》

    报错信息: ****: 第4 行附近出现错误: 不是 GROUP BY 表达式 修改办法: 达梦可以配置兼容参数,COMPATIBLE_MODE=4,静态参数,需要重启数据库后生效! sp_set_p ...

  10. React请求机制优化思路

    说起数据加载的机制,有一个绕不开的话题就是前端性能,很多电商门户的首页其实都会做一些垂直的定制优化,比如让请求在页面最早加载,或者在前一个页面就进行预加载等等.随着react18的发布,请求机制这一块 ...