Today, I want to show how I use Thomas Lin Pederson’s awesome ggraph package to plot decision trees from Random Forest models.

I am very much a visual person, so I try to plot as much of my results as possible because it helps me get a better feel for what is going on with my data.

A nice aspect of using tree-based machine learning, like Random Forest models, is that that they are more easily interpreted than e.g. neural networks as they are based on decision trees. So, when I am using such models, I like to plot final decision trees (if they aren’t too large) to get a sense of which decisions are underlying my predictions.

There are a few very convient ways to plot the outcome if you are using the randomForest package but I like to have as much control as possible about the layout, colors, labels, etc. And because I didn’t find a solution I liked for caret models, I developed the following little function (below you may find information about how I built the model):

As input, it takes part of the output from model_rf <- caret::train(... "rf" ...), that gives the trees of the final model: model_rf$finalModel$forest. From these trees, you can specify which one to plot by index.

library(dplyr)
library(ggraph)
library(igraph) tree_func <- function(final_model,
tree_num) { # get tree by index
tree <- randomForest::getTree(final_model,
k = tree_num,
labelVar = TRUE) %>%
tibble::rownames_to_column() %>%
# make leaf split points to NA, so the 0s won't get plotted
mutate(`split point` = ifelse(is.na(prediction), `split point`, NA)) # prepare data frame for graph
graph_frame <- data.frame(from = rep(tree$rowname, 2),
to = c(tree$`left daughter`, tree$`right daughter`)) # convert to graph and delete the last node that we don't want to plot
graph <- graph_from_data_frame(graph_frame) %>%
delete_vertices("0") # set node labels
V(graph)$node_label <- gsub("_", " ", as.character(tree$`split var`))
V(graph)$leaf_label <- as.character(tree$prediction)
V(graph)$split <- as.character(round(tree$`split point`, digits = 2)) # plot
plot <- ggraph(graph, 'dendrogram') +
theme_bw() +
geom_edge_link() +
geom_node_point() +
geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE,
repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
theme(panel.grid.minor = element_blank(),
panel.grid.major = element_blank(),
panel.background = element_blank(),
plot.background = element_rect(fill = "white"),
panel.border = element_blank(),
axis.line = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.ticks = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
plot.title = element_text(size = 18)) print(plot)
}

We can now plot, e.g. the tree with the smalles number of nodes:

tree_num <- which(model_rf$finalModel$forest$ndbigtree == min(model_rf$finalModel$forest$ndbigtree))

tree_func(final_model = model_rf$finalModel, tree_num)

Or we can plot the tree with the biggest number of nodes:

tree_num <- which(model_rf$finalModel$forest$ndbigtree == max(model_rf$finalModel$forest$ndbigtree))

tree_func(final_model = model_rf$finalModel, tree_num)


Preparing the data and modeling

The data set I am using in these example analyses, is the Breast Cancer Wisconsin (Diagnostic) Dataset. The data was downloaded from the UC Irvine Machine Learning Repository.

The first data set looks at the predictor classes:

  • malignant or
  • benign breast mass.

The features characterize cell nucleus properties and were generated from image analysis of fine needle aspirates (FNA) of breast masses:

  • Sample ID (code number)
  • Clump thickness
  • Uniformity of cell size
  • Uniformity of cell shape
  • Marginal adhesion
  • Single epithelial cell size
  • Number of bare nuclei
  • Bland chromatin
  • Number of normal nuclei
  • Mitosis
  • Classes, i.e. diagnosis
