XAItest: Enhancing Feature Discovery with eXplainable AI

Ghislain FIEVET ghislain.fievet@gmail.com

Add a custom feature importance function

“The XAItest package includes several classic feature importance algorithms and supports the addition of new ones. To integrate an XGBoost model and generate its feature importance metrics using the SHAP package shapr.

The following function structure is required

The function should accept:

The function should return:

Load libraries and classification dataset

# Load the libraries
library(XAItest)
library(ggplot2)
library(ggforce)
library(SummarizedExperiment)
se_path <- system.file("extdata", "seClassif.rds", package="XAItest")
dataset_classif <- readRDS(se_path)

data_matrix <- assay(dataset_classif, "counts")
data_matrix <- t(data_matrix)
metadata <- as.data.frame(colData(dataset_classif))
df_simu_classif <- as.data.frame(cbind(data_matrix, y = metadata[['y']]))
for (col in names(df_simu_classif)) {
    if (col != 'y') {
        df_simu_classif[[col]] <- as.numeric(df_simu_classif[[col]])
    }
}

Build and use the custom feature importance function

featureImportanceXGBoost <- function(df, y="y", ...){
    # Prepare data
    matX <- as.matrix(df[, colnames(df) != y])
    vecY <- df[[y]]
    vecY <- as.character(vecY)
    vecY[vecY == unique(vecY)[1]] <- 0
    vecY[vecY == unique(vecY)[2]] <- 1
    vecY <- as.numeric(vecY)
    
    # Train the XGBoost model
    model <- xgboost::xgboost(data = matX, label = vecY,
                                nrounds = 10, verbose = FALSE)
    modelPredictions <- predict(model, matX)
    modelPredictionsCat <- modelPredictions
    modelPredictionsCat[modelPredictions < 0.5] <-
                                unique(as.character(df[[y]]))[1]
    modelPredictionsCat[modelPredictions >= 0.5] <-
                                unique(as.character(df[[y]]))[2]

    # Specifying the phi_0, i.e. the expected prediction without any features
    p <- mean(vecY)
    # Computing the actual Shapley values with kernelSHAP accounting
    # for feature dependence using the empirical (conditional)
    # distribution approach with bandwidth parameter sigma = 0.1 (default)
    explanation <- shapr::explain(
        model,
        approach = "empirical",
        x_explain = matX,
        x_train = matX,
        phi0 = p,
        iterative = TRUE,
        iterative_args = list(max_iter = 3)
    )
    results <- colMeans(abs(explanation$shapley_values_est), na.rm = TRUE)
    results <- results[3:length(results)]
    list(featImps = results, model = model, modelPredictions=modelPredictionsCat)
}
set.seed(123)
results <- XAI.test(dataset_classif,"y", simData = TRUE,
                   simPvalTarget = 0.001,
                   customFeatImps=
                   list("XGB_SHAP_feat_imp"=featureImportanceXGBoost),
                   defaultMethods = c("ttest", "lm")
                  )
## Note: Feature classes extracted from the model contains NA.
## Assuming feature classes from the data are correct.
## Success with message:
## max_n_coalitions is NULL or larger than or 2^n_features = 16384, 
## and is therefore set to 2^n_features = 16384.
## 
## ── Starting `shapr::explain()` at 2025-03-23 19:42:58 ──────────────────────────
## • Model class: <xgb.Booster>
## • Approach: empirical
## • Iterative estimation: TRUE
## • Number of feature-wise Shapley values: 14
## • Number of observations to explain: 100
## • Computations (temporary) saved at:
## '/tmp/Rtmp9cHnvJ/shapr_obj_2c079b5a655641.rds'
## 
## ── iterative computation started ──
## 
## ── Iteration 1 ─────────────────────────────────────────────────────────────────
## ℹ Using 200 of 16384 coalitions, 200 new.
## 
## ── Iteration 2 ─────────────────────────────────────────────────────────────────
## ℹ Using 512 of 16384 coalitions, 312 new.
## 
## ── Iteration 3 ─────────────────────────────────────────────────────────────────
## ℹ Using 880 of 16384 coalitions, 368 new.

The mapPvalImportance function reveals that both the custom XGB_SHAP_feat_imp and other feature importance metrics identify the biDistrib feature as significant.

Display as a data.frame:

mpi <- mapPvalImportance(results, refPvalColumn = "ttest_adjPval", refPval = 0.001)
head(mpi$df)
##                ttest_pval isSign_ttest_pval ttest_adjPval isSign_ttest_adjPval
## diff_distrib02   1.02e-43                 1      1.42e-42                    1
## diff_distrib01   7.93e-37                 1      1.11e-35                    1
## simFeat          3.89e-05                 1      5.45e-04                    1
## norm_noise03     7.97e-02                 0      1.00e+00                    0
## norm_noise08     1.11e-01                 0      1.00e+00                    0
## norm_noise09     1.44e-01                 0      1.00e+00                    0
##                 lm_pval isSign_lm_pval lm_adjPval isSign_lm_adjPval
## diff_distrib02 3.57e-21              1   5.00e-20                 1
## diff_distrib01 1.27e-11              1   1.78e-10                 1
## simFeat        7.98e-02              0   1.00e+00                 0
## norm_noise03   6.11e-01              0   1.00e+00                 0
## norm_noise08   8.86e-01              0   1.00e+00                 0
## norm_noise09   5.16e-01              0   1.00e+00                 0
##                XGB_SHAP_feat_imp isSign_XGB_SHAP_feat_imp
## diff_distrib02            0.0908                      1.0
## diff_distrib01            0.1120                      1.0
## simFeat                   0.0270                      0.5
## norm_noise03              0.0322                      0.5
## norm_noise08              0.0155                      0.0
## norm_noise09              0.0344                      0.5

