这是半成品, 已完成了 fit() 部分, 形成了包含一棵完整树的 node 对象.

后续工作是需解析该 node 对象, 完成 predict() 工作.

# !/usr/bin/python
# -*- coding:utf-8 -*- """
Re-implement ID3 algorithm as a practice
Only information gain criterion supplied in our DT algorithm.
使用该 ID3 re-implement 的前提:
1. train data 的标签必须转成0,1,2,...的形式
2. 只能处理连续特征
""" # Author: 相忠良(Zhong-Liang Xiang) <ugoood@163.com>
# Finished at July ***, 2017 import numpy as np
from sklearn import datasets, cross_validation ## load data
def load_data():
iris = datasets.load_iris()
return cross_validation.train_test_split(iris.data, iris.target, test_size=0.25, random_state=0) class DecisionNode():
def __init__(self, feature_i=None, threshold=None, value=None, left_branch=None, right_branch=None):
self.feature_i = feature_i # Best feature's index
self.threshold = threshold # Best split threshold in the feature
self.value = value # Value if the node is a leaf in the tree
self.left_branch = left_branch # 'Left' subtree
self.right_branch = right_branch # 'Right' subtree
# print feature_i, 'feature_i'
print self.value, 'value' class MyDecisionTreeClassifier():
trees = []
num_eles_in_class_label = 3 # 分类标签类的个数
tree = {}
predict_label = []
X_train = []
y_train = []
max_depth = 3
max_leaf_nodes = 30
min_samples_leaf = 1
count = 0 def __init__(self, ):
self.root = None # TODO
def fit(self, X, y):
self.root = DecisionNode(self.createTree(X, y)) def predict(self, X):
pass def score(self, X, y):
pass ## entropy
# e.g entropy(y_test)
def __entropy(self, label_list):
bincount = np.bincount(label_list, minlength=self.num_eles_in_class_label)
sum = np.sum(bincount)
# print 'sum in entropy ', sum
temp = 1.0 * bincount / sum
tot = 0 # to avoid log2(0)
for e in temp:
if (e != 0):
tot += e * (-np.log2(e))
return tot def gain(self, pre_split_label_list, after_split_label_list_2d):
total = 0
n = after_split_label_list_2d[0].__len__() + after_split_label_list_2d[1].__len__()
for item in after_split_label_list_2d:
total += self.__entropy(item) * (1.0 * item.__len__() / n)
return self.__entropy(pre_split_label_list) - total ## 针对np.bincount()的结果,如[37 34 41],判断是否为纯节点,既[0 22 0]的形式
def isPure(self, bincount_list):
sb = sorted(bincount_list)
if ((sb[-1] != 0) & (sb[-2] == 0)):
return True
else:
return False ## 计算出现次数最多的类别标签
def maxCate(self, bincount_list):
bincount_list = np.array(bincount_list)
return bincount_list.argmax() ## 递归停止条件:
# 如果样例小于等于10,停止
# 如果样例大于10 且 点纯,停止
# 否则 继续分裂
def createTree(self, X, y):
bincount_list = np.bincount(y, minlength=self.num_eles_in_class_label)
if ((self.isPure(bincount_list)) & (np.sum(bincount_list) > 10)):
print bincount_list, '11111'
return DecisionNode(value=self.maxCate(bincount_list))
elif (np.sum(bincount_list) <= 10):
print bincount_list, '22222'
return DecisionNode(value=self.maxCate(bincount_list))
else:
print bincount_list, '33333' f, v, g = self.seek_best_split_feature(X, y)
mask_big = X[:, f] > v
mask_sma = X[:, f] <= v
bigger_X = []
bigger_y = []
smaller_X = []
smaller_y = []
bigger_X.append(X[mask_big])
bigger_y.append(y[mask_big])
smaller_X.append(X[mask_sma])
smaller_y.append(y[mask_sma]) left_branch = self.createTree(bigger_X[0], bigger_y[0])
right_branch = self.createTree(smaller_X[0], smaller_y[0])
return DecisionNode(feature_i=f, threshold=v, left_branch=left_branch, right_branch=right_branch) ## k>=2 特征区间切分点个数
# samples 样本
# labels 样本对应的标签
# return: best_feature, best_split_point, gain_on_that_point
def seek_best_split_feature(self, samples, labels, k=10): # 2 2.84 0.915290847812
samples = np.array(samples)
labels = np.array(labels)
best_split_point_pool = {} # 最佳分裂特征,点,及对应的gain
col_indx = 0 # 遍历所有特征,寻找某特征最佳分裂点
while col_indx < samples.shape[1]:
max = np.max(samples[:, col_indx])
min = np.min(samples[:, col_indx])
split_point = np.linspace(min, max, k, False)[1:]
# 寻找某特征最佳分裂点
temp = []
dic = {}
for p in split_point:
index_less = np.where(samples[:, col_indx] < p)[0] # [1 2]
index_bigger = np.where(samples[:, col_indx] >= p)[0]
label_less = labels[index_less]
label_bigger = labels[index_bigger]
temp.append(list(label_less))
temp.append(list(label_bigger))
g = self.gain(labels, temp)
dic[p] = g
temp = []
best_key = sorted(dic, key=lambda x: dic[x])[-1] # 返回value最大的那个key
dic_temp = {}
dic_temp[best_key] = dic[best_key]
best_split_point_pool[col_indx] = dic_temp
col_indx += 1 # 特征列表
feature_name_box = list(best_split_point_pool.keys())
b = list(best_split_point_pool.values()) # 临时表
# 最大gain列表
gain_box = []
# 最佳切分点列表
point_box = []
for item in b:
gain_box.append(item.values()[0])
point_box.append(item.keys()[0]) best_feature = feature_name_box[np.argmax(gain_box)]
best_split_point = point_box[np.argmax(gain_box)]
gain_on_that_point = np.max(gain_box)
return best_feature, best_split_point, gain_on_that_point ## 测试用例 X_train, X_test, y_train, y_test = load_data()
cls = MyDecisionTreeClassifier() a = [[9, 2, 3, 4],
[5, 6, 7, 8],
[1, 10, 11, 12],
[13, 14, 15, 16]]
b = [0, 1, 2, 3]
a = np.array(a)
b = np.array(b) # xx = [2,1,1]
# print cls.maxCate(xx),'11111111111111111111111' cls.fit(X_train, y_train)
tree = cls.root
print type(cls.root) '''
下面是编程过程中留下的经验
''' # 重要1: np.linspace(0,1,5) 0-1之间,等分5份,包括首尾
# np.linspace(0,1,5)
# [ 0. 0.25 0.5 0.75 1. ] # 重要2: np.where(a[:,0]>2) 返回矩阵a中第0列值大于2的那些行的索引号
# 返回值的样子 (array([1, 2]),) # 重要3: 返回value最大的那个key
# print(sorted(dic, key=lambda x: dic[x])[-1]) # 重要4: np.bincount()指定最小长度
# xxx = [1,1,1,1,1]
# print np.bincount(xxx,minlength=3)
# 结果: [0 5 0]

重写轮子之 ID3的更多相关文章

  1. 重写轮子之 GaussionNB

    我仿照sk-learn 中 GaussionNB 的结构, 重写了该算法的轮子,命名为 MyGaussionNB, 如下: # !/usr/bin/python # -*- coding:utf-8 ...

  2. 重写轮子之 kNN

    # !/usr/bin/python # -*- coding:utf-8 -*- """ Re-implement kNN algorithm as a practic ...

  3. 关于重写ID3 Algorithm Based On MapReduceV1/C++/Streaming的一些心得体会

    心血来潮,同时想用C++连连手.面对如火如荼的MP,一阵念头闪过,如果把一些ML领域的玩意整合到MP里面是不是很有意思 确实很有意思,可惜mahout来高深,我也看不懂.干脆自动动手丰衣足食,加上自己 ...

  4. 【转】C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1)

    C# 重写WndProc 拦截 发送 系统消息 + windows消息常量值(1) #region 截获消息        /// 截获消息  处理XP不能关机问题        protected ...

  5. Asp.net Mvc 请求是如何到达 MvcHandler的——UrlRoutingModule、MvcRouteHandler分析,并造个轮子

    这个是转载自:http://www.cnblogs.com/keyindex/archive/2012/08/11/2634005.html(那个比较容易忘记,希望博主不要生气的) 前言 本文假定读者 ...

  6. 拆解轮子之XRecyclerView

    简介 这个轮子是对RecyclerView的封装,主要完成了下拉刷新.上拉加载更多.RecyclerView头部.在我的Material Design学习项目中使用到了项目地址,感觉还不错.趁着毕业答 ...

  7. 跨平台技术实践案例: 用 reactxp 重写墨刀的移动端

    Authors:  Gao Cong, Perry Poon Illustrators:  Shena Bian April 20, 2019 重新编写,又一次,我们又一次重新编写了移动端应用和移动端 ...

  8. 星级评分原理 N次重写的分析

    使用的是雪碧图,用的软件是CSS Sprite Tools 第一次实现与分析: <!DOCTYPE html> <html> <head> <meta cha ...

  9. [18/11/29] 继承(extends)和方法的重写(override,不是重载)

    一.何为继承?(对原有类的扩充) 继承让我们更加容易实现类的扩展. 比如,我们定义了人类,再定义Boy类就只需要扩展人类即可.实现了代码的重用,不用再重新发明轮子(don’t  reinvent  w ...

随机推荐

  1. Python系列-python文件操作

    原链接:https://blog.csdn.net/m0_37745438/article/details/79573414 python提供了一系列方法来对文件进行读取.写入等操作 一.打开文件的方 ...

  2. Linux探索之路1---CentOS入坑笔记整理

    前言 上次跟运维去行方安装行内环境,发现linux命令还是不是很熟练.特别是用户权限分配以及vi下的快捷操作.于是决定在本地安装一个CentOS虚拟机,后面有时间就每天学习一点Linux常用命令. 作 ...

  3. IT技术有感

    今天看技术文章,spring相关的,某一个点以前每次看一直不理解, 可是不知道为什么隔了1年左右,中间什么都没做,现在却都懂了. 在看懂的那一刻,笼罩在我心上的躁动突然平静了许多,我的心这一年来前所未 ...

  4. FPGA与MATLAB数据交互高效率验证算法——仿真阶段

    之前博文是对基本设计技巧的总结和一些小设计随笔,内容有点杂,缺乏目的性.本来后续计划设计几个小项目,但导师的任务比较紧,所以为了提高效率,后续博客会涉及到很多算法方面的设计与验证的内容,主要关于OFD ...

  5. [洛谷P1197/BZOJ1015][JSOI2008]星球大战Starwar - 并查集,离线,联通块

    Description 很久以前,在一个遥远的星系,一个黑暗的帝国靠着它的超级武器统治者整个星系.某一天,凭着一个偶然的机遇,一支反抗军摧毁了帝国的超级武器,并攻下了星系中几乎所有的星球.这些星球通过 ...

  6. url的解码方式

    #coding:utf-8 import urllib legal_person_string = "%E6%B3%95%E5%AE%9A%E4%BB%A3%E8%A1%A8%E4%BA%B ...

  7. requests-文件上传

    import requests files = {'file':open('D://tomas.jpg','rb')}#设定一个files,打开文件对象 response = requests.pos ...

  8. 一 Unicode和UTF-8的异同

    下面就是我的笔记,主要用来整理自己的思路.但是,我尽量试图写得通俗易懂,希望能对其他朋友有用.毕竟,字符编码是计算机技术的基石,想要熟练使用计算机,就必须懂得一点字符编码的知识.1. ASCII码我们 ...

  9. [LeetCode] Array Partition I 数组分割之一

    Given an array of 2n integers, your task is to group these integers into n pairs of integer, say (a1 ...

  10. [HNOI2013]比赛

    题目描述 沫沫非常喜欢看足球赛,但因为沉迷于射箭游戏,错过了最近的一次足球联赛.此次联 赛共N支球队参加,比赛规则如下: (1) 每两支球队之间踢一场比赛. (2) 若平局,两支球队各得1分. (3) ...