bc_data <- read.table("datasets/breast-cancer-wisconsin.data.txt", header = FALSE, sep = ",")
colnames(bc_data) <- c("sample_code_number",
"clump_thickness",
"uniformity_of_cell_size",
"uniformity_of_cell_shape",
"marginal_adhesion",
"single_epithelial_cell_size",
"bare_nuclei",
"bland_chromatin",
"normal_nucleoli",
"mitosis",
"classes") bc_data$classes <- ifelse(bc_data$classes == "2", "benign",
ifelse(bc_data$classes == "4", "malignant", NA)) bc_data[bc_data == "?"] <- NA # impute missing data
library(mice) bc_data[,2:10] <- apply(bc_data[, 2:10], 2, function(x) as.numeric(as.character(x)))
dataset_impute <- mice(bc_data[, 2:10], print = FALSE)
bc_data <- cbind(bc_data[, 11, drop = FALSE], mice::complete(dataset_impute, 1)) bc_data$classes <- as.factor(bc_data$classes) # how many benign and malignant cases are there?
summary(bc_data$classes) # separate into training and test data
library(caret) set.seed(42)
index <- createDataPartition(bc_data$classes, p = 0.7, list = FALSE)
train_data <- bc_data[index, ]
test_data <- bc_data[-index, ] # run model
set.seed(42)
model_rf <- caret::train(classes ~ .,
data = train_data,
method = "rf",
preProcess = c("scale", "center"),
trControl = trainControl(method = "repeatedcv",
number = 10,
repeats = 10,
savePredictions = TRUE,
verboseIter = FALSE))

If you are interested in more machine learning posts, check out the category listing for machine_learning on my blog.


sessionInfo()
## R version 3.3.3 (2017-03-06)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
##
## locale:
## [1] LC_COLLATE=English_United States.1252
## [2] LC_CTYPE=English_United States.1252
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C
## [5] LC_TIME=English_United States.1252
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] igraph_1.0.1 ggraph_1.0.0 ggplot2_2.2.1.9000
## [4] dplyr_0.5.0
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.9 nloptr_1.0.4 plyr_1.8.4
## [4] viridis_0.3.4 iterators_1.0.8 tools_3.3.3
## [7] digest_0.6.12 lme4_1.1-12 evaluate_0.10
## [10] tibble_1.2 gtable_0.2.0 nlme_3.1-131
## [13] lattice_0.20-34 mgcv_1.8-17 Matrix_1.2-8
## [16] foreach_1.4.3 DBI_0.6 ggrepel_0.6.5
## [19] yaml_2.1.14 parallel_3.3.3 SparseM_1.76
## [22] gridExtra_2.2.1 stringr_1.2.0 knitr_1.15.1
## [25] MatrixModels_0.4-1 stats4_3.3.3 rprojroot_1.2
## [28] grid_3.3.3 caret_6.0-73 nnet_7.3-12
## [31] R6_2.2.0 rmarkdown_1.3 minqa_1.2.4
## [34] udunits2_0.13 tweenr_0.1.5 deldir_0.1-12
## [37] reshape2_1.4.2 car_2.1-4 magrittr_1.5
## [40] units_0.4-2 backports_1.0.5 scales_0.4.1
## [43] codetools_0.2-15 ModelMetrics_1.1.0 htmltools_0.3.5
## [46] MASS_7.3-45 splines_3.3.3 randomForest_4.6-12
## [49] assertthat_0.1 pbkrtest_0.4-6 ggforce_0.1.1
## [52] colorspace_1.3-2 labeling_0.3 quantreg_5.29
## [55] stringi_1.1.2 lazyeval_0.2.0 munsell_0.4.3

转自:https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph

