说明:每个样本都会装入Data样本对象,决策树生成算法接收的是一个Array<Data>样本列表,所以构建测试数据时也要符合格式,最后生成的决策树是树的根节点,通过里面提供的showTree()方法可查看整个树结构,下面奉上源码。

Data.java

package ai.tree.data;

import java.util.HashMap;

/**
* 样本类
* @author ChenLuyang
* @date 2019/2/21
*/
public class Data implements Cloneable{
/**
* K是特征描述,V是特征值
*/
private HashMap<String,String> feature = new HashMap<String, String>(); /**
* 该样本结论
*/
private String result; public Data(HashMap<String,String> feature,String result){
this.feature = feature;
this.result = result;
} public HashMap<String, String> getFeature() {
return feature;
} public String getResult() {
return result;
} private void setFeature(HashMap<String, String> feature) {
this.feature = feature;
} @Override
public Data clone()
{
Data object=null;
try {
object = (Data) super.clone();
object.setFeature((HashMap<String, String>) this.feature.clone());
} catch (CloneNotSupportedException e) {
e.printStackTrace();
} return object;
}
}

  

DecisionTree.java

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.math.BigDecimal;
import java.util.*; /**
* @author ChenLuyang
* @date 2019/2/21
*/
public class DecisionTree {
/**
* 递归构建决策树
*
* @param dataList 样本集合
* @return ai.tree.algorithm.DecisionTree.TreeNode 使用传入样本构建的决策节点
* @author ChenLuyang
* @date 2019/2/21 16:05
*/
public TreeNode createTree(List<Data> dataList) {
//创建当前节点
TreeNode<String, String, String> nowTreeNode = new TreeNode<String, String, String>();
//当前节点的各个分支节点
Map<String, TreeNode> featureDecisionMap = new HashMap<String, TreeNode>(); //统计当前样本集中所有的分类结果
Set<String> resultSet = new HashSet<String>();
for (Data data :
dataList) {
resultSet.add(data.getResult());
} //如果当前样本集只有一种类别,则表示不用分类了,返回当前节点
if (resultSet.size() == 1) {
String resultClassify = resultSet.iterator().next(); nowTreeNode.setResultNode(resultClassify); return nowTreeNode;
} //如果数据集中特征为空,则选择整个集合中出现次数最多的分类,作为分类结果
if (dataList.get(0).getFeature().size() == 0) {
Map<String, Integer> countMap = new HashMap<String, Integer>();
for (Data data :
dataList) {
Integer num = countMap.get(data.getResult());
if (num == null) {
countMap.put(data.getResult(), 1);
} else {
countMap.put(data.getResult(), num + 1);
}
} String tmpResult = "";
Integer tmpNum = 0;
for (String res :
countMap.keySet()) {
if (countMap.get(res) > tmpNum) {
tmpNum = countMap.get(res);
tmpResult = res;
}
} nowTreeNode.setResultNode(tmpResult); return nowTreeNode;
} //寻找当前最优分类
String bestLabel = chooseBestFeatureToSplit(dataList); //提取最优特征的所有可能值
Set<String> bestLabelInfoSet = new HashSet<String>();
for (Data data :
dataList) {
bestLabelInfoSet.add(data.getFeature().get(bestLabel));
} //使用最优特征的各个特征值进行分类
for (String labelInfo :
bestLabelInfoSet) {
for (Data data :
dataList) {
}
List<Data> branchDataList = splitDataList(dataList, bestLabel, labelInfo); //最优特征下该特征值的节点
TreeNode branchTreeNode = createTree(branchDataList);
featureDecisionMap.put(labelInfo, branchTreeNode);
} nowTreeNode.setDecisionNode(bestLabel, featureDecisionMap); return nowTreeNode;
} /**
* 计算传入数据集中的最优分类特征
*
* @param dataList
* @return int 最优分类特征的描述
* @author ChenLuyang
* @date 2019/2/21 14:12
*/
public String chooseBestFeatureToSplit(List<Data> dataList) {
//目前数据集中的特征集合
Set<String> futureSet = dataList.get(0).getFeature().keySet(); //未分类时的熵
BigDecimal baseEntropy = calcShannonEnt(dataList); //熵差
BigDecimal bestInfoGain = new BigDecimal("0");
//最优特征
String bestFeature = ""; //按照各特征分类
for (String future :
futureSet) {
//该特征分类后的熵
BigDecimal futureEntropy = new BigDecimal("0"); //该特征的所有特征值去重集合
Set<String> futureInfoSet = new HashSet<String>();
for (Data data :
dataList) {
futureInfoSet.add(data.getFeature().get(future));
} //按照该特征的特征值一一分类
for (String futureInfo :
futureInfoSet) {
List<Data> splitResultDataList = splitDataList(dataList, future, futureInfo); //分类后样本数占总样本数的比例
BigDecimal tmpProb = new BigDecimal(splitResultDataList.size() + "").divide(new BigDecimal(dataList.size() + ""), 5, BigDecimal.ROUND_HALF_DOWN); //所占比例乘以分类后的样本熵,然后再进行熵的累加
futureEntropy = futureEntropy.add(tmpProb.multiply(calcShannonEnt(splitResultDataList)));
} BigDecimal subEntropy = baseEntropy.subtract(futureEntropy); if (subEntropy.compareTo(bestInfoGain) >= 0) {
bestInfoGain = subEntropy;
bestFeature = future;
}
} return bestFeature;
} /**
* 计算传入样本集的熵值
*
* @param dataList 样本集
* @return java.math.BigDecimal 熵
* @author ChenLuyang
* @date 2019/2/22 9:41
*/
public BigDecimal calcShannonEnt(List<Data> dataList) {
//样本总数
BigDecimal sumEntries = new BigDecimal(dataList.size() + "");
//香农熵
BigDecimal shannonEnt = new BigDecimal("0");
//统计各个分类结果的样本数量
Map<String, Integer> resultCountMap = new HashMap<String, Integer>();
for (Data data :
dataList) {
Integer dataResultCount = resultCountMap.get(data.getResult());
if (dataResultCount == null) {
resultCountMap.put(data.getResult(), 1);
} else {
resultCountMap.put(data.getResult(), dataResultCount + 1);
}
} for (String resultCountKey :
resultCountMap.keySet()) {
BigDecimal resultCountValue = new BigDecimal(resultCountMap.get(resultCountKey).toString()); BigDecimal prob = resultCountValue.divide(sumEntries, 5, BigDecimal.ROUND_HALF_DOWN);
shannonEnt = shannonEnt.subtract(prob.multiply(new BigDecimal(Math.log(prob.doubleValue()) / Math.log(2) + "")));
} return shannonEnt;
} /**
* 根据某个特征的特征值,进行样本数据的划分,将划分后的样本数据集返回
*
* @param dataList 待划分的样本数据集
* @param future 筛选的特征依据
* @param info 筛选的特征值依据
* @return java.util.List<ai.tree.data.Data> 按照指定特征值分类后的数据集
* @author ChenLuyang
* @date 2019/2/21 18:26
*/
public List<Data> splitDataList(List<Data> dataList, String future, String info) {
List<Data> resultDataList = new ArrayList<Data>();
for (Data data :
dataList) {
if (data.getFeature().get(future).equals(info)) {
Data newData = (Data) data.clone();
newData.getFeature().remove(future);
resultDataList.add(newData);
}
} return resultDataList;
} /**
* L:每一个特征的描述信息的类型
* F:特征的类型
* R:最终分类结果的类型
*/
public class TreeNode<L, F, R> {
/**
* 该节点的最优特征的描述信息
*/
private L label; /**
* 根据不同的特征作出响应的决定。
* K为特征值,V为该特征值作出的决策节点
*/
private Map<F, TreeNode> featureDecisionMap; /**
* 是否为最终分类节点
*/
private boolean isFinal; /**
* 最终分类结果信息
*/
private R resultClassify; /**
* 设置叶子节点
*
* @param resultClassify 最终分类结果
* @return void
* @author ChenLuyang
* @date 2019/2/22 18:31
*/
public void setResultNode(R resultClassify) {
this.isFinal = true;
this.resultClassify = resultClassify;
} /**
* 设置分支节点
*
* @param label 当前分支节点的描述信息(特征)
* @param featureDecisionMap 当前分支节点的各个特征值,与其对应的子节点
* @return void
* @author ChenLuyang
* @date 2019/2/22 18:31
*/
public void setDecisionNode(L label, Map<F, TreeNode> featureDecisionMap) {
this.isFinal = false;
this.label = label;
this.featureDecisionMap = featureDecisionMap;
} /**
* 展示当前节点的树结构
*
* @return void
* @author ChenLuyang
* @date 2019/2/22 16:54
*/
public String showTree() {
HashMap<String, String> treeMap = new HashMap<String, String>();
if (isFinal) {
String key = "result";
R value = resultClassify;
treeMap.put(key, value.toString());
} else {
String key = label.toString();
HashMap<F, String> showFutureMap = new HashMap<F, String>();
for (F f :
featureDecisionMap.keySet()) {
showFutureMap.put(f, featureDecisionMap.get(f).showTree());
}
String value = showFutureMap.toString(); treeMap.put(key, value);
} return treeMap.toString();
} public L getLabel() {
return label;
} public Map<F, TreeNode> getFeatureDecisionMap() {
return featureDecisionMap;
} public R getResultClassify() {
return resultClassify;
} public boolean getFinal() {
return isFinal;
}
}
}

  

