## ----setup, include = FALSE--------------------------------------------------- knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.dim = c(6, 4) ) suppressPackageStartupMessages({ library(BART) library(tidytreatment) library(dplyr) library(tidybayes) library(ggplot2) }) # load pre-computed data and model sim <- suhillsim1 te_model <- bartmodel1 # pre compute posterior_treat_eff <- treatment_effects(te_model, treatment = "z", newdata = sim$data) posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", newdata = sim$dat, subset = "treated") ## ----load-data-print, echo = TRUE, eval = FALSE------------------------------- # # # load packages # library(BART) # library(tidytreatment) # library(dplyr) # library(tidybayes) # library(ggplot2) # # # set seed so vignette is reproducible # set.seed(101) # # # simulate data # sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, # coef_categorical_treatment = c(0,0,1), # coef_categorical_nontreatment = c(-1,0,-1) # ) # ## ----data-summary, echo = TRUE, eval = TRUE----------------------------------- # non-treated vs treated counts: table(sim$data$z) dat <- sim$data # a selection of data dat %>% select(y, z, c1, x1:x3) %>% head() ## ----run-bart, echo = TRUE, eval = FALSE-------------------------------------- # # # STEP 1 VS Model: Regress y ~ covariates # var_select_bart <- wbart(x.train = select(dat,-y,-z), # y.train = pull(dat, y), # sparse = TRUE, # nskip = 2000, # ndpost = 5000) # # # STEP 2: Variable selection # # select most important vars from y ~ covariates model # # very simple selection mechanism. Should use cross-validation in practice # covar_ranking <- covariate_importance(var_select_bart) # var_select <- covar_ranking %>% # filter(avg_inclusion >= quantile(avg_inclusion, 0.5)) %>% # pull(variable) # # # change categorical variables to just one variable # var_select <- unique(gsub("c1[1-3]$","c1", var_select)) # # var_select # # # STEP 3 PS Model: Regress z ~ selected covariates # # BART::pbart is for probit regression # prop_bart <- pbart( # x.train = select(dat, all_of(var_select)), # y.train = pull(dat, z), # nskip = 2000, # ndpost = 5000 # ) # # # store propensity score in data # dat$prop_score <- prop_bart$prob.train.mean # # # Step 4 TE Model: Regress y ~ z + covariates + propensity score # te_model <- wbart( # x.train = select(dat,-y), # y.train = pull(dat, y), # nskip = 10000L, # ndpost = 200L, #* # keepevery = 100L #* # ) # # #* The posterior samples are kept small to manage size on CRAN # ## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------ posterior_fitted <- fitted_draws(te_model, value = "fit", include_newdata = FALSE) # include_newdata = FALSE, avoids returning the newdata with the fitted values # as it is so large. newdata argument must be specified for this option in BART models. # The `.row` variable makes sure we know which row in the newdata the fitted # value came from (if we dont include the data in the result). posterior_fitted ## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE----------------------- # # # Function to tidy predicted draws also, this adds random normal noise by default # posterior_pred <- predicted_draws(te_model, include_newdata = FALSE) # ## ----plot-tidy-bart, echo=TRUE, cache=FALSE----------------------------------- treatment_var_and_c1 <- dat %>% select(z,c1) %>% mutate(.row = 1:n(), z = as.factor(z)) posterior_fitted %>% left_join(treatment_var_and_c1, by = ".row") %>% ggplot() + stat_halfeye(aes(x = z, y = fit)) + facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + xlab("Treatment (z)") + ylab("Posterior predicted value") + theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values") ## ----post-treatment, eval = FALSE--------------------------------------------- # # # sample based (using data from fit) conditional treatment effects, posterior draws # posterior_treat_eff <- # treatment_effects(te_model, treatment = "z", newdata = dat) # ## ----cates-hist, echo=TRUE, cache=FALSE--------------------------------------- # Histogram of treatment effect (all draws) posterior_treat_eff %>% ggplot() + geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (all draws)") # Histogram of treatment effect (median for each subject) posterior_treat_eff %>% summarise(cte_hat = median(cte)) %>% ggplot() + geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)") ## ----att-ate, eval=FALSE------------------------------------------------------ # # get the ATE and ATT directly: # # posterior_ate <- tidy_ate(te_model, treatment = "z", newdata = dat) # posterior_att <- tidy_att(te_model, treatment = "z", newdata = dat) # ## ----ate-trace-setup, eval = TRUE, echo = FALSE------------------------------- posterior_ate <- posterior_treat_eff %>% group_by(.chain, .iteration, .draw) %>% summarise(ate = mean(cte), .groups = "drop") ## ----ate-trace, eval=TRUE, echo=TRUE------------------------------------------ posterior_ate %>% ggplot(aes(x = .draw, y = ate)) + geom_line() + theme_bw() + ggtitle("Trace plot of ATE") ## ----post-te-treated, echo=TRUE, eval=FALSE----------------------------------- # # # sample based (using data from fit) conditional treatment effects, posterior draws # posterior_treat_eff_on_treated <- # treatment_effects(te_model, treatment = "z", newdata = dat, subset = "treated") # ## ----cates-hist-treated, echo=TRUE, cache=FALSE------------------------------- posterior_treat_eff_on_treated %>% ggplot() + geom_histogram(aes(x = cte), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (all draws from treated subjects)") ## ----cates-stack-plot, echo=TRUE, cache=FALSE--------------------------------- posterior_treat_eff %>% select(-z) %>% point_interval() %>% arrange(cte) %>% mutate(.orow = 1:n()) %>% ggplot() + geom_interval(aes(x = .orow, y= cte, ymin = .lower, ymax = .upper)) + geom_point(aes(x = .orow, y = cte), shape = "circle open", alpha = 0.5) + ylab("Median posterior CATE for each subject (95% CI)") + theme_bw() + coord_flip() + scale_colour_brewer() + theme(axis.title.y = element_blank(), axis.text.y = element_blank(), axis.ticks.y = element_blank(), legend.position = "none") ## ----cates-line-plot, echo=TRUE, cache=FALSE---------------------------------- posterior_treat_eff %>% left_join(tibble(c1 = dat$c1, .row = 1:length(dat$c1) ), by = ".row") %>% group_by(c1) %>% ggplot() + stat_halfeye(aes(x = c1, y = cte), alpha = 0.7) + scale_fill_brewer() + theme_bw() + ggtitle("Treatment effect by `c1`") ## ----common-support, echo=TRUE, results='hide', cache=FALSE------------------- # calculate common support directly # argument 'modeldata' must be specified for BART models csupp_chisq <- has_common_support(te_model, treatment = "z", modeldata = dat, method = "chisq", cutoff = 0.05) csupp_chisq %>% filter(!common_support) csupp_sd <- has_common_support(te_model, treatment = "z", modeldata = dat, method = "sd", cutoff = 1) csupp_sd %>% filter(!common_support) # calculate treatment effects (on those who were treated) # and include only those estimates with common support posterior_treat_eff_on_treated <- treatment_effects(te_model, treatment = "z", subset = "treated", newdata = dat, common_support_method = "sd", cutoff = 1) ## ----interaction-investigator, echo=TRUE, cache=FALSE------------------------- treatment_interactions <- covariate_with_treatment_importance(te_model, treatment = "z") treatment_interactions %>% ggplot() + geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + theme_bw() + ggtitle("Important variables interacting with treatment ('z')") + ylab("Inclusion counts") + theme(axis.text.x = element_text(angle = 45, hjust=1)) variable_importance <- covariate_importance(te_model) variable_importance %>% ggplot() + geom_bar(aes(x = variable, y = avg_inclusion), stat = "identity") + theme_bw() + ggtitle("Important variables overall") + ylab("Inclusion counts") + theme(axis.text.x = element_text(angle = 45, hjust=1)) ## ----sigma-trace, echo=TRUE, cache=FALSE-------------------------------------- # includes skipped MCMC samples variance_draws(te_model, value = "siqsq") %>% filter(.draw > 10000) %>% ggplot(aes(x = .draw, y = siqsq)) + geom_line() + theme_bw() + ggtitle("Trace plot of model variance post warm-up") ## ----convergence-bart, echo=TRUE, cache=FALSE--------------------------------- res <- residual_draws(te_model, response = pull(dat, y), include_newdata = FALSE) res %>% point_interval(.residual, y, .width = c(0.95) ) %>% select(-y.lower, -y.upper) %>% ggplot() + geom_pointinterval(aes(x = y, y = .residual, ymin = .residual.lower, ymax = .residual.upper), alpha = 0.2) + scale_fill_brewer() + theme_bw() + ggtitle("Residuals vs observations") res %>% summarise(.fitted = mean(.fitted), y = first(y)) %>% ggplot(aes(x = y, y = .fitted)) + geom_point() + geom_smooth(method = "lm") + theme_bw() + ggtitle("Observations vs fitted") res %>% summarise(.residual = mean(.residual)) %>% ggplot(aes(sample = .residual)) + geom_qq() + geom_qq_line() + theme_bw() + ggtitle("Q-Q plot of residuals")