## ----include=FALSE------------------------------------------------------------ knitr::opts_chunk$set( collapse = TRUE, comment = "#>", fig.width = 7, fig.height = 5, error = TRUE ) ## ----------------------------------------------------------------------------- library(fairGATE) library(dplyr) library(readxl) # Loading the UCI Adult Dataset data("adult_ready_small", package = "fairGATE") adult_data <- adult_ready_small adult <- adult_data %>% mutate( across(where(is.character), ~ trimws(.x)), income = as.integer(income) ) ## ----------------------------------------------------------------------------- # Dropping unwanted cols (i.e. numeric cols and those with high multicolinearity) cols_to_drop <- c("subjectid", "Row.names") # Ensure to perform other preprocessing steps such as one-hot endoing etc # Fully prepared data goes here prepared <- fairGATE::prepare_data( data = adult, outcome_var = "income", group_var = "sex", cols_to_remove= cols_to_drop ) ## ----include = FALSE---------------------------------------------------------- # --- Safety block: clean any non-finite or zero-variance columns before training --- X <- prepared$X fix_col <- function(x) { x[!is.finite(x)] <- NA if (all(is.na(x))) return(rep(0, length(x))) x[is.na(x)] <- stats::median(x, na.rm = TRUE) x } # Replace any non-finite values bad <- colSums(!is.finite(X)) > 0 if (any(bad)) X[, bad] <- apply(X[, bad, drop = FALSE], 2, fix_col) # Drop zero-variance columns (these can cause NaNs on scaling) zv <- apply(X, 2, function(v) sd(v, na.rm = TRUE) == 0) if (any(zv)) X <- X[, !zv, drop = FALSE] # Update prepared object prepared$X <- X prepared$feature_names <- colnames(X) # Quick sanity check stopifnot(sum(!is.finite(prepared$X)) == 0, ncol(prepared$X) > 0) ## ----train-demo, results='hide', message=FALSE, warning=FALSE----------------- # Train a small Gated Neural Network trained_model <- fairGATE::train_gnn( prepared_data = prepared, run_tuning = FALSE, # skip tuning for speed best_params = list( lr = 0.01, hidden_dim = 16, dropout_rate = 0.1, lambda = 0.0, temperature = 1.0 ), num_repeats = 2, # very short repeated split epochs = 20, # fast CRAN-safe runtime verbose = FALSE ) ## ----------------------------------------------------------------------------- # Run basic analysis basic_analyses <- analyse_gnn_results( gnn_results = trained_model, prepared_data = prepared ) # --- View all plots from the basic analysis --- cat("## ROC Curve\n") print(basic_analyses$roc_plot) cat("\n## Calibration Plot\n") print(basic_analyses$calibration_plot) cat("\n## Gate Weight Distribution\n") print(basic_analyses$gate_density_plot) cat("\n## Gate Entropy Distribution\n") print(basic_analyses$entropy_density_plot) ## ----------------------------------------------------------------------------- exp_res <- analyse_experts( gnn_results = trained_model, # from train_gnn() prepared_data = prepared, # from prepare_data() top_n_features = 15, # number of top features to visualise verbose = TRUE ) # View the main objects returned names(exp_res) #> [1] "all_weights" "means_by_group_wide" #> [3] "pairwise_differences" "difference_plot" #> [5] "multi_group_plot" "top_features_multi" # View first few feature importances head(exp_res$means_by_group_wide) # Example: view one pairwise difference table names(exp_res$pairwise_differences) #> [1] "Female_vs_Male" head(exp_res$pairwise_differences[[1]]) # Visualise feature specialisation if (!is.null(exp_res$difference_plot)) print(exp_res$difference_plot) if (!is.null(exp_res$multi_group_plot)) print(exp_res$multi_group_plot) ## ----------------------------------------------------------------------------- # Generate and print the Sankey plot p <- plot_sankey( prepared_data = prepared, # from prepare_data() gnn_results = trained_model, # from train_gnn() expert_results = exp_res, # from analyse_experts() verbose = TRUE ) print(p) ## ----f360_export, eval = FALSE, message = FALSE------------------------------- # export_f360_csv( # gnn_results = trained_model, # from train_gnn() # prepared_data = prepared, # from prepare_data() # path = "outputs/fairness360_input.csv", # include_gate_cols = TRUE, # include expert routing probabilities # threshold = 0.5, # classification threshold for binary outcome # verbose = TRUE # )