广义随机森林

了解causal forest之前,需要先了解其forest实现的载体:GENERALIZED RANDOM FORESTS[6](GRF)

其是随机森林的一种推广, 经典的随机森林只能去估计label Y,不能用于估计复杂的目标,比如causal effect,Causal Tree、Cauasl Forest的同一个作者对其进行了改良。先定义一下矩估计参数表达式:

\[\begin{equation} \tag{1}

\mathbb E[\psi_{\theta(x), \upsilon(x)}(O_i)|X=x]=0

\end{equation}

\]

其中,\(\psi\) 是score function,也就是measure metric,\(\theta\) 是我们不得不去计算的参数,比如tree里面的各项参数如特征threshold,叶子节点估计值..etc, \(\upsilon\)

则是一个可选参数。\(O\) 表示和计算相关的值,比如监督信号。像response类的模型,\(O_i={Y_i}\), 像causal 模型,\(O_i={Y_i, W_i}\) \(W\) 表示某种treatment。

该式在实际优化参数的时候,等价于最小化:

\[\tag{2} \left(\hat \theta(x), \upsilon(x)\right)\in argmin_{\theta, \upsilon}\left|\left|\sum\alpha_i(x)\psi_{\theta, \upsilon(O_i)}\right|\right|_2
\]

其中,\(\alpha\) 是一种权重,当然,这里也可以理解为树的权重,假设总共需要学习\(B\) 棵树:

\[\alpha_i(x)=\frac{1}{B}\sum_{b=1}^{B}\alpha_{bi}(x)
\]
\[\alpha_{bi(x)}=\frac{1(\{x\in L_b(x)\})}{|L_b(x)|}
\]

其中,\(L_b(x)\) 表示叶子节点里的样本。本质上,这个权重表示的是:训练样本和推理或者测试样本的相似度,因为如果某个样本\(x_i\)落入叶子\(L_b\) ,且我们可以认为叶子节点内的样本同质的情况下,那么可以认为这个样本和当前落入的tree有相似性。

当然,按照这个公式,如果\(L_b\) 很大,说明进入这个叶子的训练样本很多,意味着没划分完全,异质性低,则最后分配给这棵树的权重就低,反之亦然。

分裂准则框架

对于每棵树,父节点\(P\) 通过最优化下式进行分裂:

\[\tag{3}\left(\hat{\theta}_P, \hat{\nu}_P\right)(\mathcal{J}) \in \operatorname{argmin}_{\theta, \nu}\left\{\left\|\sum_{\left\{i \in \mathcal{J}: X_i \in P\right\}} \psi_{\theta, \nu}\left(O_i\right)\right\|_2\right\} .
\]

其中,\(\mathcal{J}\) 表示train set,分裂后形成的2个子节点标准为:通过最小化估计值与真实值间的误差平方:

\[\tag{4}\operatorname{err}\left(C_1, C_2\right)=\sum_{j=1,2} \mathbb{P}\left[X \in C_j \mid X \in P\right] \mathbb{E}\left[\left(\hat{\theta}_{C_j}(\mathcal{J})-\theta(X)\right)^2 \mid X \in C_j\right]
\]

等价于最大化节点间的异质性:

\[\tag{5}\Delta\left(C_1, C_2\right):=n_{C_1} n_{C_2} / n_P^2\left(\hat{\theta}_{C_1}(\mathcal{J})-\hat{\theta}_{C_2}(\mathcal{J})\right)^2
\]

但是\(\theta\) 参数比较难优化,交给梯度下降:

\[\tag{6}\tilde{\theta}_C=\hat{\theta}_P-\frac{1}{\left|\left\{i: X_i \in C\right\}\right|} \sum_{\left\{i: X_i \in C\right\}} \xi^{\top} A_P^{-1} \psi_{\hat{\theta}_P, \hat{\nu}_P}\left(O_i\right)
\]

其中,\(\hat \theta_P\) 通过 (2) 式获得, \(A_p\) 为score function的梯度

\[\tag{7}A_P=\frac{1}{\left|\left\{i: X_i \in P\right\}\right|} \sum_{\left\{i: X_i \in P\right\}} \nabla \psi_{\hat{\theta}_P, \hat{\nu}_P}\left(O_i\right),
\]

梯度计算部分包含2个step:

  • step1:labeling-step 得到一个pseudo-outcomes
\[\tag{8}\rho_i=-\xi^{\top} A_P^{-1} \psi_{\hat{\theta}_P, \hat{\nu}_P}\left(O_i\right) \in \mathbb{R}$.
\]
  • step2:回归阶段,用这个pseudo-outcomes 作为信号,传递给split函数, 最终是最大化下式指导节点分割
\[{\Delta}\left(C_1, C_2\right)=\sum_{j=1}^2 \frac{1}{\left|\left\{i: X_i \in C_j\right\}\right|}\left(\sum_{\left\{i: X_i \in C_j\right\}} \rho_i\right)^2
\]

以下是GRF的几种Applications:

Causal Forest

