“The XAItest package includes several classic feature importance algorithms and supports the addition of new ones. To integrate an XGBoost model and generate its feature importance metrics using the SHAP package shapr.
The function should accept:
The function should return:
# Load the libraries
library(XAItest)
library(ggplot2)
library(ggforce)
library(SummarizedExperiment)
se_path <- system.file("extdata", "seClassif.rds", package="XAItest")
dataset_classif <- readRDS(se_path)
data_matrix <- assay(dataset_classif, "counts")
data_matrix <- t(data_matrix)
metadata <- as.data.frame(colData(dataset_classif))
df_simu_classif <- as.data.frame(cbind(data_matrix, y = metadata[['y']]))
for (col in names(df_simu_classif)) {
if (col != 'y') {
df_simu_classif[[col]] <- as.numeric(df_simu_classif[[col]])
}
}
featureImportanceXGBoost <- function(df, y="y", ...){
# Prepare data
matX <- as.matrix(df[, colnames(df) != y])
vecY <- df[[y]]
vecY <- as.character(vecY)
vecY[vecY == unique(vecY)[1]] <- 0
vecY[vecY == unique(vecY)[2]] <- 1
vecY <- as.numeric(vecY)
# Train the XGBoost model
model <- xgboost::xgboost(data = matX, label = vecY,
nrounds = 10, verbose = FALSE)
modelPredictions <- predict(model, matX)
modelPredictionsCat <- modelPredictions
modelPredictionsCat[modelPredictions < 0.5] <-
unique(as.character(df[[y]]))[1]
modelPredictionsCat[modelPredictions >= 0.5] <-
unique(as.character(df[[y]]))[2]
# Specifying the phi_0, i.e. the expected prediction without any features
p <- mean(vecY)
# Computing the actual Shapley values with kernelSHAP accounting
# for feature dependence using the empirical (conditional)
# distribution approach with bandwidth parameter sigma = 0.1 (default)
explanation <- shapr::explain(
model,
approach = "empirical",
x_explain = matX,
x_train = matX,
phi0 = p,
iterative = TRUE,
iterative_args = list(max_iter = 3)
)
results <- colMeans(abs(explanation$shapley_values_est), na.rm = TRUE)
results <- results[3:length(results)]
list(featImps = results, model = model, modelPredictions=modelPredictionsCat)
}
set.seed(123)
results <- XAI.test(dataset_classif,"y", simData = TRUE,
simPvalTarget = 0.001,
customFeatImps=
list("XGB_SHAP_feat_imp"=featureImportanceXGBoost),
defaultMethods = c("ttest", "lm")
)
## Note: Feature classes extracted from the model contains NA.
## Assuming feature classes from the data are correct.
## Success with message:
## max_n_coalitions is NULL or larger than or 2^n_features = 16384,
## and is therefore set to 2^n_features = 16384.
##
## ── Starting `shapr::explain()` at 2025-03-23 19:42:58 ──────────────────────────
## • Model class: <xgb.Booster>
## • Approach: empirical
## • Iterative estimation: TRUE
## • Number of feature-wise Shapley values: 14
## • Number of observations to explain: 100
## • Computations (temporary) saved at:
## '/tmp/Rtmp9cHnvJ/shapr_obj_2c079b5a655641.rds'
##
## ── iterative computation started ──
##
## ── Iteration 1 ─────────────────────────────────────────────────────────────────
## ℹ Using 200 of 16384 coalitions, 200 new.
##
## ── Iteration 2 ─────────────────────────────────────────────────────────────────
## ℹ Using 512 of 16384 coalitions, 312 new.
##
## ── Iteration 3 ─────────────────────────────────────────────────────────────────
## ℹ Using 880 of 16384 coalitions, 368 new.
The mapPvalImportance function reveals that both the custom XGB_SHAP_feat_imp and other feature importance metrics identify the biDistrib feature as significant.
Display as a data.frame:
mpi <- mapPvalImportance(results, refPvalColumn = "ttest_adjPval", refPval = 0.001)
head(mpi$df)
## ttest_pval isSign_ttest_pval ttest_adjPval isSign_ttest_adjPval
## diff_distrib02 1.02e-43 1 1.42e-42 1
## diff_distrib01 7.93e-37 1 1.11e-35 1
## simFeat 3.89e-05 1 5.45e-04 1
## norm_noise03 7.97e-02 0 1.00e+00 0
## norm_noise08 1.11e-01 0 1.00e+00 0
## norm_noise09 1.44e-01 0 1.00e+00 0
## lm_pval isSign_lm_pval lm_adjPval isSign_lm_adjPval
## diff_distrib02 3.57e-21 1 5.00e-20 1
## diff_distrib01 1.27e-11 1 1.78e-10 1
## simFeat 7.98e-02 0 1.00e+00 0
## norm_noise03 6.11e-01 0 1.00e+00 0
## norm_noise08 8.86e-01 0 1.00e+00 0
## norm_noise09 5.16e-01 0 1.00e+00 0
## XGB_SHAP_feat_imp isSign_XGB_SHAP_feat_imp
## diff_distrib02 0.0908 1.0
## diff_distrib01 0.1120 1.0
## simFeat 0.0270 0.5
## norm_noise03 0.0322 0.5
## norm_noise08 0.0155 0.0
## norm_noise09 0.0344 0.5
Display as a datatable:
mpi$dt
# Plot of the XGboost generated model
plotModel(results, "XGB_SHAP_feat_imp", "diff_distrib01", "biDistrib")
sessionInfo()
## R Under development (unstable) (2025-03-13 r87965)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.2 LTS
##
## Matrix products: default
## BLAS: /home/biocbuild/bbs-3.21-bioc/R/lib/libRblas.so
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0 LAPACK version 3.12.0
##
## locale:
## [1] LC_CTYPE=en_US.UTF-8 LC_NUMERIC=C
## [3] LC_TIME=en_GB LC_COLLATE=C
## [5] LC_MONETARY=en_US.UTF-8 LC_MESSAGES=en_US.UTF-8
## [7] LC_PAPER=en_US.UTF-8 LC_NAME=C
## [9] LC_ADDRESS=C LC_TELEPHONE=C
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C
##
## time zone: America/New_York
## tzcode source: system (glibc)
##
## attached base packages:
## [1] stats4 stats graphics grDevices utils datasets methods
## [8] base
##
## other attached packages:
## [1] caret_7.0-1 lattice_0.22-6
## [3] SummarizedExperiment_1.37.0 Biobase_2.67.0
## [5] GenomicRanges_1.59.1 GenomeInfoDb_1.43.4
## [7] IRanges_2.41.3 S4Vectors_0.45.4
## [9] BiocGenerics_0.53.6 generics_0.1.3
## [11] MatrixGenerics_1.19.1 matrixStats_1.5.0
## [13] gridExtra_2.3 ggforce_0.4.2
## [15] ggplot2_3.5.1 XAItest_0.99.25
##
## loaded via a namespace (and not attached):
## [1] pROC_1.18.5 rlang_1.1.5 magrittr_2.0.3
## [4] e1071_1.7-16 compiler_4.6.0 lime_0.5.3
## [7] vctrs_0.6.5 reshape2_1.4.4 stringr_1.5.1
## [10] fastmap_1.2.0 pkgconfig_2.0.3 shape_1.4.6.1
## [13] crayon_1.5.3 XVector_0.47.2 labeling_0.4.3
## [16] markdown_1.13 prodlim_2024.06.25 UCSC.utils_1.3.1
## [19] purrr_1.0.4 xfun_0.51 glmnet_4.1-8
## [22] cachem_1.1.0 randomForest_4.7-1.2 shapr_1.0.2
## [25] jsonlite_1.9.1 recipes_1.2.0 DelayedArray_0.33.6
## [28] tweenr_2.0.3 parallel_4.6.0 R6_2.6.1
## [31] bslib_0.9.0 stringi_1.8.4 limma_3.63.10
## [34] parallelly_1.42.0 rpart_4.1.24 xgboost_1.7.8.1
## [37] jquerylib_0.1.4 lubridate_1.9.4 Rcpp_1.0.14
## [40] assertthat_0.2.1 iterators_1.0.14 knitr_1.50
## [43] future.apply_1.11.3 Matrix_1.7-3 splines_4.6.0
## [46] nnet_7.3-20 timechange_0.3.0 tidyselect_1.2.1
## [49] yaml_2.3.10 abind_1.4-8 timeDate_4041.110
## [52] codetools_0.2-20 listenv_0.9.1 tibble_3.2.1
## [55] plyr_1.8.9 withr_3.0.2 evaluate_1.0.3
## [58] future_1.34.0 survival_3.8-3 proxy_0.4-27
## [61] polyclip_1.10-7 pillar_1.10.1 kernelshap_0.7.0
## [64] DT_0.33 foreach_1.5.2 commonmark_1.9.5
## [67] munsell_0.5.1 scales_1.3.0 globals_0.16.3
## [70] class_7.3-23 glue_1.8.0 tools_4.6.0
## [73] data.table_1.17.0 ModelMetrics_1.2.2.2 gower_1.0.2
## [76] grid_4.6.0 crosstalk_1.2.1 ipred_0.9-15
## [79] colorspace_2.1-1 nlme_3.1-167 GenomeInfoDbData_1.2.14
## [82] cli_3.6.4 S4Arrays_1.7.3 lava_1.8.1
## [85] dplyr_1.1.4 gtable_0.3.6 sass_0.4.9
## [88] digest_0.6.37 progressr_0.15.1 SparseArray_1.7.7
## [91] htmlwidgets_1.6.4 farver_2.1.2 htmltools_0.5.8.1
## [94] lifecycle_1.0.4 hardhat_1.4.1 httr_1.4.7
## [97] mime_0.13 statmod_1.5.0 MASS_7.3-65