0 引言

决策树的目的在于构造一颗树像下面这样的树。

图1

图2

1. 如何构造呢?

1.1   参考资料。

      本例以图2为例,并参考了以下资料。

写的东西非常经典。

(3)机器学习(Tom.Mitchell著) 第三章 决策树,里面详细介绍了信息增益的计算,和熵的计算。建议大家参考

1.2 数据集(训练数据集)



outlook temperature humidity windy play
sunny hot high FALSE no
sunny hot high TRUE no
overcast hot high FALSE yes
rainy mild high FALSE yes
rainy cool normal FALSE yes
rainy cool normal TRUE no
overcast cool normal TRUE yes
sunny mild high FALSE no
sunny cool normal FALSE yes
rainy mild normal FALSE yes
sunny mild normal TRUE yes
overcast mild high TRUE yes
overcast hot normal FALSE yes
rainy mild high TRUE no

1.3 构造原则—选信息增益最大的

从图中知,一共有四个属性,outlook     temperature    humidity  windy,首先选哪一个作为树的第一个节点呢。答案是选信息增益越大的作为开始的节点。信息增益的计算公式如下:
Entropy(s)是熵,S样本集,Sv是子集。熵的计算公式如下:


举例:
根据以上的数据,我们只知道新的一天打球的概率是9/14,不打的概率是5/14。此时的熵为


对每项指标分别统计:在不同的取值下打球和不打球的次数。

table 2

outlook temperature humidity windy play
  yes no   yes no   yes no   yes no yes no
sunny 2 3 hot 2 2 high 3 4 FALSE 6 2 9 5
overcast 4 0 mild 4 2 normal 6 1 TRUR 3 3    
rainy 3 2 cool 3 1                

下面我们计算当已知变量outlook的值时,信息熵为多少。

outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971

outlook=overcast时,entropy=0

outlook=rainy时,entropy=0.971

而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693

这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247

同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。

gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。

1.4 为什么选信息增益最大的?

根据参考资料(2)的结论是:信息增益量越大,这个属性作为一棵树的根节点就能使这棵树更简洁(2)

1.5 递归:

接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。

依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。

1.6 递归结束的条件:

如果Examples都为正,那么返回label =+ 的单结点树Root ,熵为0

 如果Examples都为反,那么返回label =- 的单结点树Root ,熵为0

 如果Attributes为空,那么返回单结点树Root,label=Examples中最普遍的


2. 伪代码


3. java 实现

