python实现简单决策树(信息增益)——基于周志华的西瓜书数据
数据集如下:
色泽 根蒂 敲声 纹理 脐部 触感 好瓜
青绿 蜷缩 浊响 清晰 凹陷 硬滑 是
乌黑 蜷缩 沉闷 清晰 凹陷 硬滑 是
乌黑 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 蜷缩 沉闷 清晰 凹陷 硬滑 是
浅白 蜷缩 浊响 清晰 凹陷 硬滑 是
青绿 稍蜷 浊响 清晰 稍凹 软粘 是
乌黑 稍蜷 浊响 稍糊 稍凹 软粘 是
乌黑 稍蜷 浊响 清晰 稍凹 硬滑 是
乌黑 稍蜷 沉闷 稍糊 稍凹 硬滑 否
青绿 硬挺 清脆 清晰 平坦 软粘 否
浅白 硬挺 清脆 模糊 平坦 硬滑 否
浅白 蜷缩 浊响 模糊 平坦 软粘 否
青绿 稍蜷 浊响 稍糊 凹陷 硬滑 否
浅白 稍蜷 沉闷 稍糊 凹陷 硬滑 否
乌黑 稍蜷 浊响 清晰 稍凹 软粘 否
浅白 蜷缩 浊响 模糊 平坦 硬滑 否
青绿 蜷缩 沉闷 稍糊 稍凹 硬滑 否
基于信息增益的ID3决策树的原理这里不再赘述,读者如果不明白可参考西瓜书对这部分内容的讲解。
python实现代码如下:
from math import log2
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties # 统计label出现次数
def get_counts(data):
total = len(data)
results = {}
for d in data:
results[d[-1]] = results.get(d[-1], 0) + 1
return results, total # 计算信息熵
def calcu_entropy(data):
results, total = get_counts(data)
ent = sum([-1.0*v/total*log2(v/total) for v in results.values()])
return ent # 计算每个feature的信息增益
def calcu_each_gain(column, update_data):
total = len(column)
grouped = update_data.iloc[:, -1].groupby(by=column)
temp = sum([len(g[1])/total*calcu_entropy(g[1]) for g in list(grouped)])
return calcu_entropy(update_data.iloc[:, -1]) - temp # 获取最大的信息增益的feature
def get_max_gain(temp_data):
columns_entropy = [(col, calcu_each_gain(temp_data[col], temp_data)) for col in temp_data.iloc[:, :-1]]
columns_entropy = sorted(columns_entropy, key=lambda f: f[1], reverse=True)
return columns_entropy[0] # 去掉数据中已存在的列属性内容
def drop_exist_feature(data, best_feature):
attr = pd.unique(data[best_feature])
new_data = [(nd, data[data[best_feature] == nd]) for nd in attr]
new_data = [(n[0], n[1].drop([best_feature], axis=1)) for n in new_data]
return new_data # 获得出现最多的label
def get_most_label(label_list):
label_dict = {}
for l in label_list:
label_dict[l] = label_dict.get(l, 0) + 1
sorted_label = sorted(label_dict.items(), key=lambda ll: ll[1], reverse=True)
return sorted_label[0][0] # 创建决策树
def create_tree(data_set, column_count):
label_list = data_set.iloc[:, -1]
if len(pd.unique(label_list)) == 1:
return label_list.values[0]
if all([len(pd.unique(data_set[i])) ==1 for i in data_set.iloc[:, :-1].columns]):
return get_most_label(label_list)
best_attr = get_max_gain(data_set)[0]
tree = {best_attr: {}}
exist_attr = pd.unique(data_set[best_attr])
if len(exist_attr) != len(column_count[best_attr]):
no_exist_attr = set(column_count[best_attr]) - set(exist_attr)
for nea in no_exist_attr:
tree[best_attr][nea] = get_most_label(label_list)
for item in drop_exist_feature(data_set, best_attr):
tree[best_attr][item[0]] = create_tree(item[1], column_count)
return tree # 决策树绘制基本参考《机器学习实战》书内的代码以及博客:http://blog.csdn.net/c406495762/article/details/76262487
# 获取树的叶子节点数目
def get_num_leafs(decision_tree):
num_leafs = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
num_leafs += get_num_leafs(second_dict[k])
else:
num_leafs += 1
return num_leafs # 获取树的深度
def get_tree_depth(decision_tree):
max_depth = 0
first_str = next(iter(decision_tree))
second_dict = decision_tree[first_str]
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
this_depth = 1 + get_tree_depth(second_dict[k])
else:
this_depth = 1
if this_depth > max_depth:
max_depth = this_depth
return max_depth # 绘制节点
def plot_node(node_txt, center_pt, parent_pt, node_type):
arrow_args = dict(arrowstyle='<-')
font = FontProperties(fname=r'C:\Windows\Fonts\STXINGKA.TTF', size=15)
create_plot.ax1.annotate(node_txt, xy=parent_pt, xycoords='axes fraction', xytext=center_pt,
textcoords='axes fraction', va="center", ha="center", bbox=node_type,
arrowprops=arrow_args, FontProperties=font) # 标注划分属性
def plot_mid_text(cntr_pt, parent_pt, txt_str):
font = FontProperties(fname=r'C:\Windows\Fonts\MSYH.TTC', size=10)
x_mid = (parent_pt[0] - cntr_pt[0]) / 2.0 + cntr_pt[0]
y_mid = (parent_pt[1] - cntr_pt[1]) / 2.0 + cntr_pt[1]
create_plot.ax1.text(x_mid, y_mid, txt_str, va="center", ha="center", color='red', FontProperties=font) # 绘制决策树
def plot_tree(decision_tree, parent_pt, node_txt):
d_node = dict(boxstyle="sawtooth", fc="0.8")
leaf_node = dict(boxstyle="round4", fc='0.8')
num_leafs = get_num_leafs(decision_tree)
first_str = next(iter(decision_tree))
cntr_pt = (plot_tree.xoff + (1.0 +float(num_leafs))/2.0/plot_tree.totalW, plot_tree.yoff)
plot_mid_text(cntr_pt, parent_pt, node_txt)
plot_node(first_str, cntr_pt, parent_pt, d_node)
second_dict = decision_tree[first_str]
plot_tree.yoff = plot_tree.yoff - 1.0/plot_tree.totalD
for k in second_dict.keys():
if isinstance(second_dict[k], dict):
plot_tree(second_dict[k], cntr_pt, k)
else:
plot_tree.xoff = plot_tree.xoff + 1.0/plot_tree.totalW
plot_node(second_dict[k], (plot_tree.xoff, plot_tree.yoff), cntr_pt, leaf_node)
plot_mid_text((plot_tree.xoff, plot_tree.yoff), cntr_pt, k)
plot_tree.yoff = plot_tree.yoff + 1.0/plot_tree.totalD def create_plot(dtree):
fig = plt.figure(1, facecolor='white')
fig.clf()
axprops = dict(xticks=[], yticks=[])
create_plot.ax1 = plt.subplot(111, frameon=False, **axprops)
plot_tree.totalW = float(get_num_leafs(dtree))
plot_tree.totalD = float(get_tree_depth(dtree))
plot_tree.xoff = -0.5/plot_tree.totalW
plot_tree.yoff = 1.0
plot_tree(dtree, (0.5, 1.0), '')
plt.show() if __name__ == '__main__':
my_data = pd.read_csv('./watermelon2.0.csv', encoding='gbk')
column_count = dict([(ds, list(pd.unique(my_data[ds]))) for ds in my_data.iloc[:, :-1].columns])
d_tree = create_tree(my_data, column_count)
create_plot(d_tree)
绘制的决策树如下:
python实现简单决策树(信息增益)——基于周志华的西瓜书数据的更多相关文章
- 周志华-机器学习西瓜书-第三章习题3.5 LDA
本文为周志华机器学习西瓜书第三章课后习题3.5答案,编程实现线性判别分析LDA,数据集为书本第89页的数据 首先介绍LDA算法流程: LDA的一个手工计算数学实例: 课后习题的代码: # coding ...
- 支持向量机(SVM)算法分析——周志华的西瓜书学习
1.线性可分 对于一个数据集: 如果存在一个超平面X能够将D中的正负样本精确地划分到S的两侧,超平面如下: 那么数据集D就是线性可分的,否则,不可分. w称为法向量,决定了超平面的方向:b为位移量,决 ...
- (二)《机器学习》(周志华)第4章 决策树 笔记 理论及实现——“西瓜树”——CART决策树
CART决策树 (一)<机器学习>(周志华)第4章 决策树 笔记 理论及实现——“西瓜树” 参照上一篇ID3算法实现的决策树(点击上面链接直达),进一步实现CART决策树. 其实只需要改动 ...
- 【深度森林第三弹】周志华等提出梯度提升决策树再胜DNN
[深度森林第三弹]周志华等提出梯度提升决策树再胜DNN 技术小能手 2018-06-04 14:39:46 浏览848 分布式 性能 神经网络 还记得周志华教授等人的“深度森林”论文吗?今天, ...
- 【Todo】【读书笔记】机器学习-周志华
书籍位置: /Users/baidu/Documents/Data/Interview/机器学习-数据挖掘/<机器学习_周志华.pdf> 一共442页.能不能这个周末先囫囵吞枣看完呢.哈哈 ...
- [重磅]Deep Forest,非神经网络的深度模型,周志华老师最新之作,三十分钟理解!
欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 深度学习最大的贡献,个人认为就是表征 ...
- 偶尔转帖:AI会议的总结(by南大周志华)
偶尔转帖:AI会议的总结(by南大周志华) 说明: 纯属个人看法, 仅供参考. tier-1的列得较全, tier-2的不太全, tier-3的很不全. 同分的按字母序排列. 不很严谨地说, tier ...
- 【转载】 AI会议的总结(by南大周志华)
原文地址: https://blog.csdn.net/LiFeitengup/article/details/8441054 最近在查找期刊会议级别的时候发现这篇博客,应该是2012年之前的内容,现 ...
- AI产业将更凸显个人英雄主义 周志华老师的观点是如此的有深度
今天无意间在网上看的了一则推送,<周志华:AI产业将更凸显个人英雄主义> http://tech.163.com/18/0601/13/DJ7J39US00098IEO.html 摘录一些 ...
随机推荐
- for ...in 、for each ...in、 for...of(https://developer.mozilla.org/zh-CN/docs/Web/JavaScript/Reference/Statements/for...of)
1.for in for...in语句以任意顺序遍历一个对象的可枚举性.对于每个不同的属性,语句都会被执行. 语法 for (variable in object) {...} variable 在每 ...
- ajax异步请求/同源策略/跨域传值
基本概念 Ajax 全称是异步的 JavaScript 和 XML . 通过在后台与服务器进行少量数据交换,AJAX 可以使网页实现异步更新.这意味着可以在不重新加载整个网页的情况下,对网页的某部分进 ...
- vue1.0 与 Vue2.0的一些区别 及用法
1.Vue2.0的模板标记外必须使用元素包起来: eg:Vue1.0的写法 <!DOCTYPE html> <html> <head> <meta chars ...
- 如何禁用 Azure 虚拟机的日期时间同步
问题描述 由于 Azure 虚拟机的特殊性,物理主机会实时同步虚拟机的时间和日期.当有特殊需求时,客户想要停止日期时间的同步,但是一些常见的关闭 NTP 服务等操作都会失败. 解决方案 Importa ...
- sql字段合并与分组聚合
http://blog.csdn.net/cuixianlong/article/details/74024846 1 字段合并 原始数据如下:表名为Employee ID FirstName Las ...
- Hadoop ->> HDFS(Hadoop Distributed File System)
HDFS全称是Hadoop Distributed File System.作为分布式文件系统,具有高容错性的特点.它放宽了POSIX对于操作系统接口的要求,可以直接以流(Stream)的形式访问文件 ...
- vue2.x 随记
1. 外部js调用vue的方法等 将vue实例中的this传入外部js文件(比如作为某方法的参数),即可访问传入实例的所有内容.调用该实例中子组件的方法,用$refs. 2. 路由参数 传递:vm.$ ...
- php 获取毫秒时间戳
function getMsec(){//返回毫秒时间戳 $arr = explode(' ',microtime()); $hm = 0; foreach($arr as $v){ $hm += f ...
- CRM中间件里的发布-订阅者模式
从事务码SMW01里能观察到一个BDOC可能被发送往不止一个目的site去,比如下图所示的5个site都会收到该site,而高亮显示的SMOF_ERPSITE代表ERP系统QI3的client 504 ...
- 失去光标display=none事件的坑
1.实现效果: 失去光标进行判断,如果内容为空出现提示. 2.页面代码: <tr class="tableform_tr"> <td width="15 ...