Start.java

package ai.tree.algorithm;

import ai.tree.data.Data;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List; /**
* @author ChenLuyang
* @date 2019/2/22
*/
public class Start {
/**
* 构建测试样本集,测试样本如下:
样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
样本特征:{头发长短=短发, 身材=胖, 是否戴眼镜=有眼镜} 分类:女
样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=没眼镜} 分类:男
样本特征:{头发长短=短发, 身材=瘦, 是否戴眼镜=没眼镜} 分类:男
样本特征:{头发长短=长发, 身材=瘦, 是否戴眼镜=有眼镜} 分类:女
样本特征:{头发长短=长发, 身材=胖, 是否戴眼镜=有眼镜} 分类:男
* @author ChenLuyang
* @date 2019/2/21 15:34
* @return java.util.List<ai.tree.data.DecisionTreeTestData.Data> 样本集
*/
public static List<Data> createDataList(){
/**
* 样本特征描述
* @author ChenLuyang
* @date 2019/2/22 18:55
* @return java.util.List<ai.tree.data.Data>
*/
String[] labels = new String[]{"是否戴眼镜", "头发长短", "身材"}; List<Data> dataList = new ArrayList<Data>(); HashMap<String,String> feature1 = new HashMap<String, String>();
feature1.put(labels[0],"有眼镜");
feature1.put(labels[1].toString(),"短发");
feature1.put(labels[2].toString(),"胖");
dataList.add(new Data(feature1,"男")); HashMap<String,String> feature2 = new HashMap<String, String>();
feature2.put(labels[0],"有眼镜");
feature2.put(labels[1],"长发");
feature2.put(labels[2],"瘦");
dataList.add(new Data(feature2,"女")); HashMap<String,String> feature3 = new HashMap<String, String>();
feature3.put(labels[0],"有眼镜");
feature3.put(labels[1],"短发");
feature3.put(labels[2],"胖");
dataList.add(new Data(feature3,"女")); HashMap<String,String> feature4 = new HashMap<String, String>();
feature4.put(labels[0],"没眼镜");
feature4.put(labels[1],"长发");
feature4.put(labels[2],"胖");
dataList.add(new Data(feature4,"男")); HashMap<String,String> feature5 = new HashMap<String, String>();
feature5.put(labels[0],"没眼镜");
feature5.put(labels[1],"短发");
feature5.put(labels[2],"瘦");
dataList.add(new Data(feature5,"男")); HashMap<String,String> feature6 = new HashMap<String, String>();
feature6.put(labels[0],"有眼镜");
feature6.put(labels[1],"长发");
feature6.put(labels[2],"瘦");
dataList.add(new Data(feature6,"女")); HashMap<String,String> feature7 = new HashMap<String, String>();
feature7.put(labels[0],"有眼镜");
feature7.put(labels[1],"长发");
feature7.put(labels[2],"胖");
dataList.add(new Data(feature7,"男")); return dataList;
} public static void main(String[] args) {
DecisionTree decisionTree = new DecisionTree(); //使用测试样本生成决策树
DecisionTree.TreeNode tree = decisionTree.createTree(createDataList()); //展示决策树
System.out.println(tree.showTree());
}
}

  