此仅贴主要的代码,源码请到我的github下载:
package sequence.machinelearning.decisiontree.myid3;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.LinkedList; public class MyID3 { private static LinkedList<String> attribute = new LinkedList<String>(); // 存储属性的名称
private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 存储每个属性的取值
private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始数据 public static final String patternString = "@attribute(.*)[{](.*?)[}]";
public static String[] yesNo;
public static TreeNode root; /**
*
* @param lines 传入要分析的数据集
* @param index 哪个属性?attribute的index
*/
public Double getGain(LinkedList<String[]> lines,int index){
Double gain=-1.0;
List<Double> li=new ArrayList<Double>();
//统计Yes No的次数
for(int i=0;i<yesNo.length;i++){
Double sum=0.0;
for(int j=0;j<lines.size();j++){
String[] line=lines.get(j);
//data为结构化数据,如果数据最后一列==yes,sum+1
if(line[line.length-1].equals(yesNo[i])){
sum=sum+1;
}
}
li.add(sum);
}
//计算Entropy(S)计算Entropy(S) 见参考书《机器学习 》Tom.Mitchell著 第3.4.1.2节
Double entropyS=TheMath.getEntropy(lines.size(), li);
//下面计算gain List<String> la=attributevalue.get(index);
List<Point> lasv=new ArrayList<Point>();
for(int n=0;n<la.size();n++){
String attvalue=la.get(n);
//统计Yes No的次数
List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes时发生的次数,是no发生的次数
Double Sv=0.0;//公式3.4中的Sv 见参考书《机器学习(Tom.Mitchell著)》
for(int i=0;i<yesNo.length;i++){
Double sum=0.0;
for(int j=0;j<lines.size();j++){
String[] line=lines.get(j);
//data为结构化数据,如果数据最后一列==yes,sum+1
if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){
sum=sum+1;
}
}
Sv=Sv+sum;//计算总数
lisub.add(sum);
}
//计算Entropy(S) 见参考书《机器学习(Tom.Mitchell著)》
Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub);
//
Point p=new Point();
p.setSv(Sv);
p.setEntropySv(entropySv);
lasv.add(p);
}
gain=TheMath.getGain(entropyS,lines.size(),lasv);
return gain;
}
//寻找最大的信息增益,将最大的属性定为当前节点,并返回该属性所在list的位置和gain值
public Maxgain getMaxGain(LinkedList<String[]> lines){
if(lines==null||lines.size()<=0){
return null;
}
Maxgain maxgain = new Maxgain();
Double maxvalue=0.0;
int maxindex=-1;
for(int i=0;i<attribute.size();i++){
Double tmp=getGain(lines,i);
if(maxvalue< tmp){
maxvalue=tmp;
maxindex=i;
}
}
maxgain.setMaxgain(maxvalue);
maxgain.setMaxindex(maxindex);
return maxgain;
}
//剪取数组
public LinkedList<String[]> filterLines(LinkedList<String[]> lines, String attvalue, int index){
LinkedList<String[]> newlines=new LinkedList<String[]>();
for(int i=0;i<lines.size();i++){
String[] line=lines.get(i);
if(line[index].equals(attvalue)){
newlines.add(line);
}
} return newlines;
}
public void createDTree(){
root=new TreeNode();
Maxgain maxgain=getMaxGain(data);
if(maxgain==null){
System.out.println("没有数据集,请检查!");
}
int maxKey=maxgain.getMaxindex();
String nodename=attribute.get(maxKey);
root.setName(nodename);
root.setLiatts(attributevalue.get(maxKey));
insertNode(data,root,maxKey);
}
/**
*
* @param lines 传入的数据集,作为新的递归数据集
* @param node 深入此节点
* @param index 属性位置
*/
public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){
List<String> liatts=node.getLiatts();
for(int i=0;i<liatts.size();i++){
String attname=liatts.get(i);
LinkedList<String[]> newlines=filterLines(lines,attname,index);
if(newlines.size()<=0){
System.out.println("出现异常,循环结束");
return;
}
Maxgain maxgain=getMaxGain(newlines);
double gain=maxgain.getMaxgain();
Integer maxKey=maxgain.getMaxindex();
//不等于0继续递归,等于0说明是叶子节点,结束递归。
if(gain!=0){
TreeNode subnode=new TreeNode();
subnode.setParent(node);
subnode.setFatherAttribute(attname);
String nodename=attribute.get(maxKey);
subnode.setName(nodename);
subnode.setLiatts(attributevalue.get(maxKey));
node.addChild(subnode);
//不等于0,继续递归
insertNode(newlines,subnode,maxKey);
}else{
TreeNode subnode=new TreeNode();
subnode.setParent(node);
subnode.setFatherAttribute(attname);
//叶子节点是yes还是no?取新行中最后一个必是其名称,因为只有完全是yes,或完全是no的情况下才会是叶子节点
String[] line=newlines.get(0);
String nodename=line[line.length-1];
subnode.setName(nodename);
node.addChild(subnode);
}
}
}
//输出决策树
public void printDTree(TreeNode node)
{
if(node.getChildren()==null){
System.out.println("--"+node.getName());
return;
}
System.out.println(node.getName());
List<TreeNode> childs = node.getChildren();
for (int i = 0; i < childs.size(); i++)
{
System.out.println(childs.get(i).getFatherAttribute());
printDTree(childs.get(i));
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
MyID3 myid3 = new MyID3();
myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff"));
myid3.createDTree();
myid3.printDTree(root);
}
//读取arff文件,给attribute、attributevalue、data赋值
public void readARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
if (line.startsWith("@decision")) {
line = br.readLine();
if(line=="")
continue;
yesNo = line.split(",");
}
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
attributevalue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
}

版权声明:本文为博主原创文章,未经博主允许不得转载。

