## ----setup, include = FALSE--------------------------------------------------- knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.dim = c(6, 4) ) suppressPackageStartupMessages({ library(bartCause) library(stan4bart) library(tidytreatment) library(dplyr) library(tidybayes) library(ggplot2) }) # load pre-computed data and model sim <- suhillsim2_ranef ## ----load-data-print, echo = TRUE, eval = FALSE------------------------------- # # # load packages # library(bartCause) # library(stan4bart) # library(tidytreatment) # library(dplyr) # library(tidybayes) # library(ggplot2) # # # set seed so vignette is reproducible # set.seed(101) # # # simulate data # sim <- simulate_su_hill_data(n = 100, treatment_linear = FALSE, omega = 0, add_categorical = TRUE, # n_subjects = 10, sd_subjects = 0.1, # coef_categorical_treatment = c(0,0,1), # coef_categorical_nontreatment = c(-1,0,-1) # ) # ## ----data-summary, echo = TRUE, eval = TRUE----------------------------------- # non-treated vs treated counts: table(sim$data$z) dat <- sim$data # a selection of data dat %>% select(y, z, c1, x1:x3) %>% head() # repeated observation counts for subjects: table(sim$data$subject_id) ## ----run-bart, echo = TRUE, eval = TRUE--------------------------------------- # STEP 1 VS Model: Regress y ~ covariates vs_bart <- stan4bart(y ~ bart(. - subject_id - z) + (1|subject_id), data = dat, iter = 5000, verbose = -1) # STEP 2: Variable selection # select most important vars from y ~ covariates model # very simple selection mechanism. Should use cross-validation in practice covar_ranking <- covariate_importance(vs_bart) var_select <- covar_ranking %>% filter(avg_inclusion > mean(avg_inclusion) - sd(avg_inclusion)) %>% # at minimum: within 1 sd of mean inclusion pull(variable) # change categorical variables to just one variable var_select <- unique(gsub("c1.[1-3]$","c1", var_select)) var_select # includes all covariates # STEP 3 PS Model: Regress z ~ selected covariates ps_bart <- stan4bart(z ~ bart(. - subject_id - y) + (1|subject_id), data = dat, iter = 5000, verbose = -1) # store propensity score in data prop_score <- fitted(ps_bart) # Step 4 TE Model: Regress y ~ z + covariates + propensity score te_bart <- bartc(response = y, treatment = z, confounders = x1 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 + x10, parametric = (1|subject_id), data = dat, method.trt = prop_score, iter = 5000, bart_args = list(keepTrees = TRUE)) #* The posterior samples are kept small to manage size on CRAN ## ----tidy-bart-fit, echo=TRUE, cache=FALSE------------------------------------ # get model parameters (excluding BART paramaters) posterior_params <- tidy_draws(te_bart) posterior_fitted <- epred_draws(te_bart, value = "fitted") ## ----tidy-bart-pred, eval=FALSE, echo=TRUE, cache=FALSE----------------------- # # # Function to tidy predicted draws (adds predicted noise to fitted values) # posterior_pred <- predicted_draws(te_bart, value = "predicted") # ## ----plot-tidy-bart, echo=TRUE, cache=FALSE----------------------------------- treatment_var_and_c1 <- dat %>% select(z,c1) %>% mutate(.row = 1:n(), z = as.factor(z)) posterior_fitted %>% left_join(treatment_var_and_c1, by = ".row") %>% ggplot() + stat_halfeye(aes(x = z, y = fitted)) + facet_wrap(~c1, labeller = as_labeller( function(x) paste("c1 =",x) ) ) + xlab("Treatment (z)") + ylab("Posterior predicted value") + theme_bw() + ggtitle("Effect of treatment with 'c1' on posterior fitted values") ## ----post-treatment, eval = T------------------------------------------------- # sample based (using data from fit) conditional treatment effects, posterior draws posterior_treat_eff <- treatment_effects(te_bart) # check lines up with summary results... ## ----cates-hist, echo=TRUE, cache=FALSE, eval = T----------------------------- # Histogram of treatment effect (all draws) posterior_treat_eff %>% ggplot() + geom_histogram(aes(x = icate), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (all draws)") # Histogram of treatment effect (median for each subject) posterior_treat_eff %>% summarise(cte_hat = median(icate)) %>% ggplot() + geom_histogram(aes(x = cte_hat), binwidth = 0.1, colour = "white") + theme_bw() + ggtitle("Histogram of treatment effect (median for each subject)")