决策树算法原理及JAVA实现(ID3)
0 引言
决策树的目的在于构造一颗树像下面这样的树。
图1
图2
1. 如何构造呢?
1.1 参考资料。
写的东西非常经典。
1.2 数据集(训练数据集)
| outlook | temperature | humidity | windy | play |
| sunny | hot | high | FALSE | no |
| sunny | hot | high | TRUE | no |
| overcast | hot | high | FALSE | yes |
| rainy | mild | high | FALSE | yes |
| rainy | cool | normal | FALSE | yes |
| rainy | cool | normal | TRUE | no |
| overcast | cool | normal | TRUE | yes |
| sunny | mild | high | FALSE | no |
| sunny | cool | normal | FALSE | yes |
| rainy | mild | normal | FALSE | yes |
| sunny | mild | normal | TRUE | yes |
| overcast | mild | high | TRUE | yes |
| overcast | hot | normal | FALSE | yes |
| rainy | mild | high | TRUE | no |
1.3 构造原则—选信息增益最大的
对每项指标分别统计:在不同的取值下打球和不打球的次数。
table 2
| outlook | temperature | humidity | windy | play | |||||||||
| yes | no | yes | no | yes | no | yes | no | yes | no | ||||
| sunny | 2 | 3 | hot | 2 | 2 | high | 3 | 4 | FALSE | 6 | 2 | 9 | 5 |
| overcast | 4 | 0 | mild | 4 | 2 | normal | 6 | 1 | TRUR | 3 | 3 | ||
| rainy | 3 | 2 | cool | 3 | 1 | ||||||||
下面我们计算当已知变量outlook的值时,信息熵为多少。
outlook=sunny时,2/5的概率打球,3/5的概率不打球。entropy=0.971
outlook=overcast时,entropy=0
outlook=rainy时,entropy=0.971
而根据历史统计数据,outlook取值为sunny、overcast、rainy的概率分别是5/14、4/14、5/14,所以当已知变量outlook的值时,信息熵为:5/14 × 0.971 + 4/14 × 0 + 5/14 × 0.971 = 0.693
这样的话系统熵就从0.940下降到了0.693,信息增溢gain(outlook)为0.940-0.693=0.247
同样可以计算出gain(temperature)=0.029,gain(humidity)=0.152,gain(windy)=0.048。
gain(outlook)最大(即outlook在第一步使系统的信息熵下降得最快),所以决策树的根节点就取outlook。
1.4 为什么选信息增益最大的?
1.5 递归:
接下来要确定N1取temperature、humidity还是windy?在已知outlook=sunny的情况,根据历史数据,我们作出类似table 2的一张表,分别计算gain(temperature)、gain(humidity)和gain(windy),选最大者为N1。
依此类推,构造决策树。当系统的信息熵降为0时,就没有必要再往下构造决策树了,此时叶子节点都是纯的--这是理想情况。最坏的情况下,决策树的高度为属性(决策变量)的个数,叶子节点不纯(这意味着我们要以一定的概率来作出决策)。
1.6 递归结束的条件:
如果Examples都为反,那么返回label =- 的单结点树Root ,熵为0
如果Attributes为空,那么返回单结点树Root,label=Examples中最普遍的
2. 伪代码
3. java 实现
package sequence.machinelearning.decisiontree.myid3; import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileWriter;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.LinkedList; public class MyID3 { private static LinkedList<String> attribute = new LinkedList<String>(); // 存储属性的名称
private static LinkedList<ArrayList<String>> attributevalue = new LinkedList<ArrayList<String>>(); // 存储每个属性的取值
private static LinkedList<String[]> data = new LinkedList<String[]>();; // 原始数据 public static final String patternString = "@attribute(.*)[{](.*?)[}]";
public static String[] yesNo;
public static TreeNode root; /**
*
* @param lines 传入要分析的数据集
* @param index 哪个属性?attribute的index
*/
public Double getGain(LinkedList<String[]> lines,int index){
Double gain=-1.0;
List<Double> li=new ArrayList<Double>();
//统计Yes No的次数
for(int i=0;i<yesNo.length;i++){
Double sum=0.0;
for(int j=0;j<lines.size();j++){
String[] line=lines.get(j);
//data为结构化数据,如果数据最后一列==yes,sum+1
if(line[line.length-1].equals(yesNo[i])){
sum=sum+1;
}
}
li.add(sum);
}
//计算Entropy(S)计算Entropy(S) 见参考书《机器学习 》Tom.Mitchell著 第3.4.1.2节
Double entropyS=TheMath.getEntropy(lines.size(), li);
//下面计算gain List<String> la=attributevalue.get(index);
List<Point> lasv=new ArrayList<Point>();
for(int n=0;n<la.size();n++){
String attvalue=la.get(n);
//统计Yes No的次数
List<Double> lisub=new ArrayList<Double>();//如:sunny 是yes时发生的次数,是no发生的次数
Double Sv=0.0;//公式3.4中的Sv 见参考书《机器学习(Tom.Mitchell著)》
for(int i=0;i<yesNo.length;i++){
Double sum=0.0;
for(int j=0;j<lines.size();j++){
String[] line=lines.get(j);
//data为结构化数据,如果数据最后一列==yes,sum+1
if(line[index].equals(attvalue)&&line[line.length-1].equals(yesNo[i])){
sum=sum+1;
}
}
Sv=Sv+sum;//计算总数
lisub.add(sum);
}
//计算Entropy(S) 见参考书《机器学习(Tom.Mitchell著)》
Double entropySv=TheMath.getEntropy(Sv.intValue(), lisub);
//
Point p=new Point();
p.setSv(Sv);
p.setEntropySv(entropySv);
lasv.add(p);
}
gain=TheMath.getGain(entropyS,lines.size(),lasv);
return gain;
}
//寻找最大的信息增益,将最大的属性定为当前节点,并返回该属性所在list的位置和gain值
public Maxgain getMaxGain(LinkedList<String[]> lines){
if(lines==null||lines.size()<=0){
return null;
}
Maxgain maxgain = new Maxgain();
Double maxvalue=0.0;
int maxindex=-1;
for(int i=0;i<attribute.size();i++){
Double tmp=getGain(lines,i);
if(maxvalue< tmp){
maxvalue=tmp;
maxindex=i;
}
}
maxgain.setMaxgain(maxvalue);
maxgain.setMaxindex(maxindex);
return maxgain;
}
//剪取数组
public LinkedList<String[]> filterLines(LinkedList<String[]> lines, String attvalue, int index){
LinkedList<String[]> newlines=new LinkedList<String[]>();
for(int i=0;i<lines.size();i++){
String[] line=lines.get(i);
if(line[index].equals(attvalue)){
newlines.add(line);
}
} return newlines;
}
public void createDTree(){
root=new TreeNode();
Maxgain maxgain=getMaxGain(data);
if(maxgain==null){
System.out.println("没有数据集,请检查!");
}
int maxKey=maxgain.getMaxindex();
String nodename=attribute.get(maxKey);
root.setName(nodename);
root.setLiatts(attributevalue.get(maxKey));
insertNode(data,root,maxKey);
}
/**
*
* @param lines 传入的数据集,作为新的递归数据集
* @param node 深入此节点
* @param index 属性位置
*/
public void insertNode(LinkedList<String[]> lines,TreeNode node,int index){
List<String> liatts=node.getLiatts();
for(int i=0;i<liatts.size();i++){
String attname=liatts.get(i);
LinkedList<String[]> newlines=filterLines(lines,attname,index);
if(newlines.size()<=0){
System.out.println("出现异常,循环结束");
return;
}
Maxgain maxgain=getMaxGain(newlines);
double gain=maxgain.getMaxgain();
Integer maxKey=maxgain.getMaxindex();
//不等于0继续递归,等于0说明是叶子节点,结束递归。
if(gain!=0){
TreeNode subnode=new TreeNode();
subnode.setParent(node);
subnode.setFatherAttribute(attname);
String nodename=attribute.get(maxKey);
subnode.setName(nodename);
subnode.setLiatts(attributevalue.get(maxKey));
node.addChild(subnode);
//不等于0,继续递归
insertNode(newlines,subnode,maxKey);
}else{
TreeNode subnode=new TreeNode();
subnode.setParent(node);
subnode.setFatherAttribute(attname);
//叶子节点是yes还是no?取新行中最后一个必是其名称,因为只有完全是yes,或完全是no的情况下才会是叶子节点
String[] line=newlines.get(0);
String nodename=line[line.length-1];
subnode.setName(nodename);
node.addChild(subnode);
}
}
}
//输出决策树
public void printDTree(TreeNode node)
{
if(node.getChildren()==null){
System.out.println("--"+node.getName());
return;
}
System.out.println(node.getName());
List<TreeNode> childs = node.getChildren();
for (int i = 0; i < childs.size(); i++)
{
System.out.println(childs.get(i).getFatherAttribute());
printDTree(childs.get(i));
}
}
public static void main(String[] args) {
// TODO Auto-generated method stub
MyID3 myid3 = new MyID3();
myid3.readARFF(new File("datafile/decisiontree/test/in/weather.nominal.arff"));
myid3.createDTree();
myid3.printDTree(root);
}
//读取arff文件,给attribute、attributevalue、data赋值
public void readARFF(File file) {
try {
FileReader fr = new FileReader(file);
BufferedReader br = new BufferedReader(fr);
String line;
Pattern pattern = Pattern.compile(patternString);
while ((line = br.readLine()) != null) {
if (line.startsWith("@decision")) {
line = br.readLine();
if(line=="")
continue;
yesNo = line.split(",");
}
Matcher matcher = pattern.matcher(line);
if (matcher.find()) {
attribute.add(matcher.group(1).trim());
String[] values = matcher.group(2).split(",");
ArrayList<String> al = new ArrayList<String>(values.length);
for (String value : values) {
al.add(value.trim());
}
attributevalue.add(al);
} else if (line.startsWith("@data")) {
while ((line = br.readLine()) != null) {
if(line=="")
continue;
String[] row = line.split(",");
data.add(row);
}
} else {
continue;
}
}
br.close();
} catch (IOException e1) {
e1.printStackTrace();
}
}
}
版权声明:本文为博主原创文章,未经博主允许不得转载。
决策树算法原理及JAVA实现(ID3)的更多相关文章
- 机器学习相关知识整理系列之一:决策树算法原理及剪枝(ID3,C4.5,CART)
决策树是一种基本的分类与回归方法.分类决策树是一种描述对实例进行分类的树形结构,决策树由结点和有向边组成.结点由两种类型,内部结点表示一个特征或属性,叶结点表示一个类. 1. 基础知识 熵 在信息学和 ...
- 决策树算法原理(ID3,C4.5)
决策树算法原理(CART分类树) CART回归树 决策树的剪枝 决策树可以作为分类算法,也可以作为回归算法,同时特别适合集成学习比如随机森林. 1. 决策树ID3算法的信息论基础 1970年昆兰找 ...
- 决策树算法原理(CART分类树)
决策树算法原理(ID3,C4.5) CART回归树 决策树的剪枝 在决策树算法原理(ID3,C4.5)中,提到C4.5的不足,比如模型是用较为复杂的熵来度量,使用了相对较为复杂的多叉树,只能处理分类不 ...
- 决策树算法原理--good blog
转载于:http://www.cnblogs.com/pinard/p/6050306.html (楼主总结的很好,就拿来主义了,不顾以后还是多像楼主学习) 决策树算法在机器学习中算是很经典的一个算法 ...
- ID3决策树算法原理及C++实现(其中代码转自别人的博客)
分类是数据挖掘中十分重要的组成部分.分类作为一种无监督学习方式被广泛的使用. 之前关于"数据挖掘中十大经典算法"中,基于ID3核心思想的分类算法C4.5榜上有名.所以不难看出ID3 ...
- 决策树算法(1)含java源代码
信息熵:变量的不确定性越大,熵越大.熵可用下面的公式描述:-(p1*logp1+p2*logp2+...+pn*logpn)pi表示事件i发生的概率ID3:GAIN(A)=INFO(D)-INFO_A ...
- scikit-learn决策树算法类库使用小结
之前对决策树的算法原理做了总结,包括决策树算法原理(上)和决策树算法原理(下).今天就从实践的角度来介绍决策树算法,主要是讲解使用scikit-learn来跑决策树算法,结果的可视化以及一些参数调参的 ...
- 决策树算法——ID3
决策树算法是一种有监督的分类学习算法.利用经验数据建立最优分类树,再用分类树预测未知数据. 例子:利用学生上课与作业状态预测考试成绩. 上述例子包含两个可以观测的属性:上课是否认真,作业是否认真,并以 ...
- python机器学习笔记 ID3决策树算法实战
前面学习了决策树的算法原理,这里继续对代码进行深入学习,并掌握ID3的算法实践过程. ID3算法是一种贪心算法,用来构造决策树,ID3算法起源于概念学习系统(CLS),以信息熵的下降速度为选取测试属性 ...
随机推荐
- 【leetcode刷题笔记】Longest Substring Without Repeating Characters
Given a string, find the length of the longest substring without repeating characters. For example, ...
- 0417 封装 property、classmethod、staricmethod
一.封装 把一堆东西装在一个容器里 函数和属性装到了一个非全局的命名空间 class A: __N = 123 # 静态变量 def func(self): print(A.__N) # 在类的内部使 ...
- codeforces 686B
题意:给出一个序列,只允许进行相邻的两两交换,给出使序列变为非降序列的操作方案. 思路:关键点是操作次数不限,冒泡排序. #include<iostream> #include<cs ...
- 侠客群控引擎二次开发SDK可用方法大全(持续更新)
如这篇文章所示 http://www.xiake.net/blog/archives/1 侠客的插件SDK能提供很强大的功能(所有官方使用的方法都有提供) 这篇文章是详细介绍所有SDK可调用的方法 首 ...
- 关于ios::sync_with_stdio(false);和 cin.tie(0)加速c++输入输出流
原文地址:http://www.hankcs.com/program/cpp/cin-tie-with-sync_with_stdio-acceleration-input-and-output.ht ...
- 异常之: The server time zone value '�й���ʱ��' is unrecognized or represents more than one time zone.
在 MySQL 中执行命令试下: set global time_zone=’+8:00’ 设置为东8区 就不报错了. show variables like '%time_zone%'; 解释:在 ...
- 实现stack 加上·getMin功能 时间复杂度为O(n)
package com.hzins.suanfa; import java.util.Stack; /** * 实现stack 加上·getMin功能 时间复杂度为O(n) * @author Adm ...
- nodejs buffer 总结
JavaScript 语言自身只有字符串数据类型,没有二进制数据类型.Buffer 类,该类用来创建一个专门存放二进制数据的缓存区. 一个 Buffer 类似于一个整数数组,但它对应于 V8 堆内存之 ...
- 关于c++中char*、char ch[]和string区别
一.字符串指针: char* ch="hello"; 这里的"hello"是字符串常量,是不可以改变的,即通过ch[0]="s"会编译出错. ...
- linux命令学习笔记( 7 ) : mv 命令
mv命令是move的缩写,可以用来移动文件或者将文件改名(move (rename) files),是Linux系统下常用的命令, 经常用来备份文件或者目录. .命令格式: mv [选项] 源文件或目 ...