Today, I want to show how I use Thomas Lin Pederson’s awesome ggraph package to plot decision trees from Random Forest models.

I am very much a visual person, so I try to plot as much of my results as possible because it helps me get a better feel for what is going on with my data.

A nice aspect of using tree-based machine learning, like Random Forest models, is that that they are more easily interpreted than e.g. neural networks as they are based on decision trees. So, when I am using such models, I like to plot final decision trees (if they aren’t too large) to get a sense of which decisions are underlying my predictions.

There are a few very convient ways to plot the outcome if you are using the randomForest package but I like to have as much control as possible about the layout, colors, labels, etc. And because I didn’t find a solution I liked for caret models, I developed the following little function (below you may find information about how I built the model):

As input, it takes part of the output from model_rf <- caret::train(... "rf" ...), that gives the trees of the final model: model_rf$finalModel$forest. From these trees, you can specify which one to plot by index.

library(dplyr)
library(ggraph)
library(igraph) tree_func <- function(final_model,
tree_num) { # get tree by index
tree <- randomForest::getTree(final_model,
k = tree_num,
labelVar = TRUE) %>%
tibble::rownames_to_column() %>%
# make leaf split points to NA, so the 0s won't get plotted
mutate(`split point` = ifelse(is.na(prediction), `split point`, NA)) # prepare data frame for graph
graph_frame <- data.frame(from = rep(tree$rowname, 2),
to = c(tree$`left daughter`, tree$`right daughter`)) # convert to graph and delete the last node that we don't want to plot
graph <- graph_from_data_frame(graph_frame) %>%
delete_vertices("0") # set node labels
V(graph)$node_label <- gsub("_", " ", as.character(tree$`split var`))
V(graph)$leaf_label <- as.character(tree$prediction)
V(graph)$split <- as.character(round(tree$`split point`, digits = 2)) # plot
plot <- ggraph(graph, 'dendrogram') +
theme_bw() +
geom_edge_link() +
geom_node_point() +
geom_node_text(aes(label = node_label), na.rm = TRUE, repel = TRUE) +
geom_node_label(aes(label = split), vjust = 2.5, na.rm = TRUE, fill = "white") +
geom_node_label(aes(label = leaf_label, fill = leaf_label), na.rm = TRUE,
repel = TRUE, colour = "white", fontface = "bold", show.legend = FALSE) +
theme(panel.grid.minor = element_blank(),
panel.grid.major = element_blank(),
panel.background = element_blank(),
plot.background = element_rect(fill = "white"),
panel.border = element_blank(),
axis.line = element_blank(),
axis.text.x = element_blank(),
axis.text.y = element_blank(),
axis.ticks = element_blank(),
axis.title.x = element_blank(),
axis.title.y = element_blank(),
plot.title = element_text(size = 18)) print(plot)
}

We can now plot, e.g. the tree with the smalles number of nodes:

tree_num <- which(model_rf$finalModel$forest$ndbigtree == min(model_rf$finalModel$forest$ndbigtree))

tree_func(final_model = model_rf$finalModel, tree_num)

Or we can plot the tree with the biggest number of nodes:

tree_num <- which(model_rf$finalModel$forest$ndbigtree == max(model_rf$finalModel$forest$ndbigtree))

tree_func(final_model = model_rf$finalModel, tree_num)


Preparing the data and modeling

The data set I am using in these example analyses, is the Breast Cancer Wisconsin (Diagnostic) Dataset. The data was downloaded from the UC Irvine Machine Learning Repository.

The first data set looks at the predictor classes:

  • malignant or
  • benign breast mass.

The features characterize cell nucleus properties and were generated from image analysis of fine needle aspirates (FNA) of breast masses:

  • Sample ID (code number)
  • Clump thickness
  • Uniformity of cell size
  • Uniformity of cell shape
  • Marginal adhesion
  • Single epithelial cell size
  • Number of bare nuclei
  • Bland chromatin
  • Number of normal nuclei
  • Mitosis
  • Classes, i.e. diagnosis