决策树算法原理及JAVA实现(ID3)的更多相关文章

  1. 机器学习相关知识整理系列之一:决策树算法原理及剪枝(ID3,C4.5,CART)

    决策树是一种基本的分类与回归方法.分类决策树是一种描述对实例进行分类的树形结构,决策树由结点和有向边组成.结点由两种类型,内部结点表示一个特征或属性,叶结点表示一个类. 1. 基础知识 熵 在信息学和 ...

  2. 决策树算法原理(ID3,C4.5)

    决策树算法原理(CART分类树) CART回归树 决策树的剪枝 决策树可以作为分类算法,也可以作为回归算法,同时特别适合集成学习比如随机森林. 1. 决策树ID3算法的信息论基础   1970年昆兰找 ...

  3. 决策树算法原理(CART分类树)

    决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不 ...

  4. 决策树算法原理--good blog

    转载于:http://www.cnblogs.com/pinard/p/6050306.html (楼主总结的很好,就拿来主义了,不顾以后还是多像楼主学习) 决策树算法在机器学习中算是很经典的一个算法 ...

  5. ID3决策树算法原理及C++实现(其中代码转自别人的博客)

    分类是数据挖掘中十分重要的组成部分.分类作为一种无监督学习方式被广泛的使用. 之前关于"数据挖掘中十大经典算法"中,基于ID3核心思想的分类算法C4.5榜上有名.所以不难看出ID3 ...

  6. 决策树算法(1)含java源代码

    信息熵:变量的不确定性越大,熵越大.熵可用下面的公式描述:-(p1*logp1+p2*logp2+...+pn*logpn)pi表示事件i发生的概率ID3:GAIN(A)=INFO(D)-INFO_A ...

  7. scikit-learn决策树算法类库使用小结

    之前对决策树的算法原理做了总结,包括决策树算法原理(上)和决策树算法原理(下).今天就从实践的角度来介绍决策树算法,主要是讲解使用scikit-learn来跑决策树算法,结果的可视化以及一些参数调参的 ...

  8. 决策树算法——ID3

    决策树算法是一种有监督的分类学习算法.利用经验数据建立最优分类树,再用分类树预测未知数据. 例子:利用学生上课与作业状态预测考试成绩. 上述例子包含两个可以观测的属性:上课是否认真,作业是否认真,并以 ...

  9. python机器学习笔记 ID3决策树算法实战

    前面学习了决策树的算法原理,这里继续对代码进行深入学习,并掌握ID3的算法实践过程. ID3算法是一种贪心算法,用来构造决策树,ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性 ...

随机推荐

  1. 【leetcode刷题笔记】Longest Substring Without Repeating Characters

    Given a string, find the length of the longest substring without repeating characters. For example, ...

  2. 0417 封装 property、classmethod、staricmethod

    一.封装 把一堆东西装在一个容器里 函数和属性装到了一个非全局的命名空间 class A: __N = 123 # 静态变量 def func(self): print(A.__N) # 在类的内部使 ...

  3. codeforces 686B

    题意:给出一个序列,只允许进行相邻的两两交换,给出使序列变为非降序列的操作方案. 思路:关键点是操作次数不限,冒泡排序. #include<iostream> #include<cs ...

  4. 侠客群控引擎二次开发SDK可用方法大全(持续更新)

    如这篇文章所示 http://www.xiake.net/blog/archives/1 侠客的插件SDK能提供很强大的功能(所有官方使用的方法都有提供) 这篇文章是详细介绍所有SDK可调用的方法 首 ...

  5. 关于ios::sync_with_stdio(false);和 cin.tie(0)加速c++输入输出流

    原文地址:http://www.hankcs.com/program/cpp/cin-tie-with-sync_with_stdio-acceleration-input-and-output.ht ...

  6. 异常之: The server time zone value '�й���׼ʱ��' is unrecognized or represents more than one time zone.

    在 MySQL 中执行命令试下: set global time_zone=’+8:00’  设置为东8区 就不报错了. show variables like '%time_zone%'; 解释:在 ...

  7. 实现stack 加上·getMin功能 时间复杂度为O(n)

    package com.hzins.suanfa; import java.util.Stack; /** * 实现stack 加上·getMin功能 时间复杂度为O(n) * @author Adm ...

  8. nodejs buffer 总结

    JavaScript 语言自身只有字符串数据类型,没有二进制数据类型.Buffer 类,该类用来创建一个专门存放二进制数据的缓存区. 一个 Buffer 类似于一个整数数组,但它对应于 V8 堆内存之 ...

  9. 关于c++中char*、char ch[]和string区别

    一.字符串指针: char* ch="hello"; 这里的"hello"是字符串常量,是不可以改变的,即通过ch[0]="s"会编译出错. ...

  10. linux命令学习笔记( 7 ) : mv 命令

    mv命令是move的缩写,可以用来移动文件或者将文件改名(move (rename) files),是Linux系统下常用的命令, 经常用来备份文件或者目录. .命令格式: mv [选项] 源文件或目 ...