Title: | Causal Distillation Trees |
Version: | 1.0.0 |
Description: | Causal Distillation Tree (CDT) is a novel machine learning method for estimating interpretable subgroups with heterogeneous treatment effects. CDT allows researchers to fit any machine learning model (or metalearner) to estimate heterogeneous treatment effects for each individual, and then "distills" these predicted heterogeneous treatment effects into interpretable subgroups by fitting an ordinary decision tree to predict the previously-estimated heterogeneous treatment effects. This package provides tools to estimate causal distillation trees (CDT), as detailed in Huang, Tang, and Kenney (2025) <doi:10.48550/arXiv.2502.07275>. |
License: | MIT + file LICENSE |
Encoding: | UTF-8 |
RoxygenNote: | 7.3.2 |
Depends: | R (≥ 4.1.0) |
Suggests: | testthat (≥ 3.0.0) |
Config/testthat/edition: | 3 |
LinkingTo: | Rcpp, RcppArmadillo |
Imports: | bcf, dplyr, ggparty, ggplot2, grf, lifecycle, partykit, purrr, R.utils, Rcpp, rlang, rpart, stringr, tibble, tidyselect |
URL: | https://tiffanymtang.github.io/causalDT/ |
NeedsCompilation: | yes |
Packaged: | 2025-08-28 20:55:06 UTC; ttang4 |
Author: | Tiffany Tang |
Maintainer: | Tiffany Tang <ttang4@nd.edu> |
Repository: | CRAN |
Date/Publication: | 2025-09-03 08:00:13 UTC |
Get list of rules from a party model.
Description
This is a copy of partykit:::.list.rules.party()
that is exported for
use in the causal distillation tree framework.
Usage
.list.rules.party(x, i = NULL, ...)
Causal Distillation Trees (CDT)
Description
This function implements causal distillation trees (CDT), developed in Huang et al. (2025). Briefly, CDT is a two-stage procedure that allows researchers to identify interpretable subgroups with heterogeneous treatment effects. In the first stage, researchers are free to use any machine learning model or metalearner to predict the heterogeneous treatment effects for each individual in the dataset. In the second stage, CDT “distills” these predicted heterogeneous treatment effects into interpretable subgroups by fitting an ordinary decision tree using the predicted heterogeneous treatment effects from the first stage as the response variable.
Usage
causalDT(
X,
Y,
Z,
W = NULL,
holdout_prop = 0.3,
holdout_idxs = NULL,
teacher_model = "causal_forest",
teacher_predict = NULL,
student_model = "rpart",
rpart_control = NULL,
rpart_prune = c("none", "min", "1se"),
nfolds_crossfit = NULL,
nreps_crossfit = NULL,
B_stability = 100,
max_depth_stability = NULL,
...
)
Arguments
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
Z |
A vector of treatments. |
W |
A vector of weights corresponding to treatment propensities. |
holdout_prop |
Proportion of data to hold out for honest estimation of
treatment effects. Used only if |
holdout_idxs |
A vector of indices to hold out for honest estimation of
treatment effects. If NULL, a holdout set of size |
teacher_model |
Teacher model used to estimate individual-level
treatment events. Should be either "causal_forest" (default),
"bcf", or a function.
If "causal_forest", |
teacher_predict |
Function used to predict individual-level treatment
effects from the teacher model. Should take in two arguments. as input: the
first being the model object returned by |
student_model |
Student model used to estimate subgroups of individuals
and their corresponding estimated treatment effects. Should be either
"rpart" (default) or a function. If "rpart", |
rpart_control |
A list of control parameters for the |
rpart_prune |
Method for pruning the tree. Default is |
nfolds_crossfit |
Number of folds in cross-fitting procedure.
If |
nreps_crossfit |
Number of repetitions of the cross-fitting procedure.
If |
B_stability |
Number of bootstrap samples to use in evaluating stability
diagnostics (which can be used to select an appropriate teacher model).
Default is 100. Stability diagnostics are only performed if
|
max_depth_stability |
Maximum depth of the decision tree used in
evaluating stability diagnostics. If |
... |
Additional arguments passed to the |
Value
A list with the following elements:
estimate |
Estimated subgroup average treatment effects tibble with the following columns:
|
student_fit |
Output of
|
teacher_fit |
A list of (cross-fitted) teacher model fits. |
teacher_predictions |
The predicted individual-level treatment effects, averaged across all cross-fitted teacher model. |
teacher_predictions_ls |
A list of predicted individual-level treatment effects from each (cross-fitted) teacher model fit. |
crossfit_idxs_ls |
A list of fold indices used in each cross-fit. |
stability_diagnostics |
A list of stability diagnostics with the following elements:
|
holdout_idxs |
Indices of the holdout set. |
References
Huang, M., Tang, T. M., and Kenney, A. M. (2025). Distilling heterogeneous treatment effects: Stable subgroup estimation in causal inference. arXiv preprint arXiv:2502.07275.
Examples
n <- 50
p <- 3
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# causal distillation trees using causal forest teacher model
out <- causalDT(X, Y, Z)
Crossfit wrapper around estimators
Description
Crossfit wrapper around estimators
Usage
crossfit(estimator, X, Y, Z, W = NULL, split_idxs, ...)
Subgroup CATE estimation.
Description
This function estimates the conditional average treatment effect for each subgroup given by the fitted decision tree. The conditional average treatment effect is estimated as the difference in the average outcome between treated and control units that fall within each subgroup (i.e., each leaf node in the decision tree).
Usage
estimate_group_cates(fit, X, Y, Z)
Arguments
fit |
Fitted subgroup model used to determine subgroup membership of
individuals. Typically, this is a |
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
Z |
A vector of treatments. |
Value
Estimated subgroup average treatment effects tibble with the following columns:
leaf_id |
Leaf node identifier. |
subgroup |
String representation of the subgroup. |
estimate |
Estimated conditional average treatment effect for the subgroup. |
variance |
Asymptotic variance of the estimated conditional average treatment effect. |
.var1 |
Sample variance for treated observations in the subgroup. |
.var0 |
Sample variance for control observations in the subgroup. |
.n1 |
Number of treated observations in the subgroup. |
.n0 |
Number of control observations in the subgroup. |
.sample_idxs |
Indices of (holdout) observations in the subgroup. |
Examples
n <- 50
p <- 3
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# causal distillation tree output
out <- causalDT(X, Y, Z)
# compute subgroup CATEs manually
group_cates <- estimate_group_cates(
out$student_fit$fit,
X = X[out$holdout_idxs, , drop = FALSE],
Y = Y[out$holdout_idxs],
Z = Z[out$holdout_idxs]
)
all.equal(out$estimate, group_cates)
Subgroup stability diagnostics
Description
This function evaluates the stability of the estimated subgroups from causal distillation trees (CDT) using the Jaccard subgroup stability index (SSI), developed in Huang et al. (2025). It is generally recommended to choose teacher models in CDT that result in the most stable subgroups, as indicated by high SSI values.
Usage
evaluate_subgroup_stability(
estimator,
fit,
X,
y,
Z = NULL,
rpart_control = NULL,
B = 100,
max_depth = NULL
)
Arguments
estimator |
Function used to estimate subgroups of individuals and their
corresponding estimated treatment effects. The function should take in
|
fit |
Fitted subgroup model (often, the output of |
X |
A tibble, data.frame, or matrix of covariates. |
y |
A vector of responses to predict. |
Z |
A vector of treatments. |
rpart_control |
A list of control parameters for the |
B |
Number of bootstrap samples to use in evaluating stability diagnostics. Default is 100. |
max_depth |
Maximum depth of the tree to consider when evaluating
stability diagnostics. If |
Value
A list with the following elements:
jaccard_mean |
Vector of mean Jaccard similarity index for each tree depth. The tree depth is given by the vector index. |
jaccard_distribution |
List of Jaccard similarity indices across all bootstraps for each tree depth. |
bootstrap_predictions |
List of mean student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth. |
bootstrap_predictions_var |
List of variance of student model predictions (for training (non-holdout) data) across all bootstraps for each tree depth. |
leaf_ids |
List of leaf node identifiers, indicating the leaf membership of each training sample in the (original) fitted student model. |
References
Huang, M., Tang, T. M., and Kenney, A. M. (2025). Distilling heterogeneous treatment effects: Stable subgroup estimation in causal inference. arXiv preprint arXiv:2502.07275.
Examples
n <- 200
p <- 10
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# run causal distillation trees without stability diagnostics
out <- causalDT(X, Y, Z, B_stability = 0)
# run stability diagnostics
stability_out <- evaluate_subgroup_stability(
estimator = student_rpart,
fit = out$student_fit$fit,
X = X[-out$holdout_idxs, , drop = FALSE],
y = out$student_fit$predictions
)
Get depth of each node in a party model.
Description
Get depth of each node in a party model.
Usage
get_party_node_depths(party_fit, return_features = FALSE)
Arguments
party_fit |
A |
return_features |
Logical indicating whether to return the feature associated with each node. |
Get decision paths from a party model.
Description
Return the decision paths for each leaf node in a party
model as character
strings.
Usage
get_party_paths(party_fit)
Arguments
party_fit |
A |
Value
A list of character vectors, where each element corresponds to the decision
path for a leaf node in the party_fit
model.
Get decision paths from an rpart model.
Description
Return the decision paths for each leaf node in an rpart
model as character
strings.
Usage
get_rpart_paths(rpart_fit)
Arguments
rpart_fit |
An |
Value
A list of character vectors, where each element corresponds to the decision
path for a leaf node in the rpart_fit
model.
Get split information from an rpart model.
Description
Return the split information for each node in an rpart
model as a data frame.
Usage
get_rpart_tree_info(rpart_fit, X = NULL, digits = getOption("digits"))
Arguments
rpart_fit |
An |
X |
Optional data frame containing the features used in the |
digits |
Number of digits to round the split values to. |
Value
A data.frame with information regarding the feature/threshold used
for each split in the rpart
model.
Plot causal distillation tree object
Description
Visualize the subgroups (i.e., the student tree) from a causal distillation tree object.
Usage
plot_cdt(cdt, show_digits = 2)
Arguments
cdt |
A causal distillation tree object, typically the output of
|
show_digits |
Number of digits to show in the plot labels. Default is 2. |
Value
A plot of the causal distillation tree.
Examples
n <- 200
p <- 10
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
cdt <- causalDT(X, Y, Z)
plot_cdt(cdt)
Plot Jaccard subgroup similarity index (SSI) for causal distillation tree objects
Description
The Jaccard subgroup similiarity index (SSI) is a measure of the similarity between two candidate partitions of subgroups. To select an appropriate teacher model in CDT, the Jaccard SSI can be used to select the teacher model that recovers the most stable subgroups.
Usage
plot_jaccard(...)
Arguments
... |
Two or more causal distillation tree objects, each is typically
the output of |
Value
A plot of the Jaccard SSI for each tree depth.
Examples
n <- 50
p <- 2
X <- matrix(rnorm(n * p), nrow = n, ncol = p)
Z <- rbinom(n, 1, 0.5)
Y <- 2 * Z * (X[, 1] > 0) + X[, 2] + rnorm(n, 0.1)
# number of bootstraps for stability diagnostics (setting to small value for faster example)
B <- 10
# run CDT with default causal forest teacher model
cdt1 <- causalDT(X, Y, Z, B_stability = B)
# run CDT with custom BCF teacher model
cdt2 <- causalDT(
X, Y, Z,
# set BCF training parameters to be small for faster example
teacher_model = purrr::partial(bcf, nsim = 100, nburn = 10),
teacher_predict = predict_bcf,
# set number of cross-fitting replications to be small for faster example
nreps_crossfit = 5,
B_stability = B
)
plot_jaccard(`Causal Forest` = cdt1, `BCF` = cdt2)
Predict method for cross-fitted estimators
Description
Predict method for cross-fitted estimators
Usage
predict_crossfit(fits, X, split_idxs, predict_fun)
Predict wrappers for teacher models for causal distillation trees
Description
These functions are predict()
method wrappers for various heterogeneous
treatment effect learners that can be easily used as
teacher models in the causal distillation tree framework.
-
predict_causal_forest()
: wrapper aroundpredict()
forcausal_forest()
models. -
predict_bcf()
: wrapper aroundpredict()
forbcf()
models.
Usage
predict_causal_forest(...)
predict_bcf(...)
Arguments
... |
Additional arguments to pass to the base model |
Value
Vector of predicted conditional average treatment effects (CATEs).
Rlearner teacher model wrapper for causal distillation trees
Description
This is a wrapper function to convert any of the rlearner
model functions into a format that can be used as teacher model in the
causal distillation tree framework.
Usage
rlearner_teacher(rlearner_fun, ...)
Arguments
rlearner_fun |
One of |
... |
Additional arguments to pass to the base model functions. |
Value
Outputs a function that can be used as teacher model in the
causal distillation tree framework. The returned function has the
signature function(X, Y, Z, W = NULL, ...)
.
Arguments that are shared across functions
Description
Arguments that are shared across functions
Arguments
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
y |
A vector of responses to predict. |
Z |
A vector of treatments. |
W |
A vector of weights corresponding to treatment propensities. |
rpart_control |
A list of control parameters for the |
rpart_fit |
An |
party_fit |
A |
prune |
Method for pruning the tree. Default is |
rpart_prune |
Method for pruning the tree. Default is |
Rpart wrapper for causal distillation trees.
Description
This function is a wrapper around rpart::rpart()
that can be easily
used as a student model in the causal distillation tree framework.
Usage
student_rpart(
X,
y,
method = "anova",
rpart_control = NULL,
prune = c("none", "min", "1se"),
fit_only = FALSE
)
Arguments
X |
A tibble, data.frame, or matrix of covariates. |
y |
A vector of responses to predict. |
method |
Same as |
rpart_control |
A list of control parameters for the |
prune |
Method for pruning the tree. Default is |
fit_only |
Logical. If |
Value
If fit_only = TRUE
, the fitted model is returned. Otherwise, a list
with the following components is returned:
fit |
Fitted model. An |
tree_info |
Data frame with tree structure/split information. |
subgroups |
List of subgroups given by their string representation. |
predictions |
Student model predictions for the given |
Teacher models for causal distillation trees
Description
These functions are wrappers around various heterogeneous treatment effect learners that can be easily used as teacher models in the causal distillation tree framework.
-
causal_forest()
: wrapper aroundgrf::causal_forest()
. -
bcf()
: wrapper aroundbcf::bcf()
. -
rboost()
: (defunct) wrapper aroundrlearner::rboost()
. -
rlasso()
: (defunct) wrapper aroundrlearner::rlasso()
. -
rkern()
: (defunct) wrapper aroundrlearner::rkern()
.
Warning: The rboost()
, rlasso()
, and rkern()
functions
are defunct as of version 1.0.0. Use rlearner_teacher()
(e.g.,
rlearner_teacher(rlearner::rboost)
) instead to convert
rlearner
functions into correct format for use as teacher model in
CDT.
Usage
causal_forest(X, Y, Z, W = NULL, ...)
rboost(X, Y, Z, W = NULL, ...)
rlasso(X, Y, Z, W = NULL, ...)
rkern(X, Y, Z, W = NULL, ...)
bcf(
X,
Y,
Z,
W = NULL,
pihat = "default",
w = NULL,
nburn = 2000,
nsim = 1000,
n_threads = 1,
no_output = TRUE,
...
)
Arguments
X |
A tibble, data.frame, or matrix of covariates. |
Y |
A vector of outcomes. |
Z |
A vector of treatments. |
W |
A vector of weights corresponding to treatment propensities. |
... |
Additional arguments to pass to the base model functions. |
pihat |
Length n estimates of propensity score |
w |
An optional vector of weights. When present, BCF fits a model |
nburn |
Number of burn-in MCMC iterations |
nsim |
Number of MCMC iterations to save after burn-in. The chain will run for nsim*nthin iterations after burn-in |
n_threads |
An optional integer of the number of threads to parallelize within chain bcf operations on |
no_output |
logical, whether to suppress writing trees and training log to text files, defaults to FALSE. |
Value
Outputs of the respective base model functions:
-
causal_forest()
: see output ofgrf::causal_forest()
. -
rboost()
(defunct): see output ofrlearner::rboost()
. -
rlasso()
(defunct): see output ofrlearner::rlasso()
. -
rkern()
(defunct): see output ofrlearner::rkern()
.