以Casual-Tree为base,不做任何估计量的改变

与单棵 tree 净化到 ensemble 一样,causal forest[7] 沿用了经典bagging系的随机森林,将一颗causal tree 拓展到多棵:

\[\hat \tau=\frac{1}{B}\sum_{b=1}^{B} \hat \tau_b(x)
\]

其中,每科子树\(\hat \tau\) 为一颗Casual Tree。使用随机森林作为拓展的好处之一是不需要对causal tree做任何的变换,这一点比boosing系的GBM显然成本也更低。

不过这个随机森林使用的是广义随机森林 , 经典的随机森林只能去估计label Y,不能用于估计复杂的目标,比如causal effect,Causal Tree、Cauasl Forest的同一个作者对其进行了改良,放在后面再讲。

在实现上,不考虑GRF,单机可以直接套用sklearn的forest子类,重写fit方法即可。分布式可以直接套用spark ml的forest。

self._estimator = CausalTreeRegressor(
control_name=control_name,
criterion=criterion,
groups_cnt=groups_cnt) trees = [self._make_estimator(append=False, random_state=random_state)
for i in range(n_more_estimators)] trees = Parallel(
n_jobs=self.n_jobs,
verbose=self.verbose,
**_joblib_parallel_args,
)(
delayed(_parallel_build_trees)(
t,
self,
X,
y,
sample_weight,
i,
len(trees),
verbose=self.verbose,
class_weight=self.class_weight,
n_samples_bootstrap=n_samples_bootstrap,
)
for i, t in enumerate(trees)
) self.estimators_.extend(trees)

CAPE:  适用连续treatment 的 causal effect预估

Conditional Average Partial Effects(CAPE)

GRF给定了一种框架:输入任意的score-function,能够指导最大化异质节点的方向持续分裂子树,和response类的模型一样,同样我们需要一些估计值(比如gini index、entropy)来计算分裂前后的score-function变化,计算估计值需要估计量,定义连续treatment的估计量为:

\[\theta(x)=\xi^{\top} \operatorname{Var}\left[W_i \mid X_i=x\right]^{-1} \operatorname{Cov}\left[W_i, Y_i \mid X_i=x\right]
\]

估计量参与指导分裂计算,但最终,叶子节点存储的依然是outcome的期望。

此处的motivation来源于工具变量和线性回归:

\[y=f(x)=wx+b
\]

此处我们假设\(x\)是treatment,y是outcome, \(w\) 作为一个参数简单的描述了施加treatment对结果的直接影响,要寻找到参数我们需要一个指标衡量参数好坏, 也就是loss, 和casual tree一样,通常使用mse:

\[L(w, b) = \frac{1}{2}\sum(f(x)-y)^2
\]

为了最快的找到这个w,当然是往函数梯度的方向, 我们对loss求偏导并令其为0:

\[\tag{1}\frac{\partial L}{\partial w}=\sum(f(x)-y)x=\sum(wx+b-y)x
\]
\[ \tag{2}

\begin{aligned}

\frac{\partial L}{\partial b} & = \sum(f(x)-y)=\sum(wx+b-y) \\

& \Rightarrow \sum b= \sum y-\sum wx \\

& \Rightarrow b = E(y)-wE(x) = \bar y - w\bar x

\end{aligned}

\]

(2) 代入 (1) 式可得:

\[
\begin{aligned}

\frac{\partial L}{\partial w} & \Rightarrow \sum(wx+\bar y-w\bar x-y)x =0 \\

&\Rightarrow w=\frac{\sum xy-\bar y\sum x}{\sum x^2-\bar x\sum x} \\

&\Rightarrow w=\frac{\sum(x-\bar x)(y-\bar y)}{\sum(x-\bar x)^2}\\

&\Rightarrow w=\frac{Cov(x,y)}{Var(x)}

\end{aligned}

\]

可简化得参数w是关于treatment和outcome的协方差/方差。至于\(\xi\) , 似乎影响不大。

refs

  1. https://hwcoder.top/Uplift-1
  2. 工具: scikit-uplift
  3. Meta-learners for Estimating Heterogeneous Treatment Effects using Machine Learning
  4. Athey, Susan, and Guido Imbens. "Recursive partitioning for heterogeneous causal effects." Proceedings of the National Academy of Sciences 113.27 (2016): 7353-7360.
  5. https://zhuanlan.zhihu.com/p/115223013
  6. Athey, Susan, Julie Tibshirani, and Stefan Wager. "Generalized random forests." (2019): 1148-1178.
  7. Wager, Stefan, and Susan Athey. "Estimation and inference of heterogeneous treatment effects using random forests." Journal of the American Statistical Association 113.523 (2018): 1228-1242.
  8. Rzepakowski, P., & Jaroszewicz, S. (2012). Decision trees for uplift modeling with single and multiple treatments. Knowledge and Information Systems32, 303-327.
  9. annik Rößler, Richard Guse, and Detlef Schoder. The best of two worlds: using recent advances from uplift modeling and heterogeneous treatment effects to optimize targeting policies. International Conference on Information Systems, 2022.