Plotting trees from Random Forest models with ggraph的更多相关文章

  1. 机器学习算法 --- Pruning (decision trees) & Random Forest Algorithm

    一.Table for Content 在之前的文章中我们介绍了Decision Trees Agorithms,然而这个学习算法有一个很大的弊端,就是很容易出现Overfitting,为了解决此问题 ...

  2. Random Forest And Extra Trees

    随机森林 我们对使用决策树随机取样的集成学习有个形象的名字–随机森林. scikit-learn 中封装的随机森林,在决策树的节点划分上,在随机的特征子集上寻找最优划分特征. import numpy ...

  3. Random Forest Classification of Mushrooms

    There is a plethora of classification algorithms available to people who have a bit of coding experi ...

  4. 机器学习方法(六):随机森林Random Forest,bagging

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 前面机器学习方法(四)决策树讲了经典 ...

  5. [Machine Learning & Algorithm] 随机森林(Random Forest)

    1 什么是随机森林? 作为新兴起的.高度灵活的一种机器学习算法,随机森林(Random Forest,简称RF)拥有广泛的应用前景,从市场营销到医疗保健保险,既可以用来做市场营销模拟的建模,统计客户来 ...

  6. paper 85:机器统计学习方法——CART, Bagging, Random Forest, Boosting

    本文从统计学角度讲解了CART(Classification And Regression Tree), Bagging(bootstrap aggregation), Random Forest B ...

  7. 多分类问题中,实现不同分类区域颜色填充的MATLAB代码(demo:Random Forest)

    之前建立了一个SVM-based Ordinal regression模型,一种特殊的多分类模型,就想通过可视化的方式展示模型分类的效果,对各个分类区域用不同颜色表示.可是,也看了很多代码,但基本都是 ...

  8. 统计学习方法——CART, Bagging, Random Forest, Boosting

    本文从统计学角度讲解了CART(Classification And Regression Tree), Bagging(bootstrap aggregation), Random Forest B ...

  9. sklearn_随机森林random forest原理_乳腺癌分类器建模(推荐AAA)

     sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003& ...

随机推荐

  1. css常用技巧集合

    1 不想让按钮touch时有蓝色的边框或半透明灰色遮罩(根据系统而定) /*解决方式一*/ -webkit-tap-highlight-color:rgba(0,0,0,0); -webkit-use ...

  2. 生成订单:三个表(Products,Orders,OrderItem)

    1.有三个表(Product上,Orders,OrderItem) 分别创建对应的三个实体类 OrderItem中有外键Order_id 参考Orders中的id :Product_id参考Produ ...

  3. 第九章 Criteria查询及注解

    第九章   Criteria查询及注解9.1 使用Criteria查询数据    9.1.1 条件查询        Criteria查询步骤:            1)使用session接口的cr ...

  4. 利用 Forcing InnoDB Recovery 特性解决 MySQL 重启失败的问题

    小明同学在本机上安装了 MySQL 5.7.17 配合项目进行开发,并且已经有了一部分重要数据.某天小明在开发的时候,需要出去一趟就直接把电脑关掉了,没有让 MySQL 正常关闭,重启 MySQL 的 ...

  5. 百度UEditor图片上传或文件上传路径自定义

    最近在项目中使用到百度UEditor的图片以及文件上传功能,但在上传的时候路径总是按照预设规则来自动生成,不方便一些特殊文件的维护.于是开始查看文档和源代码,其实操作还是比较简单的,具体如下: 1.百 ...

  6. border-raduis 在IE8中的兼容性问题

    border-raduis 是css3新增加的属性,我们运用起来也很方便,效果很好,但是在IE8以及之前的ie版本并不支持这个属性,怎么解决这个问题呢? 1.切成背景 这也是我经常用到的方法,虽然说有 ...

  7. 防止微信浏览器video标签全屏的问题

    在微信浏览器里面使用video标签,会自动变成全屏,改成下面就好了,起码可以在video标签之上加入其他元素. <video id="videoID" webkit-play ...

  8. ETL开发面试问题加吐槽加职业发展建议

    写在前面: 作为甲方,对于乙方派来的开发人员,我是会自己面一下.总体来说遇到的水平不一,于是经过这三年多的面(cui)试(can),总结了一套自己的面试套路,中间也遇到过很多想吐槽的东西,于是大概记录 ...

  9. scp 命令快速使用讲解

    在 Linux 下使用 scp 命令 scp 是安全拷贝协议(Secure Copy Protocol)的缩写,和众多 Linux/Unix 使用者所熟知的拷贝(cp)命令一样.scp 的使用方式类似 ...

  10. Github--账号重新申请与配置

    2017-04-24 最近洗心革面痛下决心要好好再深入学习一番前端,正好加入了一个外包团队接了份单子,外包项目正在如火如荼地进行着,自己也打算趁这个机会来好好学习总结一番. 但是俗话说得好," ...