bc_data <- read.table("datasets/breast-cancer-wisconsin.data.txt", header = FALSE, sep = ",")
colnames(bc_data) <- c("sample_code_number",
"clump_thickness",
"uniformity_of_cell_size",
"uniformity_of_cell_shape",
"marginal_adhesion",
"single_epithelial_cell_size",
"bare_nuclei",
"bland_chromatin",
"normal_nucleoli",
"mitosis",
"classes") bc_data$classes <- ifelse(bc_data$classes == "2", "benign",
ifelse(bc_data$classes == "4", "malignant", NA)) bc_data[bc_data == "?"] <- NA # impute missing data
library(mice) bc_data[,2:10] <- apply(bc_data[, 2:10], 2, function(x) as.numeric(as.character(x)))
dataset_impute <- mice(bc_data[, 2:10], print = FALSE)
bc_data <- cbind(bc_data[, 11, drop = FALSE], mice::complete(dataset_impute, 1)) bc_data$classes <- as.factor(bc_data$classes) # how many benign and malignant cases are there?
summary(bc_data$classes) # separate into training and test data
library(caret) set.seed(42)
index <- createDataPartition(bc_data$classes, p = 0.7, list = FALSE)
train_data <- bc_data[index, ]
test_data <- bc_data[-index, ] # run model
set.seed(42)
model_rf <- caret::train(classes ~ .,
data = train_data,
method = "rf",
preProcess = c("scale", "center"),
trControl = trainControl(method = "repeatedcv",
number = 10,
repeats = 10,
savePredictions = TRUE,
verboseIter = FALSE))

If you are interested in more machine learning posts, check out the category listing for machine_learning on my blog.


sessionInfo()
## R version 3.3.3 (2017-03-06)
## Platform: x86_64-w64-mingw32/x64 (64-bit)
## Running under: Windows 7 x64 (build 7601) Service Pack 1
##
## locale:
## [1] LC_COLLATE=English_United States.1252
## [2] LC_CTYPE=English_United States.1252
## [3] LC_MONETARY=English_United States.1252
## [4] LC_NUMERIC=C
## [5] LC_TIME=English_United States.1252
##
## attached base packages:
## [1] stats graphics grDevices utils datasets methods base
##
## other attached packages:
## [1] igraph_1.0.1 ggraph_1.0.0 ggplot2_2.2.1.9000
## [4] dplyr_0.5.0
##
## loaded via a namespace (and not attached):
## [1] Rcpp_0.12.9 nloptr_1.0.4 plyr_1.8.4
## [4] viridis_0.3.4 iterators_1.0.8 tools_3.3.3
## [7] digest_0.6.12 lme4_1.1-12 evaluate_0.10
## [10] tibble_1.2 gtable_0.2.0 nlme_3.1-131
## [13] lattice_0.20-34 mgcv_1.8-17 Matrix_1.2-8
## [16] foreach_1.4.3 DBI_0.6 ggrepel_0.6.5
## [19] yaml_2.1.14 parallel_3.3.3 SparseM_1.76
## [22] gridExtra_2.2.1 stringr_1.2.0 knitr_1.15.1
## [25] MatrixModels_0.4-1 stats4_3.3.3 rprojroot_1.2
## [28] grid_3.3.3 caret_6.0-73 nnet_7.3-12
## [31] R6_2.2.0 rmarkdown_1.3 minqa_1.2.4
## [34] udunits2_0.13 tweenr_0.1.5 deldir_0.1-12
## [37] reshape2_1.4.2 car_2.1-4 magrittr_1.5
## [40] units_0.4-2 backports_1.0.5 scales_0.4.1
## [43] codetools_0.2-15 ModelMetrics_1.1.0 htmltools_0.3.5
## [46] MASS_7.3-45 splines_3.3.3 randomForest_4.6-12
## [49] assertthat_0.1 pbkrtest_0.4-6 ggforce_0.1.1
## [52] colorspace_1.3-2 labeling_0.3 quantreg_5.29
## [55] stringi_1.1.2 lazyeval_0.2.0 munsell_0.4.3

转自:https://shiring.github.io/machine_learning/2017/03/16/rf_plot_ggraph