Display as a datatable:

mpi$dt

# Plot of the XGboost generated model
plotModel(results, "XGB_SHAP_feat_imp", "diff_distrib01", "biDistrib")

plot of chunk unnamed-chunk-2

sessionInfo()
## R Under development (unstable) (2025-03-13 r87965)
## Platform: x86_64-pc-linux-gnu
## Running under: Ubuntu 24.04.2 LTS
## 
## Matrix products: default
## BLAS:   /home/biocbuild/bbs-3.21-bioc/R/lib/libRblas.so 
## LAPACK: /usr/lib/x86_64-linux-gnu/lapack/liblapack.so.3.12.0  LAPACK version 3.12.0
## 
## locale:
##  [1] LC_CTYPE=en_US.UTF-8       LC_NUMERIC=C              
##  [3] LC_TIME=en_GB              LC_COLLATE=C              
##  [5] LC_MONETARY=en_US.UTF-8    LC_MESSAGES=en_US.UTF-8   
##  [7] LC_PAPER=en_US.UTF-8       LC_NAME=C                 
##  [9] LC_ADDRESS=C               LC_TELEPHONE=C            
## [11] LC_MEASUREMENT=en_US.UTF-8 LC_IDENTIFICATION=C       
## 
## time zone: America/New_York
## tzcode source: system (glibc)
## 
## attached base packages:
## [1] stats4    stats     graphics  grDevices utils     datasets  methods  
## [8] base     
## 
## other attached packages:
##  [1] caret_7.0-1                 lattice_0.22-6             
##  [3] SummarizedExperiment_1.37.0 Biobase_2.67.0             
##  [5] GenomicRanges_1.59.1        GenomeInfoDb_1.43.4        
##  [7] IRanges_2.41.3              S4Vectors_0.45.4           
##  [9] BiocGenerics_0.53.6         generics_0.1.3             
## [11] MatrixGenerics_1.19.1       matrixStats_1.5.0          
## [13] gridExtra_2.3               ggforce_0.4.2              
## [15] ggplot2_3.5.1               XAItest_0.99.25            
## 
## loaded via a namespace (and not attached):
##  [1] pROC_1.18.5             rlang_1.1.5             magrittr_2.0.3         
##  [4] e1071_1.7-16            compiler_4.6.0          lime_0.5.3             
##  [7] vctrs_0.6.5             reshape2_1.4.4          stringr_1.5.1          
## [10] fastmap_1.2.0           pkgconfig_2.0.3         shape_1.4.6.1          
## [13] crayon_1.5.3            XVector_0.47.2          labeling_0.4.3         
## [16] markdown_1.13           prodlim_2024.06.25      UCSC.utils_1.3.1       
## [19] purrr_1.0.4             xfun_0.51               glmnet_4.1-8           
## [22] cachem_1.1.0            randomForest_4.7-1.2    shapr_1.0.2            
## [25] jsonlite_1.9.1          recipes_1.2.0           DelayedArray_0.33.6    
## [28] tweenr_2.0.3            parallel_4.6.0          R6_2.6.1               
## [31] bslib_0.9.0             stringi_1.8.4           limma_3.63.10          
## [34] parallelly_1.42.0       rpart_4.1.24            xgboost_1.7.8.1        
## [37] jquerylib_0.1.4         lubridate_1.9.4         Rcpp_1.0.14            
## [40] assertthat_0.2.1        iterators_1.0.14        knitr_1.50             
## [43] future.apply_1.11.3     Matrix_1.7-3            splines_4.6.0          
## [46] nnet_7.3-20             timechange_0.3.0        tidyselect_1.2.1       
## [49] yaml_2.3.10             abind_1.4-8             timeDate_4041.110      
## [52] codetools_0.2-20        listenv_0.9.1           tibble_3.2.1           
## [55] plyr_1.8.9              withr_3.0.2             evaluate_1.0.3         
## [58] future_1.34.0           survival_3.8-3          proxy_0.4-27           
## [61] polyclip_1.10-7         pillar_1.10.1           kernelshap_0.7.0       
## [64] DT_0.33                 foreach_1.5.2           commonmark_1.9.5       
## [67] munsell_0.5.1           scales_1.3.0            globals_0.16.3         
## [70] class_7.3-23            glue_1.8.0              tools_4.6.0            
## [73] data.table_1.17.0       ModelMetrics_1.2.2.2    gower_1.0.2            
## [76] grid_4.6.0              crosstalk_1.2.1         ipred_0.9-15           
## [79] colorspace_2.1-1        nlme_3.1-167            GenomeInfoDbData_1.2.14
## [82] cli_3.6.4               S4Arrays_1.7.3          lava_1.8.1             
## [85] dplyr_1.1.4             gtable_0.3.6            sass_0.4.9             
## [88] digest_0.6.37           progressr_0.15.1        SparseArray_1.7.7      
## [91] htmlwidgets_1.6.4       farver_2.1.2            htmltools_0.5.8.1      
## [94] lifecycle_1.0.4         hardhat_1.4.1           httr_1.4.7             
## [97] mime_0.13               statmod_1.5.0           MASS_7.3-65