数据集如下:

 色泽    根蒂    敲声    纹理    脐部    触感    好瓜
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否

基于信息增益的ID3决策树的原理这里不再赘述,读者如果不明白可参考西瓜书对这部分内容的讲解。

python实现代码如下:

 from math import log2
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties # 统计label出现次数
def get_counts(data):
total = len(data)
results = {}
for d in data:
results[d[-1]] = results.get(d[-1], 0) + 1
return results, total # 计算信息熵
def calcu_entropy(data):
results, total = get_counts(data)
ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
return ent # 计算每个feature的信息增益
def calcu_each_gain(column, update_data):
total = len(column)
grouped = update_data.iloc[:, -1].groupby(by=column)
temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
return calcu_entropy(update_data.iloc[:, -1]) - temp # 获取最大的信息增益的feature
def get_max_gain(temp_data):
columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
return columns_entropy[0] # 去掉数据中已存在的列属性内容
def drop_exist_feature(data, best_feature):
attr = pd.unique(data[best_feature])
new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
return new_data # 获得出现最多的label
def get_most_label(label_list):
label_dict = {}
for l in label_list:
label_dict[l] = label_dict.get(l, 0) + 1
sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
return sorted_label[0][0] # 创建决策树
def create_tree(data_set, column_count):
label_list = data_set.iloc[:, -1]
if len(pd.unique(label_list)) == 1:
return label_list.values[0]
if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
return get_most_label(label_list)
best_attr = get_max_gain(data_set)[0]
tree = {best_attr: {}}
exist_attr = pd.unique(data_set[best_attr])
if len(exist_attr) != len(column_count[best_attr]):
no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
for nea in no_exist_attr:
tree[best_attr][nea] = get_most_label(label_list)
for item in drop_exist_feature(data_set, best_attr):
tree[best_attr][item[0]] = create_tree(item[1], column_count)
return tree # 决策树绘制基本参考《机器学习实战》书内的代码以及博客:http://blog.csdn.net/c406495762/article/details/76262487
# 获取树的叶子节点数目
def get_num_leafs(decision_tree):
num_leafs = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
num_leafs += get_num_leafs(second_dict[k])
else:
num_leafs += 1
return num_leafs # 获取树的深度
def get_tree_depth(decision_tree):
max_depth = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
this_depth = 1 + get_tree_depth(second_dict[k])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth # 绘制节点
def plot_node(node_txt, center_pt, parent_pt, node_type):
arrow_args = dict(arrowstyle='<-')
font = FontProperties(fname=r'C:\Windows\Fonts\STXINGKA.TTF', size=15)
create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', xytext=center_pt,
textcoords='axes fraction', va="center", ha="center", bbox=node_type,
arrowprops=arrow_args, FontProperties=font) # 标注划分属性
def plot_mid_text(cntr_pt, parent_pt, txt_str):
font = FontProperties(fname=r'C:\Windows\Fonts\MSYH.TTC', size=10)
x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font) # 绘制决策树
def plot_tree(decision_tree, parent_pt, node_txt):
d_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc='0.8')
num_leafs = get_num_leafs(decision_tree)
first_str = next(iter(decision_tree))
cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
plot_mid_text(cntr_pt, parent_pt, node_txt)
plot_node(first_str, cntr_pt, parent_pt, d_node)
second_dict = decision_tree[first_str]
plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
plot_tree(second_dict[k], cntr_pt, k)
else:
plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD def create_plot(dtree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_tree.totalW = float(get_num_leafs(dtree))
plot_tree.totalD = float(get_tree_depth(dtree))
plot_tree.xoff = -0.5/plot_tree.totalW
plot_tree.yoff = 1.0
plot_tree(dtree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
d_tree = create_tree(my_data, column_count)
create_plot(d_tree)

绘制的决策树如下:

python实现简单决策树(信息增益)——基于周志华的西瓜书数据的更多相关文章

  1. 周志华-机器学习西瓜书-第三章习题3.5 LDA

    本文为周志华机器学习西瓜书第三章课后习题3.5答案,编程实现线性判别分析LDA,数据集为书本第89页的数据 首先介绍LDA算法流程: LDA的一个手工计算数学实例: 课后习题的代码: # coding ...

  2. 支持向量机(SVM)算法分析——周志华的西瓜书学习

    1.线性可分 对于一个数据集: 如果存在一个超平面X能够将D中的正负样本精确地划分到S的两侧,超平面如下: 那么数据集D就是线性可分的,否则,不可分. w称为法向量,决定了超平面的方向:b为位移量,决 ...

  3. (二)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”——CART决策树

    CART决策树 (一)<机器学习>(周志华)第4章 决策树 笔记 理论及实现——“西瓜树” 参照上一篇ID3算法实现的决策树(点击上面链接直达),进一步实现CART决策树. 其实只需要改动 ...

  4. 【深度森林第三弹】周志华等提出梯度提升决策树再胜DNN

    [深度森林第三弹]周志华等提出梯度提升决策树再胜DNN   技术小能手 2018-06-04 14:39:46 浏览848 分布式 性能 神经网络   还记得周志华教授等人的“深度森林”论文吗?今天, ...

  5. 【Todo】【读书笔记】机器学习-周志华

    书籍位置: /Users/baidu/Documents/Data/Interview/机器学习-数据挖掘/<机器学习_周志华.pdf> 一共442页.能不能这个周末先囫囵吞枣看完呢.哈哈 ...

  6. [重磅]Deep Forest,非神经网络的深度模型,周志华老师最新之作,三十分钟理解!

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 深度学习最大的贡献,个人认为就是表征 ...

  7. 偶尔转帖:AI会议的总结(by南大周志华)

    偶尔转帖:AI会议的总结(by南大周志华) 说明: 纯属个人看法, 仅供参考. tier-1的列得较全, tier-2的不太全, tier-3的很不全. 同分的按字母序排列. 不很严谨地说, tier ...

  8. 【转载】 AI会议的总结(by南大周志华)

    原文地址: https://blog.csdn.net/LiFeitengup/article/details/8441054 最近在查找期刊会议级别的时候发现这篇博客,应该是2012年之前的内容,现 ...

  9. AI产业将更凸显个人英雄主义 周志华老师的观点是如此的有深度

    今天无意间在网上看的了一则推送,<周志华:AI产业将更凸显个人英雄主义> http://tech.163.com/18/0601/13/DJ7J39US00098IEO.html 摘录一些 ...

随机推荐

  1. 多结果集IMultipleResult接口

    在某些任务中,需要执行多条sql语句,这样一次会返回多个结果集,在应用程序就需要处理多个结果集,在OLEDB中支持多结果集的接口是IMultipleResult. 查询数据源是否支持多结果集 并不是所 ...

  2. iDempiere 视频教程下载

    Created by 蓝色布鲁斯,QQ32876341,blog http://www.cnblogs.com/zzyan/ iDempiere官方中文wiki主页 http://wiki.idemp ...

  3. PHP基础--两个数组相加

    在PHP中,当两个数组相加时,会把第二个数组的取值添加到第一个数组上,同时覆盖掉下标相同的值: <?php $a = array("a" => "apple& ...

  4. 04_Redis数据类型(set、zset)

    [set:集合类型(高中的集合知识)] 集合类型:无序.不可重复 列表类型:有序.可重复 [set类型] 1.添加元素 语法:sadd key member1 member2...... 返回值:返回 ...

  5. 笨办法学Python(四十一)

    习题 41: 来自 Percal 25 号行星的哥顿人(Gothons) 你在上一节中发现 dict 的秘密功能了吗?你可以解释给自己吗?让我来给你解释一下,顺便和你自己的理解对比看有什么不同.这里是 ...

  6. STC12C5A60S2 51单片机最小系统

                                                                                    STC12C5A60S2 一.根据芯片文 ...

  7. June 07th 2017 Week 23rd Wednesday

    Failure is the condiment that gives success its flavor. 失败是让成功变美味的调味料. There are kinds of flavors in ...

  8. python入门10 循环语句

    两种循环: 1 for in 2 while #coding:utf-8 #/usr/bin/python """ 2018-11-03 dinghanhua 循环语句 ...

  9. CryptoSwift:密码学

    Hash (Digest) MD5 | SHA1 | SHA224 | SHA256 | SHA384 | SHA512 | SHA3 Cyclic Redundancy Check (CRC) CR ...

  10. 【LOJ6062】「2017 山东一轮集训 Day2」Pair(线段树套路题)

    点此看题面 大致题意: 给出一个长度为\(n\)的数列\(a\)和一个长度为\(m\)的数列\(b\),求\(a\)有多少个长度为\(m\)的子串与\(b\)匹配.数列匹配指存在一种方案使两个数列中的 ...