Plotting trees from Random Forest models with ggraph的更多相关文章

  1. 机器学习算法 --- Pruning (decision trees) & Random Forest Algorithm

    一.Table for Content 在之前的文章中我们介绍了Decision Trees Agorithms,然而这个学习算法有一个很大的弊端,就是很容易出现Overfitting,为了解决此问题 ...

  2. Random Forest And Extra Trees

    随机森林 我们对使用决策树随机取样的集成学习有个形象的名字–随机森林. scikit-learn 中封装的随机森林,在决策树的节点划分上,在随机的特征子集上寻找最优划分特征. import numpy ...

  3. Random Forest Classification of Mushrooms

    There is a plethora of classification algorithms available to people who have a bit of coding experi ...

  4. 机器学习方法(六):随机森林Random Forest,bagging

    欢迎转载,转载请注明:本文出自Bin的专栏blog.csdn.net/xbinworld. 技术交流QQ群:433250724,欢迎对算法.技术感兴趣的同学加入. 前面机器学习方法(四)决策树讲了经典 ...

  5. [Machine Learning & Algorithm] 随机森林(Random Forest)

    1 什么是随机森林? 作为新兴起的.高度灵活的一种机器学习算法,随机森林(Random Forest,简称RF)拥有广泛的应用前景,从市场营销到医疗保健保险,既可以用来做市场营销模拟的建模,统计客户来 ...

  6. paper 85:机器统计学习方法——CART, Bagging, Random Forest, Boosting

    本文从统计学角度讲解了CART(Classification And Regression Tree), Bagging(bootstrap aggregation), Random Forest B ...

  7. 多分类问题中,实现不同分类区域颜色填充的MATLAB代码(demo:Random Forest)

    之前建立了一个SVM-based Ordinal regression模型,一种特殊的多分类模型,就想通过可视化的方式展示模型分类的效果,对各个分类区域用不同颜色表示.可是,也看了很多代码,但基本都是 ...

  8. 统计学习方法——CART, Bagging, Random Forest, Boosting

    本文从统计学角度讲解了CART(Classification And Regression Tree), Bagging(bootstrap aggregation), Random Forest B ...

  9. sklearn_随机森林random forest原理_乳腺癌分类器建模(推荐AAA)

     sklearn实战-乳腺癌细胞数据挖掘(博主亲自录制视频) https://study.163.com/course/introduction.htm?courseId=1005269003& ...

随机推荐

  1. salesforce 零基础学习(七十)使用jquery table实现树形结构模式

    项目中UI需要用到树形结构显示内容,后来尽管不需要做了,不过还是自己做着玩玩,mark一下,免得以后项目中用到. 实现树形结构在此使用的是jquery的dynatree.js.关于dynatree的使 ...

  2. Linux上常用的文件传输方式以及比较

    tp ftp 命令使用文件传输协议(File Transfer Protocol, FTP)在本地主机和远程主机之间或者在两个远程主机之间进行文件传输. FTP 协议允许数据在不同文件系统的主机之间传 ...

  3. CF #CROC 2016 - Elimination Round D. Robot Rapping Results Report 二分+拓扑排序

    题目链接:http://codeforces.com/contest/655/problem/D 大意是给若干对偏序,问最少需要前多少对关系,可以确定所有的大小关系. 解法是二分答案,利用拓扑排序看是 ...

  4. CF #93 div1 B. Password KMP/Z

    题目链接:http://codeforces.com/problemset/problem/126/B 大意:给一个字符串,问最长的既是前缀又是后缀又是中缀(这里指在内部出现)的子串. 我自己的做法是 ...

  5. Go - 第一个 go 程序 -- helloworld

    创建程序目录 接着上一节的内容,在我们的workspace (D:\Gopher) 里面创建子目录 hello,他的绝对路径为:D:\Gopher\src\github.com\tuo\hello 创 ...

  6. Google Earth影像数据破解之旅

    "Zed, you are so excellent." 为什么要写这句英文?容我卖个关子稍后再解释. 相信大多数人都体验过Google Earth(简称GE),我对GE最初的印象 ...

  7. AspNet.OData 协议概述

    1. 什么是OData? OData 全称 Open Data Protocol,字面理解为开放数据协议,是一个基于Http协议且API实现为Restful风格的协议标准,目前由微软支持大力推广,你可 ...

  8. python字符串实战

    haproxy配置文件 思路:读一行,写一行 global log 127.0.0.1 local2 daemon maxconn 256 log 127.0.0.1 local2 info defa ...

  9. [ext4]空间管理 - 分配机制

     在Ext4系统中,存在很多分配策略,比如预分配.多块分配.延迟分配等   Prealloc预分配 在ext4系统中,对于小文件和大文件的空间申请请求,都有不同的分配策略.对用小文件的空间请求,e ...

  10. Linux系统OOM killer机制详解

    介绍: Linux下面有个特性叫OOM killer(Out Of Memory killer),会在系统内存耗尽的情况下出现,选择性的干掉一些进程以求释放一些内存.广大从事Linux方面的IT农民工 ...