生成树结构:{是否戴眼镜={没眼镜={result=男}, 有眼镜={身材={胖={头发长短={长发={result=男}, 短发={result=女}}}, 瘦={result=女}}}}}

java编写ID3决策树的更多相关文章

  1. ID3决策树预测的java实现

    刚才写了ID3决策树的建立,这个是通过决策树来进行预测.这里主要用到的就是XML的遍历解析,比较简单. 关于xml的解析,参考了: http://blog.csdn.net/soszou/articl ...

  2. 网页动物园2.0发布,经过几个月的努力,采用JAVA编写!

    网页动物园2.0发布,经过几个月的努力,采用JAVA编写! 网页动物园2.0 正式发布!游戏发布 游戏名称: 网页动物园插件 游戏来源: 原创插件 适用版本: Discuz! X1.5 - X3.5 ...

  3. 使用Java编写一个简单的Web的监控系统cpu利用率,cpu温度,总内存大小

    原文:http://www.jb51.net/article/75002.htm 这篇文章主要介绍了使用Java编写一个简单的Web的监控系统的例子,并且将重要信息转为XML通过网页前端显示,非常之实 ...

  4. java 编写hadoop程序中使用第三方libxx.so库

    在使用java编写hadoop处理程序时遇到了,java使用依赖的第三方libxx.so库的情况,找到了一种可行的方法,记录一下,希望对别人也有帮助: 加入需要使用的lib库为libxxx.so 1. ...

  5. 如何用Java编写一段代码引发内存泄露

    本文来自StackOverflow问答网站的一个热门讨论:如何用Java编写一段会发生内存泄露的代码. Q:刚才我参加了面试,面试官问我如何写出会发生内存泄露的Java代码.这个问题我一点思路都没有, ...

  6. Java编写的C语言词法分析器

    Java编写的C语言词法分析器 这是java编写的C语言词法分析器,我也是参考很多代码,然后核心代码整理起来,放在QQ空间和博客上,目的是互相学习借鉴,希望可以得到高手改进.这个词法分析器实现的功能有 ...

  7. delphi调用java编写的webservice

    delphi调用java编写的webservice JAVApojo: public class GroupInfo implements Serializable{    private stati ...

  8. 实战WEB 服务器(JAVA编写WEB服务器)

    实战WEB 服务器(JAVA编写WEB服务器) 标签: web服务服务器javawebsockethttp服务器 2010-04-21 17:09 11631人阅读 评论(24) 收藏 举报  分类: ...

  9. Java 编写小程序,下载指定网页上的所有图片

    使用Java编写一个小程序,可以根据指定的网页地址,下载网页中的所有图片:使用到网络编程.线程池.IO和UUID的技术.具体代码如下: import java.io.File; import java ...

随机推荐

  1. SharePoint2007使用WebPart加载UserControl

    之前一直做SharePoint2010开发,最近转向了2007开发,感觉两者开发时有很多地方不一样,我现在接触到2007开发项目里面使用Module去加载Application Page,而在Appl ...

  2. bzoj1452 [JSOI2009]Count ——二维树状数组

    中文题面,给你一个矩阵,每一个格子有数字,有两种操作. 1. 把i行j列的值更改 2. 询问两个角坐标分别为(x1,y1) (x2,y2)的矩形内有几个值为z的点. 这一题的特点就是给出的z的数据范围 ...

  3. HDU 5616 Jam's balance(Jam的天平)

    HDU 5616 Jam's balance(Jam的天平) Time Limit: 2000/1000 MS (Java/Others)    Memory Limit: 65536/65536 K ...

  4. python对Excel表格操作

    操作场景,给一个Excel表格随机生成10万个手机号码 python中常见的对Excel操作模块 xlwt module 将数据写入Excel表 xlrd module 读取Excel表格 xlsxw ...

  5. Go-单元测试

        文章转载地址:https://www.flysnow.org/2017/05/16/go-in-action-go-unit-test.html 什么是单元测试?      单元测试一般用来测 ...

  6. ES6解构过程添加一个默认值和赋值一个新的值

    const info = { name: 'xiaobe', } const { name: nickName = '未知' } = info; 其中nickName是解构过程中新声明的一个变量,并且 ...

  7. 代码生成器——实现生成pojo,sql,mapper接口

    代码生成器(记录一次兴趣代码,多多指教.转载请标明作者) 在我们开始实现代码生成器之前我们先来对代码生成器有一个简单的了解. 1.什么是代码生成器? 故名思义,也就是生成代码的一个程序.那它是一个什么 ...

  8. js及jsp区别

  9. Halcon 标定与准确测量

  10. 学习笔记74—函数argsort()

    ****************************************************** 如有谬误,请联系指正.转载请注明出处. 联系方式: e-mail: heyi9069@gm ...