“The XAItest package includes several classic feature importance algorithms and supports the addition of new ones. To integrate an XGBoost model and generate its feature importance metrics using the SHAP package shapr.
The function should accept:
The function should return:
# Load the libraries
library(XAItest)
library(ggplot2)
library(ggforce)
library(SummarizedExperiment)
se_path <- system.file("extdata", "seClassif.rds", package="XAItest")
dataset_classif <- readRDS(se_path)
data_matrix <- assay(dataset_classif, "counts")
data_matrix <- t(data_matrix)
metadata <- as.data.frame(colData(dataset_classif))
df_simu_classif <- as.data.frame(cbind(data_matrix, y = metadata[['y']]))
for (col in names(df_simu_classif)) {
if (col != 'y') {
df_simu_classif[[col]] <- as.numeric(df_simu_classif[[col]])
}
}
featureImportanceXGBoost <- function(df, y="y", ...){
# Prepare data
matX <- as.matrix(df[, colnames(df) != y])
vecY <- df[[y]]
vecY <- as.character(vecY)
vecY[vecY == unique(vecY)[1]] <- 0
vecY[vecY == unique(vecY)[2]] <- 1
vecY <- as.numeric(vecY)
# Train the XGBoost model
model <- xgboost::xgboost(data = matX, label = vecY,
nrounds = 10, verbose = FALSE)
modelPredictions <- predict(model, matX)
modelPredictionsCat <- modelPredictions
modelPredictionsCat[modelPredictions < 0.5] <-
unique(as.character(df[[y]]))[1]
modelPredictionsCat[modelPredictions >= 0.5] <-
unique(as.character(df[[y]]))[2]
# Specifying the phi_0, i.e. the expected prediction without any features
p <- mean(vecY)
# Computing the actual Shapley values with kernelSHAP accounting
# for feature dependence using the empirical (conditional)
# distribution approach with bandwidth parameter sigma = 0.1 (default)
explainer <- shapr::shapr(matX, model, n_combinations = 2000)
explanation <- shapr::explain(
matX,
approach = "empirical",
explainer = explainer,
prediction_zero = p,
n_combinations = 1000
)
results <- colMeans(abs(explanation$dt), na.rm = TRUE)
list(featImps = results, model = model, modelPredictions=modelPredictionsCat)
}
set.seed(123)
results <- XAI.test(dataset_classif,"y", simData = TRUE,
simPvalTarget = 0.001,
customFeatImps=
list("XGB_SHAP_feat_imp"=featureImportanceXGBoost),
defaultMethods = c("ttest", "lm")
)
## The specified model provides feature classes that are NA. The classes of data are taken as the truth.
The mapPvalImportance function reveals that both the custom XGB_SHAP_feat_imp and other feature importance metrics identify the biDistrib feature as significant.
Display as a data.frame:
mpi <- mapPvalImportance(results, refPvalColumn = "ttest_adjPval", refPval = 0.001)
head(mpi$df)
## ttest_pval isSign_ttest_pval ttest_adjPval isSign_ttest_adjPval
## diff_distrib02 1.02e-43 1 1.42e-42 1
## diff_distrib01 7.93e-37 1 1.11e-35 1
## simFeat 3.89e-05 1 5.45e-04 1
## norm_noise03 7.97e-02 0 1.00e+00 0
## norm_noise08 1.11e-01 0 1.00e+00 0
## norm_noise09 1.44e-01 0 1.00e+00 0
## lm_pval isSign_lm_pval lm_adjPval isSign_lm_adjPval
## diff_distrib02 3.57e-21 1 5.00e-20 1
## diff_distrib01 1.27e-11 1 1.78e-10 1
## simFeat 7.98e-02 0 1.00e+00 0
## norm_noise03 6.11e-01 0 1.00e+00 0
## norm_noise08 8.86e-01 0 1.00e+00 0
## norm_noise09 5.16e-01 0 1.00e+00 0
## XGB_SHAP_feat_imp isSign_XGB_SHAP_feat_imp
## diff_distrib02 0.1220 1.0
## diff_distrib01 0.1200 1.0
## simFeat 0.0347 0.5
## norm_noise03 0.0412 0.5
## norm_noise08 0.0154 0.0
## norm_noise09 0.0234 0.0
Display as a datatable:
mpi$dt
# Plot of the XGboost generated model
plotModel(results, "XGB_SHAP_feat_imp", "diff_distrib01", "biDistrib")
sessionInfo()
## R Under development (unstable) (2025-01-20 r87609)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.1 LTS
##
## Matrix products: default
## BLAS: /home/biocbuild/bbs-3.21-bioc/R/lib/libRblas.so
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0 LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_GB LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: America/New_York
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats4 stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] caret_7.0-1 lattice_0.22-6
## [3] SummarizedExperiment_1.37.0 Biobase_2.67.0
## [5] GenomicRanges_1.59.1 GenomeInfoDb_1.43.4
## [7] IRanges_2.41.2 S4Vectors_0.45.2
## [9] BiocGenerics_0.53.6 generics_0.1.3
## [11] MatrixGenerics_1.19.1 matrixStats_1.5.0
## [13] gridExtra_2.3 ggforce_0.4.2
## [15] ggplot2_3.5.1 XAItest_0.99.24
##
## loaded via a namespace (and not attached):
## [1] pROC_1.18.5 rlang_1.1.5 magrittr_2.0.3
## [4] e1071_1.7-16 compiler_4.5.0 lime_0.5.3
## [7] vctrs_0.6.5 reshape2_1.4.4 stringr_1.5.1
## [10] fastmap_1.2.0 pkgconfig_2.0.3 shape_1.4.6.1
## [13] crayon_1.5.3 XVector_0.47.2 labeling_0.4.3
## [16] markdown_1.13 prodlim_2024.06.25 UCSC.utils_1.3.1
## [19] purrr_1.0.2 xfun_0.50 glmnet_4.1-8
## [22] cachem_1.1.0 randomForest_4.7-1.2 shapr_0.2.2
## [25] jsonlite_1.8.9 recipes_1.1.0 DelayedArray_0.33.4
## [28] tweenr_2.0.3 parallel_4.5.0 R6_2.5.1
## [31] bslib_0.9.0 stringi_1.8.4 limma_3.63.3
## [34] parallelly_1.42.0 rpart_4.1.24 xgboost_1.7.8.1
## [37] jquerylib_0.1.4 lubridate_1.9.4 Rcpp_1.0.14
## [40] assertthat_0.2.1 iterators_1.0.14 knitr_1.49
## [43] future.apply_1.11.3 Matrix_1.7-2 splines_4.5.0
## [46] nnet_7.3-20 timechange_0.3.0 tidyselect_1.2.1
## [49] yaml_2.3.10 abind_1.4-8 timeDate_4041.110
## [52] codetools_0.2-20 listenv_0.9.1 tibble_3.2.1
## [55] plyr_1.8.9 withr_3.0.2 evaluate_1.0.3
## [58] future_1.34.0 survival_3.8-3 proxy_0.4-27
## [61] polyclip_1.10-7 pillar_1.10.1 kernelshap_0.7.0
## [64] DT_0.33 foreach_1.5.2 commonmark_1.9.2
## [67] munsell_0.5.1 scales_1.3.0 globals_0.16.3
## [70] class_7.3-23 glue_1.8.0 tools_4.5.0
## [73] data.table_1.16.4 ModelMetrics_1.2.2.2 gower_1.0.2
## [76] grid_4.5.0 crosstalk_1.2.1 ipred_0.9-15
## [79] colorspace_2.1-1 nlme_3.1-167 GenomeInfoDbData_1.2.13
## [82] cli_3.6.3 S4Arrays_1.7.1 lava_1.8.1
## [85] dplyr_1.1.4 gtable_0.3.6 sass_0.4.9
## [88] digest_0.6.37 SparseArray_1.7.4 htmlwidgets_1.6.4
## [91] farver_2.1.2 htmltools_0.5.8.1 lifecycle_1.0.4
## [94] hardhat_1.4.1 httr_1.4.7 mime_0.12
## [97] statmod_1.5.0 MASS_7.3-64