Causal Inference理论学习篇-Tree Based-Causal Forest的更多相关文章

  1. Targeted Learning R Packages for Causal Inference and Machine Learning(转)

    Targeted learning methods build machine-learning-based estimators of parameters defined as features ...

  2. 因果推理综述——《A Survey on Causal Inference》一文的总结和梳理

    因果推理 本文档是对<A Survey on Causal Inference>一文的总结和梳理. 论文地址 简介 关联与因果 先有的鸡,还是先有的蛋?这里研究的是因果关系,因果关系与普通 ...

  3. 【统计】Causal Inference

    [统计]Causal Inference 原文传送门 http://www.stat.cmu.edu/~larry/=sml/Causation.pdf 过程 一.Prediction 和 causa ...

  4. Causal Inference

    目录 Standardization 非参数情况 Censoring 参数模型 Time-varying 静态 IP weighting 无参数 Censoring 参数模型 censoring 条件 ...

  5. A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python)

    A Complete Tutorial on Tree Based Modeling from Scratch (in R & Python) MACHINE LEARNING PYTHON  ...

  6. 算法---FaceNet理论学习篇

    FaceNet算法-理论学习篇 @WP20190228 ==============目 录============ 一.LFW数据集简介 二.FaceNet算法简介 FaceNet算法=MTCNN模型 ...

  7. Decision Tree、Random Forest、AdaBoost、GBDT

    原文地址:https://www.jianshu.com/p/d8ceeee66a6f Decision Tree 基本思想在于每次分裂节点时选取一个特征使得划分后得到的数据集尽可能纯. 划分标准 信 ...

  8. Chapter 2 Randomized Experiments

    目录 概 2.1 Randomization 2.2 Conditional randomization 2.3 Standardization 2.4 Inverse probability wei ...

  9. Chapter 6 Graphical Representation of Causal Effects

    目录 6.1 Causal diagrams 6.2 Causal diagrams and marginal independence 6.3 Causal diagrams and conditi ...

  10. Chapter 1 A Definition of Causal Effect

    目录 1.1 Individual casual effects 1.2 Average casual effects 1.5 Causation versus association Hern\(\ ...

随机推荐

  1. 时间同步 ntp服务器

    目录 一. 定义 二. 项目要求 三. 部署服务端 四. 部署客户端 一. 定义 #01 简介:ntp全名 network time protocol .NTP服务器可以为其他主机提供时间校对服务 # ...

  2. vetur volar 是否可以共用,vue2 vue3项目 如何同时开发?

    vetur volar 是否可以共用,vue2 vue3项目 如何同时开发? 先提问 以后再找答案 20220704 补答 vetur volar 不要一起装 vscode环境

  3. linux 无法找到“/usr/bin/core_perl/gcc” vscode

    解决问题的思路 查看有没有gcc,没有安装 有的话就是,修改安装路径就可以? "/usr/bin/core_perl/gcc".修改成Gcc的绝对路径 我的修改是./usr/bin ...

  4. Java使用Steam流对数组进行排序

    原文地址:Java使用Steam流对数组进行排序 - Stars-One的杂货小窝 简单记下笔记,不是啥难的东西 sorted()方法里传了一个比较器的接口 File file = new File( ...

  5. Morris遍历:常数空间遍历二叉树

    Morris遍历 cur有左树且第一次遍历到,去左孩子 没左树或者第二次遍历到,去右孩子 没右树,去后继节点 得到Morris序.对于该序列中出现两次的节点,只保留第一次遍历,结果就是先序遍历.只保留 ...

  6. SVN 提交文件报错:svn: E155015: Aborting commit:

    svn 提交文件报错: svn: E155015: Commit failed (details follow): svn: E155015: Aborting commit: '文件名称' rema ...

  7. 实时3D渲染它是如何工作的?可以在哪些行业应用?

    随着新兴技术--3D渲染的发展,交互应用的质量有了极大的提高.用实时三维渲染软件创建的沉浸式数字体验,几乎与现实没有区别了.随着技术的逐步改进,在价格较低的个人工作站上渲染3D图像变得更加容易,设计师 ...

  8. java方法的内存及练习

    方法的内存 一.方法调用的基本内存原理: Java内存分配 栈: 方法运行时使用的内存方法进栈运行,运行完毕就出栈 堆: newl出来的,都在堆内存中开辟了一个小空间 方法区: 存储可以运行的clas ...

  9. C++中虚表是什么

    虚函数表,以及虚函数指针是实现多态性(Polymorphism)的关键机制.多态性允许我们通过基类的指针或引用来调用派生类的函数 定义 虚函数(Virtual Function) 定义:类中使用vir ...

  10. 记录--使用Vue开发Chrome插件

    这里给大家分享我在网上总结出来的一些知识,希望对大家有所帮助 环境搭建 Vue Web-Extension - A Web-Extension preset for VueJS (vue-web-ex ...