This notebook consolidates the evaluation of two fine-tuned XGBoost models for predicting hospital readmission and mortality. It prepares longitudinal data for survival analysis by eliminating immortal time bias. The workflow manages competing outcomes through Cause-Specific Hazard preparations. Finally, it constructs an evaluation grid to assess the model’s predictive capabilities for both events across defined time intervals.
from IPython.display import display, Markdownifisinstance(imputations_list_jan26, list) andlen(imputations_list_jan26) >0: display(Markdown(f"**First element type:** `{type(imputations_list_jan26[0])}`"))ifisinstance(imputations_list_jan26[0], dict): display(Markdown(f"**First element keys:** `{list(imputations_list_jan26[0].keys())}`"))elifisinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)): display(Markdown(f"**First element shape:** `{imputations_list_jan26[0].shape}`"))
First element type:<class 'pandas.DataFrame'>
First element shape:(88504, 56)
This code block:
Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
Specifies the file_path: It points to the .pkl file you selected.
Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.
Format data
Due to inconsistencies and structural heterogeneity across previously merged datasets, we decided not to proceed with a direct inspection and comparison of column names between the first imputed dataset from imputations_list_jan26 (which likely included dummy-encoded variables) and imputation_nodum_1 (which likely retained non–dummy-encoded variables).
Instead, we reconstructed the analytic datasets de novo using the most recent source files available in the original directory (BASE_DIR). Time-to-event variables were re-derived to ensure internal consistency. Variables that could introduce information leakage (e.g., time from admission) were excluded, and the center identifier variable was removed prior to modeling.
Code
#1.2. Build Surv objects from df_finalfrom IPython.display import display, Markdownfrom sksurv.util import Survfor i inrange(1, 6):# Get the DataFrame df =globals()[f"imputation_nodum_{i}"]# Extract time and event arrays time_readm = df["readmit_time_from_disch_m"].to_numpy() event_readm = (df["readmit_event"].to_numpy() ==1) time_death = df["death_time_from_disch_m"].to_numpy() event_death = (df["death_event"].to_numpy() ==1)# Create survival objects y_surv_readm = Surv.from_arrays(event=event_readm, time=time_readm) y_surv_death = Surv.from_arrays(event=event_death, time=time_death)# Store in global variables (optional but matches your pattern)globals()[f"y_surv_readm_{i}"] = y_surv_readmglobals()[f"y_surv_death_{i}"] = y_surv_death# Print info display(Markdown(f"\n--- Imputation {i} ---")) display(Markdown(f"**y_surv_readm dtype:** {y_surv_readm.dtype}\n"f"**shape:** {y_surv_readm.shape}" )) display(Markdown(f"**y_surv_death dtype:** {y_surv_death.dtype}\n"f"**shape:** {y_surv_death.shape}" ))
fold_output("Show imputation_nodum_1 (newer database) glimpse",lambda: glimpse(imputation_nodum_1))fold_output("Show first db of imputations_list_jan26 (older) glimpse",lambda: glimpse(imputations_list_jan26[0]))
For each imputed dataset (1–5), we identified and removed predictors with zero variance, as they provide no useful information and can destabilize models. We printed the dropped variables and produced a cleaned version of each design matrix. This ensures that all downstream analyses use only informative predictors.
Code
# Keep only these objectsobjects_to_keep = {"objects_to_keep","imputation_nodum_1","imputation_nodum_2","imputation_nodum_3","imputation_nodum_4","imputation_nodum_5","y_surv_readm","y_surv_death","imputations_list_jan26"}import typesfor name inlist(globals().keys()): obj =globals()[name]if ( name notin objects_to_keepandnot name.startswith("_")andnotcallable(obj)andnotisinstance(obj, types.ModuleType) # <- protects ALL modules ):delglobals()[name]
Code
from IPython.display import display, Markdown# 1. Define columns to exclude (same as before)target_cols = ["readmit_time_from_disch_m","readmit_event","death_time_from_disch_m","death_event",]leak_time_cols = ["readmit_time_from_adm_m","death_time_from_adm_m",]center_id = ["center_id"]cols_to_exclude = target_cols + center_id + leak_time_cols# 2. Create list of your EXISTING imputation DataFrames (1-5)imputed_dfs = [ imputation_nodum_1, imputation_nodum_2, imputation_nodum_3, imputation_nodum_4, imputation_nodum_5]# 3. Preprocessing loopX_reduced_list = []for d, df inenumerate(imputed_dfs): imputation_num = d +1# Convert 0-index to 1-index for display display(Markdown(f"\n=== Imputation dataset {imputation_num} ==="))# a) Identify and drop constant predictors const_mask = (df.nunique(dropna=False) <=1) dropped_const = df.columns[const_mask].tolist() display(Markdown(f"**Constant predictors dropped ({len(dropped_const)}):**")) display(Markdown(f"{dropped_const if dropped_const else'None'}"))# b) Remove constant columns X_reduced = df.loc[:, ~const_mask]# c) Drop target/leakage columns (if present) cols_to_drop = [col for col in cols_to_exclude if col in X_reduced.columns]if cols_to_drop: X_reduced = X_reduced.drop(columns=cols_to_drop) display(Markdown(f"**Dropped target/leakage columns:** {cols_to_drop}"))else: display(Markdown("No target/leakage columns found to drop"))# d) Store cleaned DataFrame X_reduced_list.append(X_reduced)# e) Report shapes display(Markdown(f"**Original shape:** {df.shape}")) display(Markdown(f"**Cleaned shape:** {X_reduced.shape} "f"(removed {df.shape[1] - X_reduced.shape[1]} columns)" ))display(Markdown("\n✅ **Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.**"))
A structured preprocessing pipeline was implemented prior to modeling. Ordered categorical variables (e.g., housing status, educational attainment, clinical evaluations, and substance use frequency) were manually mapped to numeric scales reflecting their natural ordering. For nominal categorical variables, prespecified reference categories were enforced to ensure consistent baseline comparisons across imputations. All remaining categorical predictors were then converted to dummy variables using one-hot encoding with the first category dropped to prevent multicollinearity. The procedure was applied consistently across all imputed datasets to ensure harmonized model inputs.
Code
import pandas as pdimport numpy as npfrom sklearn.preprocessing import OrdinalEncoderimport pandas as pdimport numpy as npfrom pandas.api.types import CategoricalDtypedef preprocess_features_robust(df): df_clean = df.copy()# ---------------------------------------------------------# 1. Ordinal encoding (your existing code)# --------------------------------------------------------- ordered_mappings = {# --- NEW: Housing & Urbanicity ---"tenure_status_household": {"illegal settlement": 4, # Situación Calle"stays temporarily with a relative": 3, # Allegado"others": 2, # En pensión / Otros"renting": 1, # Arrendando"owner/transferred dwellings/pays dividends": 0# Vivienda Propia },"urbanicity_cat": {"1.Rural": 2,"2.Mixed": 1,"3.Urban": 0 },# --- Clinical Evaluations (Minimo -> Intermedio -> Alto) ---"evaluacindelprocesoteraputico": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_consumo": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_fam": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_relinterp": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_ocupacion": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_sm": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_fisica": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},"eva_transgnorma": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},# --- Frequency (Less freq -> More freq) ---"prim_sub_freq_rec": {"1.≤1 day/wk": 0,"2.2–6 days/wk": 1,"3.Daily": 2 },# --- Education (Less -> More) ---"ed_attainment_corr": {"3-Completed primary school or less": 2,"2-Completed high school or less": 1,"1-More than high school": 0 } }for col, mapping in ordered_mappings.items():if col in df_clean.columns: df_clean[col] = df_clean[col].astype(str).str.strip() df_clean[col] = df_clean[col].map(mapping) n_missing = df_clean[col].isnull().sum()if n_missing >0:if n_missing ==len(df_clean):print(f"⚠️ WARNING: Mapping failed completely for '{col}'.") mode_val = df_clean[col].mode()[0] df_clean[col] = df_clean[col].fillna(mode_val)# ---------------------------------------------------------# 2. FORCE reference categories for dummies# --------------------------------------------------------- dummy_reference = {"sex_rec": "man","plan_type_corr": "ambulatory","marital_status_rec": "married/cohabiting","cohabitation": "alone","sub_dep_icd10_status": "hazardous consumption","tr_outcome": "completion","adm_motive": "spontaneous consultation","tipo_de_vivienda_rec2": "formal housing","plan_type_corr": "pg-pab","occupation_condition_corr24": "employed","any_violence": "0.No domestic violence/sex abuse","first_sub_used": "marijuana","primary_sub_mod": "marijuana", }for col, ref in dummy_reference.items():if col in df_clean.columns: df_clean[col] = df_clean[col].astype(str).str.strip() cats = df_clean[col].unique().tolist()if ref in cats: new_order = [ref] + [c for c in cats if c != ref] cat_type = CategoricalDtype(categories=new_order, ordered=False) df_clean[col] = df_clean[col].astype(cat_type)else:print(f"⚠️ Reference '{ref}' not found in {col}")# ---------------------------------------------------------# 3. One-hot encoding# --------------------------------------------------------- df_final = pd.get_dummies(df_clean, drop_first=True, dtype=float)return df_finalX_encoded_list_final = [preprocess_features_robust(X) for X in X_reduced_list]X_encoded_list_final = [clean_names(X) for X in X_encoded_list_final]
Code
from IPython.display import display, Markdown# 1. DIAGNOSTIC: Check exact string valuesdisplay(Markdown("### --- Diagnostic Check ---"))sample_df = X_encoded_list_final[0]if'tenure_status_household'in sample_df.columns: display(Markdown("**Unique values in 'tenure_status_household':**")) display(Markdown(str(sample_df['tenure_status_household'].unique())))else: display(Markdown("❌ 'tenure_status_household' is missing entirely from input data!"))if'urbanicity_cat'in sample_df.columns: display(Markdown("**Unique values in 'urbanicity_cat':**")) display(Markdown(str(sample_df['urbanicity_cat'].unique())))if'ed_attainment_corr'in sample_df.columns: display(Markdown("**Unique values in 'ed_attainment_corr':**")) display(Markdown(str(sample_df['ed_attainment_corr'].unique())))
— Diagnostic Check —
Unique values in ‘tenure_status_household’:
[3 0 1 2 4]
Unique values in ‘urbanicity_cat’:
[0 1 2]
Unique values in ‘ed_attainment_corr’:
[1 2 0]
We recoded first substance use so small categories are grouped into Others
Code
# Columns to combinecols_to_group = ["first_sub_used_opioids","first_sub_used_others","first_sub_used_hallucinogens","first_sub_used_inhalants","first_sub_used_tranquilizers_hypnotics","first_sub_used_amphetamine_type_stimulants",]# Loop over datasets 0–4 and modify in placefor i inrange(5): df = X_encoded_list_final[i].copy()# Collapse into one dummy: if any of these == 1, mark as 1 df["first_sub_used_other"] = df[cols_to_group].max(axis=1)# Drop the rest except the new combined column df = df.drop(columns=[c for c in cols_to_group if c !="first_sub_used_other"])# Replace the dataset in the original list X_encoded_list_final[i] = df
Code
import sysfold_output("Show first db of X_encoded_list_final (newer) glimpse",lambda: glimpse(X_encoded_list_final[0]))
Show first db of X_encoded_list_final (newer) glimpse
For each imputed dataset, we fitted two regularized Cox models (one for readmission and one for death) using Coxnet, which applies elastic-net penalization with a strong LASSO component to enable variable selection. The loop fits both models on every imputation, prints basic model information, and stores all fitted models so they can later be combined or compared across imputations.
Create bins for followup (landmarks)
We extracted the observed event times and corresponding event indicators directly from the structured survival objects (y_surv_readm and y_surv_death). Using the observed event times, we constructed evaluation grids based on the 5th to 95th percentiles of the event-time distribution. These grids define standardized time points at which model performance is assessed for both readmission and mortality outcomes.
Code
import numpy as npfrom IPython.display import display, Markdown# Extract event times directly from structured arraysevent_times_readm = y_surv_readm["time"][y_surv_readm["event"]]event_times_death = y_surv_death["time"][y_surv_death["event"]]# Build evaluation grids (5th–95th percentiles, 50 points)times_eval_readm = np.unique( np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50)))times_eval_death = np.unique( np.quantile(event_times_death, np.linspace(0.05, 0.95, 50)))# Display only final resultdisplay(Markdown(f"**Eval times (readmission):** `{times_eval_readm[:5]}` ... `{times_eval_readm[-5:]}`"))display(Markdown(f"**Eval times (death):** `{times_eval_death[:5]}` ... `{times_eval_death[-5:]}`"))
First, we eliminated inmortal time bias (dead patients look like without readmission).
This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.
Code
import numpy as np# Step 3. Replicate across imputations (safe copies)n_imputations =len(X_encoded_list_final)y_surv_readm_list = [y_surv_readm.copy() for _ inrange(n_imputations)]y_surv_death_list = [y_surv_death.copy() for _ inrange(n_imputations)]def correct_competing_risks(y_readm_list, y_death_list): corrected = []for y_readm, y_death inzip(y_readm_list, y_death_list): y_corr = y_readm.copy()# death observed and occurs before (or at) readmission/censoring time mask = (y_death["event"]) & (y_death["time"] < y_corr["time"]) y_corr["event"][mask] =False y_corr["time"][mask] = y_death["time"][mask] corrected.append(y_corr)return corrected# Step 4. Apply correctiony_surv_readm_list_corrected = correct_competing_risks( y_surv_readm_list, y_surv_death_list)
Code
# Check type and lengthtype(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)# Look at the first elementy_surv_readm_list_corrected[0][:5] # first 5 rows
from IPython.display import display, HTMLimport htmldef nb_print(*args, sep=" "): msg = sep.join(str(a) for a in args) display(HTML(f"<pre style='margin:0'>{html.escape(msg)}</pre>"))
The fully preprocessed and encoded feature matrices were renamed from X_encoded_list_final to imputations_list_mar26 to reflect the finalized February 2026 analytic version of the imputed datasets.
This object contains the harmonized, ordinal-encoded, and one-hot encoded predictor matrices for all five imputations and will serve as the definitive input for subsequent modeling procedures.
# counts per stratum in train/testtrain_counts = sdiag.iloc[train_idx].value_counts()test_counts = sdiag.iloc[test_idx].value_counts()min_train =int(train_counts.min())min_test =int(test_counts.min())nb_print_md(f"**Min stratum count in TRAIN (used strata):** `{min_train}`")nb_print_md(f"**Min stratum count in TEST (used strata):** `{min_test}`")# strata that got 0 in test or 0 in trainzero_in_test =sorted(set(train_counts.index) -set(test_counts.index))zero_in_train =sorted(set(test_counts.index) -set(train_counts.index))nb_print_md(f"**Strata with 0 in TEST:** `{len(zero_in_test)}`")nb_print_md(f"**Strata with 0 in TRAIN:** `{len(zero_in_train)}`")# show examples with their full-data countsiflen(zero_in_test) >0: ex = zero_in_test[:10] nb_print_md(f"**Examples 0 in TEST (up to 10):** `{ex}`") nb_print_md(f"**Full-data counts:** `{[int(sdiag.value_counts()[k]) for k in ex]}`")
# Use the actual stratification mode that was used to splitstrata_used, strat_mode, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])s = pd.Series(strata_used)train_strata =set(s.iloc[train_idx].unique())test_strata =set(s.iloc[test_idx].unique())missing_in_test =sorted(train_strata - test_strata)missing_in_train =sorted(test_strata - train_strata)display(Markdown(f"**Strata used:** `{strat_mode}`"))display(Markdown(f"**# strata in train:** `{len(train_strata)}` | **# strata in test:** `{len(test_strata)}`"))display(Markdown(f"**Strata present in train but missing in test:** `{len(missing_in_test)}`"))display(Markdown(f"**Strata present in test but missing in train:** `{len(missing_in_train)}`"))
Strata used:fallback(plan+readm+death)
# strata in train:20 | # strata in test:20
Strata present in train but missing in test:0
Strata present in test but missing in train:0
Code
from pathlib import Pathimport pandas as pdimport numpy as npfrom IPython.display import display, MarkdownPROJECT_ROOT = find_project_root() # no hardcoded absolute pathOUT_DIR = PROJECT_ROOT /"_out"OUT_DIR.mkdir(parents=True, exist_ok=True)SPLIT_PARQUET = OUT_DIR /f"comb_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"split_df = pd.DataFrame({"row_id": np.arange(n),"is_train": np.isin(np.arange(n), train_idx)})split_df.to_parquet(SPLIT_PARQUET, index=False)display(Markdown(f"**Project root:** `{PROJECT_ROOT}`"))display(Markdown(f"**Saved split to:** `{SPLIT_PARQUET}`"))
In this section, we transition to a Gradient Boosted Decision Tree (GBDT) framework using XGBoost. This approach serves as a robust, non-linear benchmark to validate findings from the neural network, specifically optimized for high-imbalance survival data (approx. 4% death rate).
Full Metrics
The updated pipeline uses a cause-specific (death-censored) framework for readmission and reports discrimination/calibration metrics accordingly.
Trains two cause-specific XGBoost survival models (survival:cox): one for death and one for readmission.
Uses 5-fold stratified cross-validation across all imputations, with composite stratification (event type × treatment plan).
Encodes survival targets for XGBoost as +time if event, −time if censored.
Uses output_margin=True, clips risk scores to [-15, 15], then exponentiates for hazard-scale calculations.
Converts fold-level risk scores to survival probabilities using the Breslow baseline hazard estimator (for both endpoints).
Computes Global and time-dependent Uno’s C-index with a safe fallback (tau truncation when needed).
Computes Global IBS and time-dependent IBS (only when at least 2 time points are available), plus horizon-specific point Brier Score.
For horizon-specific classification metrics, restricts to valid case/control subjects at each horizon and estimates thresholds on training folds only (F1 for death, Youden for readmission).
Uses pre-encoded plan variables (no re-dummying in this step) and applies early stopping (up to 5000 rounds), logging best_iteration per fold.
Stores fold artifacts for reproducibility: predictions, baseline hazards, and exact CV train/validation splits.
Aggregates out-of-fold SHAP values across folds and imputations by patient index, with index re-alignment before export.
Saves a complete artifact set:
xgb6_corr_DUAL_metrics_<timestamp>.csv
xgb6_corr_DUAL_final_ev_hyp_<timestamp>.pkl
xgb6_corr_DUAL_BaselineHazards_<timestamp>.pkl
xgb6_corr_DUAL_CV_Splits_<timestamp>.pkl
xgb6_corr_DUAL_SHAP_Aggregated_<timestamp>.pkl (when SHAP is computed)
📌 5 Core Assumptions of This Pipeline (Cause-Specific Framework)
Cause-specific hazard assumption
Death and readmission are modeled separately; competing events are treated as censoring for each endpoint.
Independent censoring assumption
Censoring (including competing-event censoring) is assumed conditionally independent of the event process given covariates.
Cox-type risk structure assumption survival:cox optimizes a Cox partial likelihood, so effects are interpreted on the log-risk scale.
Imputation and aggregation assumption
Multiple imputations are analyzed via repeated CV and cross-imputation aggregation (not formal Rubin pooling for every metric).
This script produces competing-risk calibration curves for readmission using Aalen-Johansen.
Predictions come from cross-validated folds, avoiding optimistic bias.
It correctly treats death as a competing event, not simple censoring.
Observed risks are estimated using the Aalen-Johansen estimator, which is appropriate for CIF.
Predicted risk is defined as: 1 − S_readmission(t) from the cause-specific Cox model.
Patients are grouped into quantile-based risk bins (default 10).
Calibration is evaluated at multiple time horizons simultaneously.
A patient-level master dataset is saved for future bootstrap calibration inference.
The pipeline separates modeling and calibration — improving reproducibility.
Output figures are publication-ready (PNG 300dpi + PDF).
5 Assumptions of This Code
Cause-Specific Hazard Validity
The readmission survival probabilities generated via Breslow are correctly specified under a cause-specific Cox model.
Independence of Competing Events
Death and readmission are assumed to follow the standard competing risks framework (non-informative censoring conditional on covariates).
Proper Cross-Validation Aggregation
Combining all validation folds into a single master dataset assumes that pooling cross-validated predictions is unbiased.
(This is generally acceptable.)
Quantile Binning Adequacy
Risk binning assumes that quantile groups meaningfully represent calibration strata.
Calibration results can change with:
For XGB Cox, SHAP is on log-hazard scale, so case ranking should use absolute risk from baseline hazard:
Baseline hazards let you compute absolute risk per patient/horizon. But this automatically converts TreeSHAP values into additive risk-point SHAP. That needs a different explainer setup (probability-output function, usually model-agnostic and slower).
Use baseline hazards only to compute absolute risk at horizon for ranking/labeling patients and calibration.
Feature contributions do NOT sum on absolute risk scale due to non-linear transformation (non-linear link function); Converting Individual SHAP Values to Absolute Risk Breaks Additivity. So, water fall should not show absolute risk contributions. - Log-Hazard: f(x) = base + Σ SHAP_i –> Linear: Additive
SHAP (SHapley Additive exPlanations) values were computed for both readmission and mortality models. To prevent data leakage, SHAP values were calculated on out-of-sample cross-validated predictions: for each patient, explanations were generated from models trained on 4/5 of folds that did not include that patient. SHAP values were averaged across 5-fold cross-validation and 5 multiple imputations, providing internally validated feature importance estimates. However, be honest this is internal validation only, not external.
Code
#@title Step 8: SHAP Analysis & Plots (DUAL, Multi-Horizon, source_tag-traceable)# Corrected version - compatible with Step 5 outputs from XGboost_combined_mar26.ipynbimport osimport reimport globimport jsonimport pickleimport numpy as npimport pandas as pdimport shapimport matplotlib.pyplot as pltimport matplotlib as mplfrom IPython.display import display, Markdown# -----------------------------# 0) Config# -----------------------------if"PROJECT_ROOT"notinglobals():raiseRuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))IN_DIR = os.path.join(PROJECT_ROOT, "_out")OUT_DIR = os.path.join(PROJECT_ROOT, "_out")FIG_DIR = os.path.join(PROJECT_ROOT, "_figs")# Set None to auto-use all eval_times found in raw logsTARGET_HORIZONS = [12, 60]MAX_BEESWARM_N =90000RNG = np.random.RandomState(2125)# CI config (optional but enabled by default)BOOTSTRAP_CI =TrueN_BOOTSTRAP =500CI_ALPHA =0.05BOOTSTRAP_MAX_N =90000# Multicollinearity checkCORR_THRESHOLD =0.90# Unified DPIFIG_DPI =300os.makedirs(IN_DIR, exist_ok=True)os.makedirs(OUT_DIR, exist_ok=True)os.makedirs(FIG_DIR, exist_ok=True)# NOTE: These key candidates are matched against Step 5 output structure# Step 5 saves: 'risk_pred_readm', 'risk_pred_death', 'probs_readm_matrix', 'probs_death_matrix', etc.OUTCOME_CFG = {"readm": {"label": "Readmission","shap_key_candidates": ["shap_r_all", "shap_readm_all"],"margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],"probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],"hz_times_candidates": ["times_r", "times_readm"],"hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"], },"death": {"label": "Death","shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],"margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],"probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],"hz_times_candidates": ["times_d", "times_death", "times_mort"],"hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"], },}VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]mpl.rcParams.update({"font.family": "serif","font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],"pdf.fonttype": 42,"ps.fonttype": 42,"axes.labelsize": 14,"axes.titlesize": 15,"xtick.labelsize": 12,"ytick.labelsize": 12,"figure.dpi": FIG_DPI})# -----------------------------# 1) Pick latest complete Step 5 bundle (simplified)# -----------------------------TS_RE = re.compile(r"(\d{8}_\d{4})") # Just find the timestamp anywheredef pick_latest_complete_bundle(in_dir):"""Find the most recent complete set of Step 5 output files.""" shap_files = glob.glob(os.path.join(in_dir, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl")) candidates = []for shapf in shap_files: m = TS_RE.search(os.path.basename(shapf))ifnot m:continue tag = m.group(1)# Check for companion files - try both naming patternsfor suffix in [f"_{tag}_mar26.pkl", f"_{tag}.pkl"]: rawf = os.path.join(in_dir, f"xgb6_corr_DUAL_final_ev_hyp{suffix}") hzf = os.path.join(in_dir, f"xgb6_corr_DUAL_BaselineHazards{suffix}") splitf = os.path.join(in_dir, f"xgb6_corr_DUAL_CV_Splits{suffix}")ifall(os.path.exists(p) for p in (rawf, hzf, splitf)): dt = pd.to_datetime(tag, format="%Y%m%d_%H%M", errors="coerce")if pd.notna(dt): candidates.append((dt, tag, shapf, rawf, hzf, splitf))break# Found complete set for this tagifnot candidates:raiseFileNotFoundError(f"No complete Step 5 bundle found in '{in_dir}'. "f"Need: xgb6_corr_DUAL_SHAP_Aggregated_*.pkl + "f"final_ev_hyp + BaselineHazards + CV_Splits" ) candidates.sort(key=lambda x: x[0])return candidates[-1] # Latestsource_tag, shap_file, raw_file, hz_file, split_file = pick_latest_complete_bundle(IN_DIR)[1:]# For traceability, use source_tag in filenamesFILE_TAG = source_tagRUN_TS = pd.Timestamp.now().strftime("%Y%m%d_%H%M")# -----------------------------# 2) Load artifacts# -----------------------------try:withopen(shap_file, "rb") as f: shap_data = pickle.load(f)exceptExceptionas e:raiseRuntimeError(f"Failed to load SHAP file {shap_file}: {e}") from etry:withopen(raw_file, "rb") as f: raw_data_log = pickle.load(f)exceptExceptionas e:raiseRuntimeError(f"Failed to load raw data file {raw_file}: {e}") from etry:withopen(hz_file, "rb") as f: baseline_hazards_log = pickle.load(f)exceptExceptionas e:raiseRuntimeError(f"Failed to load baseline hazards file {hz_file}: {e}") from etry:withopen(split_file, "rb") as f: cv_splits_log = pickle.load(f)exceptExceptionas e:raiseRuntimeError(f"Failed to load CV splits file {split_file}: {e}") from erequired_keys = {"X_all", "feature_names"}missing = required_keys -set(shap_data.keys())if missing:raiseKeyError(f"Missing keys in SHAP file: {missing}")X_all = shap_data["X_all"]feature_names =list(shap_data["feature_names"])ifnotisinstance(X_all, pd.DataFrame): X_all = pd.DataFrame(X_all, columns=feature_names)iflist(X_all.columns) != feature_names: X_all = X_all.reindex(columns=feature_names)ifnot X_all.index.is_unique:raiseValueError("X_all index must be unique.")# -----------------------------# 3) Helpers# -----------------------------def get_first(dct, keys, default=None):"""Return first value from dct matching any key in keys."""for k in keys:if k in dct and dct[k] isnotNone:return dct[k]return defaultdef find_first_key(dct, keys):"""Return first key from keys that exists in dct."""for k in keys:if k in dct:return kreturnNonedef h0_at_t(times, h0_vals, t):"""Get baseline cumulative hazard at time t using left-continuous interpolation.""" times = np.asarray(times, dtype=float).ravel() h0_vals = np.asarray(h0_vals, dtype=float).ravel()if times.size ==0or h0_vals.size ==0orlen(times) !=len(h0_vals):return np.nan t =float(t)if t < times[0]:return0.0# Before first event, cumulative hazard is 0 i = np.searchsorted(times, t, side="right") -1 i =max(0, min(i, len(h0_vals) -1)) # Clip to valid rangereturnfloat(h0_vals[i])def fmt_horizon(h): h =float(h)returnstr(int(h)) ifabs(h -round(h)) <1e-9elsef"{h:g}"def horizon_token(h):returnf"{fmt_horizon(h).replace('.', 'p')}m"def save_current_figure(stem, outcome, horizon=None): fig = plt.gcf() hz =f"_{horizon_token(horizon)}"if horizon isnotNoneelse"" base =f"xgb8_dual_{outcome}_{stem}{hz}_{FILE_TAG}" png = os.path.join(FIG_DIR, f"{base}.png") pdf = os.path.join(FIG_DIR, f"{base}.pdf") fig.savefig(png, dpi=FIG_DPI, bbox_inches="tight") fig.savefig(pdf, bbox_inches="tight")return [png, pdf]def discover_horizons(raw_log):"""Extract all unique eval_times from raw data log.""" vals = []for rec in raw_log: ev = np.asarray(rec.get("eval_times", []), dtype=float).ravel() vals.extend([v for v in ev if np.isfinite(v)])returnsorted(set(vals))def bootstrap_mean_abs_shap(shap_vals, n_boot=200, alpha=0.05, seed=2026, max_n=None):"""Bootstrap confidence intervals for mean absolute SHAP values.""" n, p = shap_vals.shape rng = np.random.RandomState(seed)if max_n isnotNoneand n > max_n: idx = rng.choice(n, size=max_n, replace=False) X_sub = shap_vals[idx, :]else: X_sub = shap_vals n_eff = X_sub.shape[0] point = np.abs(X_sub).mean(axis=0) boot = np.empty((n_boot, p), dtype=float)for b inrange(n_boot): ib = rng.choice(n_eff, size=n_eff, replace=True) boot[b] = np.abs(X_sub[ib]).mean(axis=0) lo = np.quantile(boot, alpha /2.0, axis=0) hi = np.quantile(boot, 1.0- alpha /2.0, axis=0)return point, lo, hi, n_effdef collect_margin_by_id(raw_log, split_map, cfg):"""Collect margin (log-hazard) predictions by patient ID, averaged across folds.""" rows = []for rec in raw_log:if"imp_idx"notin rec or"fold_idx"notin rec:continue key = (int(rec["imp_idx"]), int(rec["fold_idx"])) split_rec = split_map.get(key)if split_rec isNone:continue val_ids = get_first(split_rec, VAL_ID_KEYS, []) margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()iflen(val_ids) !=len(margins) orlen(val_ids) ==0:continue# FIXED: Ensure IDs are strings for consistent matchingfor i, pid inenumerate(val_ids): rows.append((str(pid), float(margins[i])))ifnot rows:return pd.DataFrame(columns=["id", "margin"])return pd.DataFrame(rows, columns=["id", "margin"]).groupby("id", as_index=False)["margin"].mean()def collect_risk_by_id(raw_log, split_map, hz_map, cfg, horizon):"""Collect absolute risk predictions by patient ID at a specific horizon.""" rows = [] t =float(horizon)for rec in raw_log:if"imp_idx"notin rec or"fold_idx"notin rec:continue key = (int(rec["imp_idx"]), int(rec["fold_idx"])) split_rec = split_map.get(key) hz_rec = hz_map.get(key)if split_rec isNone:continue val_ids = get_first(split_rec, VAL_ID_KEYS, []) margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()iflen(val_ids) !=len(margins) orlen(val_ids) ==0:continue risk_vec =None eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel() probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)# Try to get risk from survival probability matrixif eval_times.size >0and probs_mat.ndim ==2: j = np.where(np.isclose(eval_times, t))[0]if j.size >0: jj =int(j[0])if probs_mat.shape[0] ==len(val_ids) and probs_mat.shape[1] == eval_times.size: risk_vec =1.0- probs_mat[:, jj]elif probs_mat.shape[1] ==len(val_ids) and probs_mat.shape[0] == eval_times.size: risk_vec =1.0- probs_mat[jj, :]if risk_vec isnotNone: risk_vec = np.asarray(risk_vec, dtype=float).ravel()# Fallback: compute from baseline hazard and marginsif risk_vec isNoneand hz_rec isnotNone: times = get_first(hz_rec, cfg["hz_times_candidates"], []) h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], []) H0_t = h0_at_t(times, h0_vals, t)if np.isfinite(H0_t): surv = np.exp(-np.exp(margins) * H0_t) risk_vec =1.0- survif risk_vec isnotNoneandlen(risk_vec) ==len(val_ids): risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)for i, pid inenumerate(val_ids): rv =float(risk_vec[i])if np.isfinite(rv): rows.append((str(pid), rv))ifnot rows:return pd.DataFrame(columns=["id", "risk"])return pd.DataFrame(rows, columns=["id", "risk"]).groupby("id", as_index=False)["risk"].mean()def collect_risk_samples_by_id(raw_log, split_map, hz_map, cfg, horizon):"""Collect all risk samples by patient ID (for CI computation).""" rows = [] t =float(horizon)for rec in raw_log:if"imp_idx"notin rec or"fold_idx"notin rec:continue key = (int(rec["imp_idx"]), int(rec["fold_idx"])) split_rec = split_map.get(key) hz_rec = hz_map.get(key)if split_rec isNone:continue val_ids = get_first(split_rec, VAL_ID_KEYS, []) margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()iflen(val_ids) !=len(margins) orlen(val_ids) ==0:continue risk_vec =None eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel() probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)# Try survival probability matrixif eval_times.size >0and probs_mat.ndim ==2: j = np.where(np.isclose(eval_times, t))[0]if j.size >0: jj =int(j[0])if probs_mat.shape[0] ==len(val_ids) and probs_mat.shape[1] == eval_times.size: risk_vec =1.0- probs_mat[:, jj]elif probs_mat.shape[1] ==len(val_ids) and probs_mat.shape[0] == eval_times.size: risk_vec =1.0- probs_mat[jj, :]# Fallback to baseline hazardif risk_vec isNoneand hz_rec isnotNone: times = get_first(hz_rec, cfg["hz_times_candidates"], []) h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], []) H0_t = h0_at_t(times, h0_vals, t)if np.isfinite(H0_t): risk_vec =1.0- np.exp(-np.exp(margins) * H0_t)if risk_vec isNone:continue risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)iflen(risk_vec) !=len(val_ids):continuefor i, pid inenumerate(val_ids): rv =float(risk_vec[i])if np.isfinite(rv): rows.append((str(pid), rv))ifnot rows:return pd.DataFrame(columns=["id", "risk"])return pd.DataFrame(rows, columns=["id", "risk"])def summarize_risk_ci(df_samples, alpha=0.05):"""Summarize risk samples with mean and quantile CI."""if df_samples.empty:return pd.DataFrame(columns=["id", "risk_mean", "risk_ci_low", "risk_ci_high", "n_samples"]) g = df_samples.groupby("id")["risk"] out = g.agg(risk_mean="mean", n_samples="size").reset_index() out["risk_ci_low"] = g.quantile(alpha /2.0).values out["risk_ci_high"] = g.quantile(1.0- alpha /2.0).valuesreturn outdef correlation_pairs_report(X_df, threshold=0.85):"""Find highly correlated feature pairs.""" X_num = X_df.apply(pd.to_numeric, errors="coerce") valid_cols = [c for c in X_num.columns if X_num[c].std(skipna=True) >0] X_num = X_num[valid_cols] corr = X_num.corr(method="pearson") pairs = [] cols =list(corr.columns)for i inrange(len(cols)):for j inrange(i +1, len(cols)): r = corr.iat[i, j]if np.isfinite(r) andabs(r) >= threshold: pairs.append((cols[i], cols[j], float(r), float(abs(r)))) pairs_df = pd.DataFrame(pairs, columns=["feature_1", "feature_2", "pearson_r", "abs_r"])iflen(pairs_df): pairs_df = pairs_df.sort_values("abs_r", ascending=False).reset_index(drop=True)return pairs_df, corr# -----------------------------# 4) Prepare maps/horizons + multicollinearity check# -----------------------------split_map = { (int(s["imp_idx"]), int(s["fold_idx"])): sfor s in cv_splits_logif"imp_idx"in s and"fold_idx"in s}hz_map = { (int(h["imp_idx"]), int(h["fold_idx"])): hfor h in baseline_hazards_logif"imp_idx"in h and"fold_idx"in h}available_horizons = discover_horizons(raw_data_log)# FIXED: Handle empty TARGET_HORIZONS and validate requested horizonsif TARGET_HORIZONS isNoneorlen(TARGET_HORIZONS) ==0: horizons = available_horizons if available_horizons else [12.0]else: horizons =sorted(set(float(h) for h in TARGET_HORIZONS))# Warn about missing horizons missing_h =set(horizons) -set(available_horizons)if missing_h: nb_print(f"Warning: Requested horizons not in data: {sorted(missing_h)}") nb_print(f"Available horizons: {available_horizons}")corr_pairs_df, corr_mat = correlation_pairs_report(X_all, threshold=CORR_THRESHOLD)corr_pairs_file = os.path.join(OUT_DIR, f"xgb8_dual_feature_corr_pairs_{FILE_TAG}.csv")corr_mat_file = os.path.join(OUT_DIR, f"xgb8_dual_feature_corr_matrix_{FILE_TAG}.csv")corr_pairs_df.to_csv(corr_pairs_file, index=False)corr_mat.to_csv(corr_mat_file)display(Markdown(f"### Step 8 SHAP (DUAL, Multi-Horizon)\n"f"- Source bundle tag: **{source_tag}**\n"f"- Run time: **{RUN_TS}**\n"f"- Patients: **{X_all.shape[0]}**\n"f"- Features: **{X_all.shape[1]}**\n"f"- Horizons (months): **{', '.join(fmt_horizon(h) for h in horizons)}**\n"f"- SHAP scale: **log-hazard**\n"f"- Multicollinearity threshold: **|r| >= {CORR_THRESHOLD:.2f}**"))iflen(corr_pairs_df) >0: display(Markdown(f"Found **{len(corr_pairs_df)}** correlated feature pairs (|r| >= {CORR_THRESHOLD:.2f}).")) display(corr_pairs_df.head(20))else: display(Markdown(f"No feature pairs above |r| >= {CORR_THRESHOLD:.2f}."))# -----------------------------# 5) Run SHAP per outcome and horizon# -----------------------------saved_plot_files = []saved_out_files = [corr_pairs_file, corr_mat_file]all_case_rows = []horizon_rows = []processed_outcomes = []# FIXED: id_to_row uses string keys for consistent matchingid_to_row = {str(idx): i for i, idx inenumerate(X_all.index)}for outcome_name, cfg in OUTCOME_CFG.items(): shap_key = find_first_key(shap_data, cfg["shap_key_candidates"])if shap_key isNone:print(f"Skipping {cfg['label']}: missing SHAP key among {cfg['shap_key_candidates']}.")continue shap_vals = np.asarray(shap_data[shap_key], dtype=float)if shap_vals.shape != X_all.shape:raiseValueError(f"{cfg['label']} SHAP shape mismatch: {shap_vals.shape} vs X_all {X_all.shape}")# Recover baseline on margin (log-hazard) scale df_margin = collect_margin_by_id(raw_data_log, split_map, cfg) df_shap_sum = pd.DataFrame({"id": X_all.index.astype(str),"shap_sum": shap_vals.sum(axis=1) }) tmp = df_shap_sum.merge(df_margin, on="id", how="inner") base_margin =float((tmp["margin"] - tmp["shap_sum"]).mean()) iflen(tmp) >0else0.0 explanation = shap.Explanation( values=shap_vals, base_values=np.full(X_all.shape[0], base_margin, dtype=float), data=X_all.to_numpy(), feature_names=feature_names )# Global mean |SHAP| + CIif BOOTSTRAP_CI: point, ci_low, ci_high, ci_n = bootstrap_mean_abs_shap( shap_vals, n_boot=N_BOOTSTRAP, alpha=CI_ALPHA, seed=2125, max_n=BOOTSTRAP_MAX_N )else: point = np.abs(shap_vals).mean(axis=0) ci_low = np.full_like(point, np.nan, dtype=float) ci_high = np.full_like(point, np.nan, dtype=float) ci_n =int(shap_vals.shape[0]) df_top = pd.DataFrame({"outcome": cfg["label"],"feature": feature_names,"mean_abs_shap_log_hazard": point,"ci95_low": ci_low,"ci95_high": ci_high,"bootstrap_n": int(ci_n),"n_bootstrap": int(N_BOOTSTRAP if BOOTSTRAP_CI else0),"ci_alpha": float(CI_ALPHA if BOOTSTRAP_CI else np.nan), }).sort_values("mean_abs_shap_log_hazard", ascending=False).reset_index(drop=True) top_file = os.path.join(OUT_DIR, f"xgb8_dual_{outcome_name}_shap_top_features_{FILE_TAG}.csv") df_top.to_csv(top_file, index=False) saved_out_files.append(top_file) processed_outcomes.append(outcome_name) display(Markdown(f"## {cfg['label']}")) display(Markdown("SHAP values and global importance are on the **log-hazard** scale.")) display(df_top.head(20))# Bar plot with CI (log-hazard SHAP) df_bar = df_top.head(20).sort_values("mean_abs_shap_log_hazard", ascending=True)# FIXED: Explicit dtype specification x = df_bar["mean_abs_shap_log_hazard"].to_numpy(dtype=float) has_ci = BOOTSTRAP_CI and np.isfinite(df_bar["ci95_low"]).all() and np.isfinite(df_bar["ci95_high"]).all() plt.figure(figsize=(11, 8))if has_ci: lo = df_bar["ci95_low"].to_numpy(dtype=float) hi = df_bar["ci95_high"].to_numpy(dtype=float) xerr = np.vstack([np.clip(x - lo, 0.0, None), np.clip(hi - x, 0.0, None)]) plt.barh(df_bar["feature"], x, xerr=xerr, color="#4C72B0", alpha=0.9, ecolor="black", capsize=2) plt.title(f"{cfg['label']} Global mean |SHAP| (log-hazard) with 95% bootstrap CI")else: plt.barh(df_bar["feature"], x, color="#4C72B0", alpha=0.9) plt.title(f"{cfg['label']} Global mean |SHAP| (log-hazard)") plt.xlabel("mean |SHAP| (log-hazard)") plt.tight_layout() saved_plot_files.extend(save_current_figure("bar_ci", outcome_name)) plt.show() plt.close()# Beeswarm (distribution on log-hazard SHAP scale)if X_all.shape[0] > MAX_BEESWARM_N: idx = RNG.choice(X_all.shape[0], MAX_BEESWARM_N, replace=False) exp_bee = explanation[idx]else: exp_bee = explanation plt.figure(figsize=(12, 8)) shap.plots.beeswarm(exp_bee, max_display=20, show=False) plt.title(f"{cfg['label']} SHAP Beeswarm (log-hazard scale)") plt.tight_layout() saved_plot_files.extend(save_current_figure("beeswarm", outcome_name)) plt.show() plt.close()# Horizon-specific risk ranking + waterfallsfor h in horizons: df_risk_samples = collect_risk_samples_by_id(raw_data_log, split_map, hz_map, cfg, h)# FIXED: Use CI_ALPHA instead of hardcoded 0.05 df_risk = summarize_risk_ci(df_risk_samples, alpha=CI_ALPHA)# FIXED: Ensure string IDs for consistent matching df_risk["id"] = df_risk["id"].astype(str) df_risk = df_risk[df_risk["id"].isin(id_to_row.keys())] n_h =int(len(df_risk)) n_total =int(X_all.shape[0]) horizon_rows.append({"outcome": cfg["label"],"horizon_months": float(h),"n_patients_with_risk": n_h,"n_total_patients": n_total,"coverage_pct": (100.0* n_h / n_total) if n_total >0else np.nan })if n_h ==0: nb_print(f"{cfg['label']} @ {fmt_horizon(h)}m: no absolute risk available; skipping waterfalls.")continue# Strictly rank by absolute risk (not SHAP score/log-hazard score) hi = df_risk.sort_values("risk_mean", ascending=False).iloc[0] lo = df_risk.sort_values("risk_mean", ascending=True).iloc[0] high_id =str(hi["id"]) low_id =str(lo["id"]) high_risk =float(hi["risk_mean"]) low_risk =float(lo["risk_mean"]) high_low =float(hi["risk_ci_low"]) high_high =float(hi["risk_ci_high"]) low_low =float(lo["risk_ci_low"]) low_high =float(lo["risk_ci_high"]) high_n =int(hi["n_samples"]) low_n =int(lo["n_samples"]) high_row = id_to_row[high_id] low_row = id_to_row[low_id] all_case_rows.append({"outcome": cfg["label"],"horizon_months": float(h),"case": "highest","id": high_id,"risk_at_horizon": high_risk,"n_patients_with_risk": n_h }) all_case_rows.append({"outcome": cfg["label"],"horizon_months": float(h),"case": "lowest","id": low_id,"risk_at_horizon": low_risk,"n_patients_with_risk": n_h })# Waterfall high risk (SHAP still log-hazard contribution) plt.figure(figsize=(10, 7)) shap.plots.waterfall(explanation[high_row], max_display=12, show=False) plt.title(f"{cfg['label']} highest absolute risk @ {fmt_horizon(h)}m "f"(ID {high_id}, risk={high_risk:.3f} [{high_low:.3f}, {high_high:.3f}], n={high_n})\n"f"Waterfall shows SHAP contributions on log-hazard scale" ) plt.tight_layout() saved_plot_files.extend(save_current_figure("waterfall_high", outcome_name, h)) plt.show() plt.close()# Waterfall low risk plt.figure(figsize=(10, 7)) shap.plots.waterfall(explanation[low_row], max_display=12, show=False) plt.title(f"{cfg['label']} lowest absolute risk @ {fmt_horizon(h)}m "f"(ID {low_id}, risk={low_risk:.3f} [{low_low:.3f}, {low_high:.3f}], n={low_n})\n"f"Waterfall shows SHAP contributions on log-hazard scale" ) plt.tight_layout() saved_plot_files.extend(save_current_figure("waterfall_low", outcome_name, h)) plt.show() plt.close()# -----------------------------# 6) Export combined outputs# -----------------------------cases_file = os.path.join(OUT_DIR, f"xgb8_dual_shap_extreme_cases_{FILE_TAG}.csv")horizon_file = os.path.join(OUT_DIR, f"xgb8_dual_horizon_sample_sizes_{FILE_TAG}.csv")info_file = os.path.join(OUT_DIR, f"xgb8_dual_shap_run_info_{FILE_TAG}.json")df_cases = pd.DataFrame(all_case_rows)df_horizon = pd.DataFrame(horizon_rows)df_cases.to_csv(cases_file, index=False)df_horizon.to_csv(horizon_file, index=False)saved_out_files.extend([cases_file, horizon_file])run_info = {"source_bundle_tag": source_tag,"run_timestamp": RUN_TS,"file_tag_used_for_outputs": FILE_TAG,"source_files": {"shap": shap_file,"raw": raw_file,"baseline_hazards": hz_file,"cv_splits": split_file },"n_patients": int(X_all.shape[0]),"n_features": int(X_all.shape[1]),"horizons_months": [float(h) for h in horizons],"available_eval_times_months": [float(h) for h in available_horizons],"outcomes_processed": processed_outcomes,"shap_scale": "log-hazard","global_importance_metric": "mean absolute SHAP (log-hazard)","bootstrap_ci": {"enabled": bool(BOOTSTRAP_CI),"n_bootstrap": int(N_BOOTSTRAP) if BOOTSTRAP_CI else0,"alpha": float(CI_ALPHA) if BOOTSTRAP_CI elseNone,"max_n_for_bootstrap": int(BOOTSTRAP_MAX_N) if BOOTSTRAP_CI elseNone },"multicollinearity_check": {"method": "pairwise Pearson correlation","threshold_abs_r": float(CORR_THRESHOLD),"pairs_file": corr_pairs_file,"corr_matrix_file": corr_mat_file },"note": "Absolute risk is used for extreme-case ranking at each horizon. SHAP values remain log-hazard contributions."}withopen(info_file, "w", encoding="utf-8") as f: json.dump(run_info, f, indent=2)saved_out_files.append(info_file)print("\nSaved plots to _figs (PNG/PDF):")for p in saved_plot_files: nb_print(" -", p)print("\nSaved tables/metadata to _out:")for p in saved_out_files: nb_print(" -", p)
Step 8 SHAP (DUAL, Multi-Horizon)
Source bundle tag: 20260306_1821
Run time: 20260306_1828
Patients: 70521
Features: 56
Horizons (months): 12, 60
SHAP scale: log-hazard
Multicollinearity threshold: |r| >= 0.90
No feature pairs above |r| >= 0.90.
Readmission
SHAP values and global importance are on the log-hazard scale.
outcome
feature
mean_abs_shap_log_hazard
ci95_low
ci95_high
bootstrap_n
n_bootstrap
ci_alpha
0
Readmission
primary_sub_mod_cocaine_paste
0.090015
0.089795
0.090213
70521
500
0.05
1
Readmission
adm_age_rec3
0.078743
0.077885
0.079599
70521
500
0.05
2
Readmission
porc_pobr
0.073050
0.072710
0.073402
70521
500
0.05
3
Readmission
sex_rec_woman
0.070474
0.070155
0.070807
70521
500
0.05
4
Readmission
plan_type_corr_pg_pr
0.069722
0.069139
0.070305
70521
500
0.05
5
Readmission
plan_type_corr_m_pr
0.056074
0.055319
0.056821
70521
500
0.05
6
Readmission
ethnicity
0.054893
0.054140
0.055594
70521
500
0.05
7
Readmission
dit_m
0.047543
0.047252
0.047825
70521
500
0.05
8
Readmission
eva_consumo
0.046565
0.046266
0.046839
70521
500
0.05
9
Readmission
ed_attainment_corr
0.041594
0.041322
0.041856
70521
500
0.05
10
Readmission
occupation_condition_corr24_unemployed
0.036454
0.036315
0.036597
70521
500
0.05
11
Readmission
dg_psiq_cie_10_dg
0.028379
0.028299
0.028446
70521
500
0.05
12
Readmission
sub_dep_icd10_status_drug_dependence
0.028375
0.028231
0.028495
70521
500
0.05
13
Readmission
polysubstance_strict
0.025277
0.025137
0.025427
70521
500
0.05
14
Readmission
primary_sub_mod_alcohol
0.024515
0.024404
0.024631
70521
500
0.05
15
Readmission
eva_sm
0.023360
0.023182
0.023537
70521
500
0.05
16
Readmission
primary_sub_mod_cocaine_powder
0.023091
0.022960
0.023230
70521
500
0.05
17
Readmission
evaluacindelprocesoteraputico
0.021649
0.021516
0.021774
70521
500
0.05
18
Readmission
tr_outcome_adm_discharge_rule_violation_undet
0.021562
0.021317
0.021812
70521
500
0.05
19
Readmission
prim_sub_freq_rec
0.019463
0.019331
0.019586
70521
500
0.05
Death
SHAP values and global importance are on the log-hazard scale.
Unlike DeepHit (which creates different risk functions over time), XGBoost Cox models are time-invariant (Proportional Hazards). This means the model’s structure—and its interactions—are constant across all time horizons. Therefore, we do not need to loop through 3, 6, 12 months. We run the discovery once per outcome to find the “Global Interactions” inherent in the model structure.
Code
#@title ⚡ Step 11: Interaction Discovery (DUAL, Global + Time-Dependent @ 3/12/36/60m)import osimport reimport globimport jsonimport pickleimport numpy as npimport pandas as pdimport matplotlib.pyplot as pltfrom IPython.display import display, Markdown, HTML # FIXED: Added HTMLfrom scipy.stats import t as student_t# -----------------------------# 0) Config# -----------------------------if"PROJECT_ROOT"notinglobals():raiseRuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))IN_DIR = os.path.join(PROJECT_ROOT, "_out")OUT_DIR = os.path.join(PROJECT_ROOT, "_out")FIG_DIR = os.path.join(PROJECT_ROOT, "_figs")# FIXED: Removed duplicate os.makedirs callsos.makedirs(IN_DIR, exist_ok=True)os.makedirs(OUT_DIR, exist_ok=True)os.makedirs(FIG_DIR, exist_ok=True)TARGET_HORIZONS = [3, 6, 12, 36, 60, 96]MAIN_TOP_K =15# top main effects to scan for interactionsINTER_TOP_K =5# top interactors per main feature for reportingMIN_VALID_N =500# minimum patients required for a valid correlationWEIGHT_CLIP_PCT =99.5# cap extreme risk-gradient weightsTIME_DEP_RANK_RANGE_MIN =8#marks a time-dependent signal if ranking changes in at least 8 position between horizonsTIME_DEP_ABS_DELTA_MIN =0.05#if interaction force/influence change in terms of correlation between horizons PLOT_TOP_TIME_DEP =12#maximum of time-dependencies interactions to plotFIG_DPI =300DIRECTION_MIN_GROUP_N =30#requires at least 30 cases in each group to compare Q1 vs. Q4FDR_ALPHA =0.01# significance thresholdOUTCOME_CFG = {"readm": {"label": "Readmission","shap_key_candidates": ["shap_r_all", "shap_readm_all"],"margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],"probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],"hz_times_candidates": ["times_r", "times_readm"],"hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"], },"death": {"label": "Death","shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],"margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],"probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],"hz_times_candidates": ["times_d", "times_death", "times_mort"],"hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"], },}VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]# FIXED: Changed TS_RE to match both _mar26 and non-_mar26 suffixesTS_RE = re.compile(r"(\d{8}_\d{4})")# -----------------------------# 1) Helpers# -----------------------------def _tag_from_path(path): m = TS_RE.search(os.path.basename(path))return m.group(1) if m elseNonedef pick_latest_complete_bundle(in_dir):"""Pick latest SHAP bundle, handling both _mar26 and non-_mar26 suffixes.""" shap_files = glob.glob(os.path.join(in_dir, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl")) candidates = []for shapf in shap_files: tag = _tag_from_path(shapf)ifnot tag:continue# FIXED: Try both naming patternsfor suffix in [f"_{tag}_mar26.pkl", f"_{tag}.pkl"]: rawf = os.path.join(in_dir, f"xgb6_corr_DUAL_final_ev_hyp{suffix}") hzf = os.path.join(in_dir, f"xgb6_corr_DUAL_BaselineHazards{suffix}") splitf = os.path.join(in_dir, f"xgb6_corr_DUAL_CV_Splits{suffix}")ifall(os.path.exists(p) for p in (rawf, hzf, splitf)): dt = pd.to_datetime(tag, format="%Y%m%d_%H%M", errors="coerce")if pd.notna(dt): candidates.append((dt, tag, shapf, rawf, hzf, splitf))break# Found complete setifnot candidates:raiseFileNotFoundError(f"No complete Step 5 bundle found in '{in_dir}' with prefix xgb6_corr_DUAL_*." ) candidates.sort(key=lambda x: x[0]) _, tag, shapf, rawf, hzf, splitf = candidates[-1]return tag, shapf, rawf, hzf, splitfdef get_first(dct, keys, default=None):for k in keys:if k in dct and dct[k] isnotNone:return dct[k]return defaultdef find_first_key(dct, keys):for k in keys:if k in dct:return kreturnNonedef h0_at_t(times, h0_vals, t): times = np.asarray(times, dtype=float).ravel() h0_vals = np.asarray(h0_vals, dtype=float).ravel()# FIXED: Added length check and improved boundary handlingif times.size ==0or h0_vals.size ==0orlen(times) !=len(h0_vals):return np.nan t =float(t)if t < times[0]:return0.0# Before first event i = np.searchsorted(times, t, side="right") -1 i =max(0, min(i, len(h0_vals) -1)) # Clip to valid rangereturnfloat(h0_vals[i])def safe_corr(x, y, min_n=500, eps=1e-12): x = np.asarray(x, dtype=float).ravel() y = np.asarray(y, dtype=float).ravel() m = np.isfinite(x) & np.isfinite(y) n =int(m.sum())if n < min_n:return np.nan, n xx = x[m] yy = y[m]if np.std(xx) < eps or np.std(yy) < eps:return np.nan, n r = np.corrcoef(xx, yy)[0, 1]returnfloat(r), ndef linear_residual(x, y, min_n=500, eps=1e-12): x = np.asarray(x, dtype=float).ravel() y = np.asarray(y, dtype=float).ravel() m = np.isfinite(x) & np.isfinite(y) n =int(m.sum())if n < min_n:returnNone, n, np.nan, np.nan xx = x[m] yy = y[m] resid = np.full_like(y, np.nan, dtype=float)if np.std(xx) < eps: mu =float(np.mean(yy)) resid[m] = yy - mureturn resid, n, 0.0, mu A = np.column_stack([xx, np.ones_like(xx)]) coef, *_ = np.linalg.lstsq(A, yy, rcond=None) slope =float(coef[0]) intercept =float(coef[1]) resid[m] = yy - (slope * xx + intercept)return resid, n, slope, interceptdef interaction_scores(X, shap_mat, feature_names, main_indices, main_importance=None, min_n=500): X = np.asarray(X, dtype=float) S = np.asarray(shap_mat, dtype=float) p = X.shape[1] rows = []for main_rank, i inenumerate(main_indices, start=1): resid, n_main, slope, intercept = linear_residual(X[:, i], S[:, i], min_n=min_n)if resid isNone:continuefor j inrange(p):if j == i:continue r, n_valid = safe_corr(X[:, j], resid, min_n=min_n)ifnot np.isfinite(r):continue delta_q, dir_q, n_dir, n_q1, n_q4 = quartile_delta_direction( X[:, j], resid, min_n=min_n, min_group_n=DIRECTION_MIN_GROUP_N ) rows.append({"main_idx": int(i),"inter_idx": int(j),"main_feature": feature_names[i],"main_rank_global": int(main_rank),"main_importance_global": float(main_importance[i]) if main_importance isnotNoneelse np.nan,"interactor": feature_names[j],"corr_resid_vs_interactor": float(r),"abs_corr": float(abs(r)),"delta_q4_q1_resid": float(delta_q) if np.isfinite(delta_q) else np.nan,"direction_q4_q1": dir_q,"n_dir_valid": int(n_dir),"n_q1": int(n_q1),"n_q4": int(n_q4),"n_valid": int(n_valid),"main_linear_slope": float(slope),"main_linear_intercept": float(intercept), })ifnot rows:return pd.DataFrame(columns=["main_idx","inter_idx","main_feature","main_rank_global","main_importance_global","interactor","corr_resid_vs_interactor","abs_corr","delta_q4_q1_resid","direction_q4_q1","n_dir_valid","n_q1","n_q4","n_valid","main_linear_slope","main_linear_intercept" ])return pd.DataFrame(rows)def top_interactions(df_scores, top_k):iflen(df_scores) ==0:return df_scores.copy() df = df_scores.sort_values( ["outcome", "horizon_months", "main_rank_global", "abs_corr"], ascending=[True, True, True, False] )return ( df.groupby(["outcome", "horizon_months", "main_feature"], dropna=False, as_index=False, group_keys=False) .head(top_k) .reset_index(drop=True) )def collect_risk_by_id(raw_log, split_map, hz_map, cfg, horizon): rows = [] t =float(horizon)for rec in raw_log:if"imp_idx"notin rec or"fold_idx"notin rec:continue key = (int(rec["imp_idx"]), int(rec["fold_idx"])) split_rec = split_map.get(key) hz_rec = hz_map.get(key)if split_rec isNone:continue val_ids = get_first(split_rec, VAL_ID_KEYS, []) margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()iflen(val_ids) ==0orlen(margins) !=len(val_ids):continue risk_vec =None eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel() probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)if eval_times.size >0and probs_mat.ndim ==2: j = np.where(np.isclose(eval_times, t))[0]if j.size >0: jj =int(j[0])if probs_mat.shape[0] ==len(val_ids) and probs_mat.shape[1] == eval_times.size: risk_vec =1.0- probs_mat[:, jj]elif probs_mat.shape[1] ==len(val_ids) and probs_mat.shape[0] == eval_times.size: risk_vec =1.0- probs_mat[jj, :]if risk_vec isnotNone: risk_vec = np.asarray(risk_vec, dtype=float).ravel()if risk_vec isNoneand hz_rec isnotNone: times = get_first(hz_rec, cfg["hz_times_candidates"], []) h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], []) H0_t = h0_at_t(times, h0_vals, t)if np.isfinite(H0_t): surv = np.exp(-np.exp(margins) * H0_t) risk_vec =1.0- survif risk_vec isnotNoneandlen(risk_vec) ==len(val_ids): risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)for i, pid inenumerate(val_ids): rv =float(risk_vec[i])if np.isfinite(rv): rows.append((str(pid), rv))ifnot rows:return pd.DataFrame(columns=["id", "risk"])return pd.DataFrame(rows, columns=["id", "risk"]).groupby("id", as_index=False)["risk"].mean()def save_current_figure(stem): fig = plt.gcf() png = os.path.join(FIG_DIR, f"{stem}.png") pdf = os.path.join(FIG_DIR, f"{stem}.pdf") fig.savefig(png, dpi=FIG_DPI, bbox_inches="tight") fig.savefig(pdf, bbox_inches="tight")return [png, pdf]def corr_to_pvalue(r, n):if (not np.isfinite(r)) or (n <=2) or (abs(r) >=1):return np.nan tval = r * np.sqrt((n -2) /max(1.0- r * r, 1e-12))returnfloat(2.0* (1.0- student_t.cdf(abs(tval), df=n -2)))def bh_fdr(pvals): p = np.asarray(pvals, dtype=float) out = np.full_like(p, np.nan, dtype=float) m = np.isfinite(p) pv = p[m] n =len(pv)if n ==0:return out order = np.argsort(pv) ranked = pv[order] q = ranked * n / np.arange(1, n +1, dtype=float) q = np.minimum.accumulate(q[::-1])[::-1] q = np.clip(q, 0.0, 1.0) unsorted_q = np.empty_like(q) unsorted_q[order] = q out[m] = unsorted_qreturn outdef quartile_delta_direction(x, resid, min_n=500, min_group_n=40, eps=1e-12): x = np.asarray(x, dtype=float).ravel() r = np.asarray(resid, dtype=float).ravel() m = np.isfinite(x) & np.isfinite(r) n =int(m.sum())if n < min_n:return np.nan, "Unknown", n, 0, 0 xx = x[m] rr = r[m]if np.std(xx) < eps or np.std(rr) < eps:return np.nan, "Unknown", n, 0, 0 q1, q3 = np.quantile(xx, [0.25, 0.75]) g1 = rr[xx <= q1] g4 = rr[xx >= q3] n1, n4 =int(len(g1)), int(len(g4))if n1 < min_group_n or n4 < min_group_n:return np.nan, "Unknown", n, n1, n4 delta =float(np.mean(g4) - np.mean(g1))ifabs(delta) < eps: d ="Neutral"elif delta >0: d ="Positive"else: d ="Negative"return delta, d, n, n1, n4def add_significance(df, family_cols, alpha=0.05):iflen(df) ==0:return df out = df.copy() out["p_value"] = np.nan out["p_fdr"] = np.nan out[f"signif_fdr_{alpha:.2f}"] =False groups = out.groupby(family_cols, dropna=False).groupsfor _, idx in groups.items(): idx =list(idx) p = np.array([corr_to_pvalue(out.at[i, "corr_resid_vs_interactor"], int(out.at[i, "n_valid"])) for i in idx], dtype=float) q = bh_fdr(p) out.loc[idx, "p_value"] = p out.loc[idx, "p_fdr"] = q out.loc[idx, f"signif_fdr_{alpha:.2f}"] = q < alphareturn out# -----------------------------# 2) Load latest bundle# -----------------------------source_tag, shap_file, raw_file, hz_file, split_file = pick_latest_complete_bundle(IN_DIR)withopen(shap_file, "rb") as f: shap_data = pickle.load(f)withopen(raw_file, "rb") as f: raw_data_log = pickle.load(f)withopen(hz_file, "rb") as f: baseline_hazards_log = pickle.load(f)withopen(split_file, "rb") as f: cv_splits_log = pickle.load(f)if"X_all"notin shap_data or"feature_names"notin shap_data:raiseKeyError("SHAP file must contain X_all and feature_names.")X_all = shap_data["X_all"]feature_names =list(shap_data["feature_names"])ifnotisinstance(X_all, pd.DataFrame): X_all = pd.DataFrame(X_all, columns=feature_names)iflist(X_all.columns) != feature_names: X_all = X_all.reindex(columns=feature_names)ifnot X_all.index.is_unique:raiseValueError("X_all index must be unique.")# FIXED: Added duplicate feature name checkiflen(feature_names) !=len(set(feature_names)): dupes = [f for f in feature_names if feature_names.count(f) >1]raiseValueError(f"Duplicate feature names detected: {set(dupes)}")X_np = X_all.to_numpy(dtype=float)id_to_row = {str(idx): i for i, idx inenumerate(X_all.index)}split_map = { (int(s["imp_idx"]), int(s["fold_idx"])): sfor s in cv_splits_logif"imp_idx"in s and"fold_idx"in s}hz_map = { (int(h["imp_idx"]), int(h["fold_idx"])): hfor h in baseline_hazards_logif"imp_idx"in h and"fold_idx"in h}display(Markdown(f"### Step 11 Interaction Discovery (DUAL)\n"f"- Source bundle tag: **{source_tag}**\n"f"- Patients: **{X_all.shape[0]}**\n"f"- Features: **{X_all.shape[1]}**\n"f"- Horizons: **{', '.join(map(str, TARGET_HORIZONS))} months**\n"f"- Global interactions: **log-hazard SHAP scale**\n"f"- Horizon interactions: **risk-scale approximation via `dr/deta * SHAP`**"))# -----------------------------# 3) Run per outcome# -----------------------------global_scores_all = []horizon_scores_all = []horizon_n_all = []saved_plot_files = []processed_outcomes = []for out_code, cfg in OUTCOME_CFG.items(): shap_key = find_first_key(shap_data, cfg["shap_key_candidates"])if shap_key isNone:# FIXED: Use display(HTML(...)) instead of print display(HTML(f"<p>Skipping {cfg['label']}: SHAP key not found ({cfg['shap_key_candidates']}).</p>"))continue shap_vals = np.asarray(shap_data[shap_key], dtype=float)if shap_vals.shape != X_np.shape:raiseValueError(f"{cfg['label']} SHAP shape mismatch: {shap_vals.shape} vs {X_np.shape}") processed_outcomes.append(out_code)# Main features fixed by global importance (for horizon comparability) global_imp = np.abs(shap_vals).mean(axis=0) main_idx = np.argsort(-global_imp)[:MAIN_TOP_K]# Global interaction scan (time-invariant model structure) df_g = interaction_scores( X=X_np, shap_mat=shap_vals, feature_names=feature_names, main_indices=main_idx, main_importance=global_imp, min_n=MIN_VALID_N ) df_g["outcome"] = cfg["label"] df_g["outcome_code"] = out_code df_g["scope"] ="global_log_hazard" df_g["horizon_months"] = np.nan df_g["n_patients"] =int(X_np.shape[0]) global_scores_all.append(df_g)# Horizon-specific interaction scan (risk-scale approximation)for h in TARGET_HORIZONS: df_risk = collect_risk_by_id(raw_data_log, split_map, hz_map, cfg, h) df_risk = df_risk[df_risk["id"].isin(id_to_row.keys())].copy() n_h =int(len(df_risk)) horizon_n_all.append({"outcome": cfg["label"],"outcome_code": out_code,"horizon_months": float(h),"n_patients_with_risk": n_h,"n_total_patients": int(X_np.shape[0]),"coverage_pct": (100.0* n_h / X_np.shape[0]) if X_np.shape[0] >0else np.nan })if n_h < MIN_VALID_N:continue row_idx = df_risk["id"].map(id_to_row).to_numpy(dtype=int) risk = np.clip(df_risk["risk"].to_numpy(dtype=float), 1e-12, 1.0-1e-12)# Risk-gradient weight from Cox transform: dr/deta = (-ln(1-r))*(1-r) w = (-np.log(1.0- risk)) * (1.0- risk)if np.isfinite(w).sum() ==0:continue cap = np.nanpercentile(w, WEIGHT_CLIP_PCT) w = np.clip(w, 0.0, cap) X_h = X_np[row_idx, :] S_h = shap_vals[row_idx, :] * w[:, None] df_h = interaction_scores( X=X_h, shap_mat=S_h, feature_names=feature_names, main_indices=main_idx, main_importance=global_imp, min_n=MIN_VALID_N ) df_h["outcome"] = cfg["label"] df_h["outcome_code"] = out_code df_h["scope"] ="horizon_risk_approx" df_h["horizon_months"] =float(h) df_h["n_patients"] = n_h df_h["weight_clip_pct"] =float(WEIGHT_CLIP_PCT) horizon_scores_all.append(df_h)# -----------------------------# 4) Combine + summarize# -----------------------------df_global_scores = pd.concat(global_scores_all, ignore_index=True) if global_scores_all else pd.DataFrame()df_horizon_scores = pd.concat(horizon_scores_all, ignore_index=True) if horizon_scores_all else pd.DataFrame()df_global_scores = add_significance(df_global_scores, ["outcome", "scope"], alpha=FDR_ALPHA) iflen(df_global_scores) else df_global_scoresdf_horizon_scores = add_significance(df_horizon_scores, ["outcome", "scope", "horizon_months"], alpha=FDR_ALPHA) iflen(df_horizon_scores) else df_horizon_scoresdf_horizon_n = pd.DataFrame(horizon_n_all)df_global_top = top_interactions(df_global_scores, INTER_TOP_K) iflen(df_global_scores) else pd.DataFrame()df_horizon_top = top_interactions(df_horizon_scores, INTER_TOP_K) iflen(df_horizon_scores) else pd.DataFrame()iflen(df_horizon_scores): tmp = df_horizon_scores.copy() tmp["rank_within_main_hz"] = ( tmp.groupby(["outcome", "horizon_months", "main_feature"])["abs_corr"] .rank(method="min", ascending=False) ) tmp["dir_pos"] = (tmp["direction_q4_q1"] =="Positive").astype(int) tmp["dir_neg"] = (tmp["direction_q4_q1"] =="Negative").astype(int) df_time = ( tmp.groupby(["outcome", "main_feature", "interactor"], as_index=False) .agg( horizons_seen=("horizon_months", "nunique"), mean_abs_corr=("abs_corr", "mean"), sd_abs_corr=("abs_corr", "std"), min_abs_corr=("abs_corr", "min"), max_abs_corr=("abs_corr", "max"), min_rank=("rank_within_main_hz", "min"), max_rank=("rank_within_main_hz", "max"), n_rows=("abs_corr", "size"), n_positive=("dir_pos", "sum"), n_negative=("dir_neg", "sum"), min_p_fdr=("p_fdr", "min"), ) ) df_time["sd_abs_corr"] = df_time["sd_abs_corr"].fillna(0.0) df_time["cv_abs_corr_pct"] =100.0* df_time["sd_abs_corr"] / np.clip(df_time["mean_abs_corr"], 1e-12, None) df_time["abs_corr_delta"] = df_time["max_abs_corr"] - df_time["min_abs_corr"] df_time["rank_range"] = df_time["max_rank"] - df_time["min_rank"] df_time["n_direction_known"] = df_time["n_positive"] + df_time["n_negative"] df_time["direction_consistent"] = (df_time["n_direction_known"] >0) & ( (df_time["n_positive"] ==0) | (df_time["n_negative"] ==0) ) df_time["direction_flip_flag"] = (df_time["n_positive"] >0) & (df_time["n_negative"] >0) df_time["dominant_direction"] = np.where( df_time["n_positive"] > df_time["n_negative"], "Positive", np.where(df_time["n_negative"] > df_time["n_positive"], "Negative", "Mixed") ) top_keys =set(zip(df_horizon_top["outcome"], df_horizon_top["main_feature"], df_horizon_top["interactor"])) iflen(df_horizon_top) elseset() df_time["in_top_at_least_once"] = [(o, m, i) in top_keys for o, m, i inzip(df_time["outcome"], df_time["main_feature"], df_time["interactor"])] df_time["time_dependent_flag"] = ( (df_time["horizons_seen"] >=2) & ((df_time["rank_range"] >= TIME_DEP_RANK_RANGE_MIN) | (df_time["abs_corr_delta"] >= TIME_DEP_ABS_DELTA_MIN)) ) df_time["time_dependent_top_flag"] = df_time["time_dependent_flag"] & df_time["in_top_at_least_once"]# Add these aliases (clearer terminology) df_time["horizon_salience_variation_flag"] = df_time["time_dependent_flag"] df_time["horizon_salience_variation_top_flag"] = df_time["time_dependent_top_flag"]else: df_time = pd.DataFrame()# -----------------------------# 5) Plots: time profiles of top time-dependent interactions# -----------------------------iflen(df_horizon_scores) andlen(df_time):for out insorted(df_horizon_scores["outcome"].unique()): td = df_time[ (df_time["outcome"] == out) & (df_time["time_dependent_top_flag"]) ].head(PLOT_TOP_TIME_DEP)iflen(td) ==0:continue keys =set(zip(td["main_feature"], td["interactor"])) dsub = df_horizon_scores[ (df_horizon_scores["outcome"] == out) & (df_horizon_scores.apply(lambda r: (r["main_feature"], r["interactor"]) in keys, axis=1)) ].copy() pivot = dsub.pivot_table( index=["main_feature", "interactor"], columns="horizon_months", values="abs_corr", aggfunc="mean" ) plt.figure(figsize=(12, 8))for (mf, it), row in pivot.iterrows(): xs = [h for h in TARGET_HORIZONS if h in row.index and np.isfinite(row[h])] ys = [row[h] for h in xs]iflen(xs) >=2: plt.plot(xs, ys, marker="o", linewidth=1.8, label=f"{mf} × {it}") plt.title(f"{out}: Time-Dependent Interaction Profiles (abs corr of residual signal)") plt.xlabel("Horizon (months)") plt.ylabel("Interaction strength (abs corr)") plt.xticks(TARGET_HORIZONS) plt.grid(alpha=0.25) plt.legend(loc="best", fontsize=8, ncol=1) plt.tight_layout() saved_plot_files.extend(save_current_figure(f"xgb11_dual_{out.lower()}_time_profiles_{source_tag}")) plt.show() plt.close()# -----------------------------# 6) Save outputs# -----------------------------f_global_scores = os.path.join(OUT_DIR, f"xgb11_dual_interactions_global_scores_{source_tag}.csv")f_global_top = os.path.join(OUT_DIR, f"xgb11_dual_interactions_global_top_{source_tag}.csv")f_horizon_scores = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_scores_{source_tag}.csv")f_horizon_top = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_top_{source_tag}.csv")f_time = os.path.join(OUT_DIR, f"xgb11_dual_interactions_time_dependent_{source_tag}.csv")f_hn = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_sample_sizes_{source_tag}.csv")f_info = os.path.join(OUT_DIR, f"xgb11_dual_interactions_run_info_{source_tag}.json")iflen(df_global_scores): df_global_scores.to_csv(f_global_scores, index=False)iflen(df_global_top): df_global_top.to_csv(f_global_top, index=False)iflen(df_horizon_scores): df_horizon_scores.to_csv(f_horizon_scores, index=False)iflen(df_horizon_top): df_horizon_top.to_csv(f_horizon_top, index=False)iflen(df_time): df_time.to_csv(f_time, index=False)iflen(df_horizon_n): df_horizon_n.to_csv(f_hn, index=False)run_info = {"source_bundle_tag": source_tag,"source_files": {"shap": shap_file,"raw": raw_file,"baseline_hazards": hz_file,"cv_splits": split_file },"outcomes_processed": processed_outcomes,"n_patients": int(X_all.shape[0]),"n_features": int(X_all.shape[1]),"main_top_k": int(MAIN_TOP_K),"inter_top_k": int(INTER_TOP_K),"horizons_months": [float(h) for h in TARGET_HORIZONS],"min_valid_n": int(MIN_VALID_N),"global_method": "Residual-correlation interaction heuristic on SHAP log-hazard scale","horizon_method": "Approximate risk-scale interaction via weighted SHAP: SHAP * dr/deta, dr/deta=(-ln(1-r))*(1-r)","weight_clip_pct": float(WEIGHT_CLIP_PCT),"fdr_alpha": float(FDR_ALPHA),"direction_method": "quartile delta on residuals: mean(resid|Q4 interactor) - mean(resid|Q1 interactor)","time_dependent_rule": {"rank_range_min": int(TIME_DEP_RANK_RANGE_MIN),"abs_corr_delta_min": float(TIME_DEP_ABS_DELTA_MIN),"flag_definition": "time_dependent_flag = (rank_range >= threshold) OR (abs_corr_delta >= threshold), with >=2 horizons" },"terminology_note": ("'time_dependent_flag' / 'horizon_salience_variation_flag' indicates horizon-dependent ""interaction salience from risk transformation (and possible PH departures), not ""time-varying coefficients learned by the Cox model." ),"note": "Cox model structure is time-invariant; horizon-specific differences here represent risk-transform-dependent interaction salience, not different learned trees."}withopen(f_info, "w", encoding="utf-8") as f: json.dump(run_info, f, indent=2)# -----------------------------# 7) Display quick summaries# -----------------------------display(Markdown("### Global Interaction Top (log-hazard SHAP)"))display(df_global_top.head(30) iflen(df_global_top) else pd.DataFrame())display(Markdown("### Horizon Interaction Top (risk-scale approximation)"))display(df_horizon_top.head(30) iflen(df_horizon_top) else pd.DataFrame())display(Markdown("### Time-Dependent Interaction Candidates"))iflen(df_time): display(df_time[df_time["time_dependent_top_flag"]].head(50))else: display(pd.DataFrame())# FIXED: Use display(HTML(...)) instead of printdisplay(HTML("<br><b>Saved files:</b>"))for p in [f_global_scores, f_global_top, f_horizon_scores, f_horizon_top, f_time, f_hn, f_info]:if os.path.exists(p): display(HTML(f" - {p}"))if saved_plot_files: display(HTML("<br><b>Saved plots:</b>"))for p in saved_plot_files: display(HTML(f" - {p}"))
Step 11 Interaction Discovery (DUAL)
Source bundle tag: 20260306_1821
Patients: 70521
Features: 56
Horizons: 3, 6, 12, 36, 60, 96 months
Global interactions: log-hazard SHAP scale
Horizon interactions: risk-scale approximation via dr/deta * SHAP
Global Interaction Top (log-hazard SHAP)
main_idx
inter_idx
main_feature
main_rank_global
main_importance_global
interactor
corr_resid_vs_interactor
abs_corr
delta_q4_q1_resid
direction_q4_q1
n_dir_valid
n_q1
n_q4
n_valid
main_linear_slope
main_linear_intercept
outcome
outcome_code
scope
horizon_months
n_patients
p_value
p_fdr
signif_fdr_0.01
0
0
7
adm_age_rec3
1
0.497441
ed_attainment_corr
0.064949
0.064949
0.011032
Positive
70521
51972
18397
70521
0.051331
-1.986121
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
1
0
10
adm_age_rec3
1
0.497441
eva_fam
0.060500
0.060500
0.008246
Positive
70521
39316
31202
70521
0.051331
-1.986121
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
2
0
51
adm_age_rec3
1
0.497441
occupation_condition_corr24_unemployed
0.060183
0.060183
0.012890
Positive
70521
46023
24498
70521
0.051331
-1.986121
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
3
0
11
adm_age_rec3
1
0.497441
eva_relinterp
0.059781
0.059781
0.008260
Positive
70521
39720
30796
70521
0.051331
-1.986121
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
4
0
13
adm_age_rec3
1
0.497441
eva_sm
0.058153
0.058153
0.006740
Positive
70521
40420
30092
70521
0.051331
-1.986121
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
5
42
0
primary_sub_mod_alcohol
2
0.267121
adm_age_rec3
-0.145518
0.145518
-0.022639
Negative
70521
17641
17640
70521
0.556013
-0.243998
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
6
42
4
primary_sub_mod_alcohol
2
0.267121
prim_sub_freq_rec
-0.100948
0.100948
-0.012761
Negative
70521
39308
31002
70521
0.556013
-0.243998
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
7
42
44
primary_sub_mod_alcohol
2
0.267121
primary_sub_mod_cocaine_powder
0.099220
0.099220
0.002855
Positive
70521
56727
70521
70521
0.556013
-0.243998
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
8
42
47
primary_sub_mod_alcohol
2
0.267121
plan_type_corr_pg_pr
-0.084876
0.084876
-0.001739
Negative
70521
62782
70521
70521
0.556013
-0.243998
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
9
42
43
primary_sub_mod_alcohol
2
0.267121
primary_sub_mod_cocaine_paste
-0.079737
0.079737
-0.009595
Negative
70521
43856
26665
70521
0.556013
-0.243998
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
10
4
42
prim_sub_freq_rec
3
0.114488
primary_sub_mod_alcohol
-0.062824
0.062824
-0.008331
Negative
70521
46546
23975
70521
0.155660
-0.213759
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
11
4
44
prim_sub_freq_rec
3
0.114488
primary_sub_mod_cocaine_powder
0.053674
0.053674
0.001663
Positive
70521
56727
70521
70521
0.155660
-0.213759
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
12
4
6
prim_sub_freq_rec
3
0.114488
urbanicity_cat
-0.044370
0.044370
-0.001224
Negative
70521
57504
70521
70521
0.155660
-0.213759
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
13
4
35
prim_sub_freq_rec
3
0.114488
adm_motive_sanitary_sector
0.029000
0.029000
0.003935
Positive
70521
48588
21933
70521
0.155660
-0.213759
Death
death
global_log_hazard
NaN
70521
1.332268e-14
4.774942e-14
True
14
4
28
prim_sub_freq_rec
3
0.114488
sub_dep_icd10_status_drug_dependence
0.028462
0.028462
0.004015
Positive
70521
19222
51299
70521
0.155660
-0.213759
Death
death
global_log_hazard
NaN
70521
4.041212e-14
1.423209e-13
True
15
51
52
occupation_condition_corr24_unemployed
4
0.099876
occupation_condition_corr24_inactive
0.145013
0.145013
0.002365
Positive
70521
59054
70521
70521
0.217878
-0.079246
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
16
51
0
occupation_condition_corr24_unemployed
4
0.099876
adm_age_rec3
0.105125
0.105125
0.009976
Positive
70521
17641
17640
70521
0.217878
-0.079246
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
17
51
24
occupation_condition_corr24_unemployed
4
0.099876
sex_rec_woman
0.052826
0.052826
0.004477
Positive
70521
52439
18082
70521
0.217878
-0.079246
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
18
51
43
occupation_condition_corr24_unemployed
4
0.099876
primary_sub_mod_cocaine_paste
-0.052387
0.052387
-0.003998
Negative
70521
43856
26665
70521
0.217878
-0.079246
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
19
51
53
occupation_condition_corr24_unemployed
4
0.099876
marital_status_rec_single
-0.042157
0.042157
-0.003140
Negative
70521
31856
38601
70521
0.217878
-0.079246
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
20
22
0
any_phys_dx
5
0.092606
adm_age_rec3
-0.177597
0.177597
-0.011847
Negative
70521
17641
17640
70521
0.465739
-0.059663
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
21
22
42
any_phys_dx
5
0.092606
primary_sub_mod_alcohol
-0.151679
0.151679
-0.008320
Negative
70521
46546
23975
70521
0.465739
-0.059663
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
22
22
39
any_phys_dx
5
0.092606
first_sub_used_alcohol
-0.100611
0.100611
-0.005474
Negative
70521
26547
39050
70521
0.465739
-0.059663
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
23
22
43
any_phys_dx
5
0.092606
primary_sub_mod_cocaine_paste
0.089074
0.089074
0.004773
Positive
70521
43856
26665
70521
0.465739
-0.059663
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
24
22
23
any_phys_dx
5
0.092606
polysubstance_strict
0.063871
0.063871
0.003744
Positive
70521
18940
51581
70521
0.465739
-0.059663
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
25
12
42
eva_ocupacion
6
0.074887
primary_sub_mod_alcohol
0.096788
0.096788
0.005576
Positive
70521
46546
23975
70521
0.096627
-0.120813
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
26
12
43
eva_ocupacion
6
0.074887
primary_sub_mod_cocaine_paste
-0.087721
0.087721
-0.004937
Negative
70521
43856
26665
70521
0.096627
-0.120813
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
27
12
2
eva_ocupacion
6
0.074887
dit_m
-0.054846
0.054846
-0.002005
Negative
70521
17697
17661
70521
0.096627
-0.120813
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
28
12
17
eva_ocupacion
6
0.074887
dg_psiq_cie_10_instudy
-0.049600
0.049600
-0.000620
Negative
70521
58292
70521
70521
0.096627
-0.120813
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
29
12
31
eva_ocupacion
6
0.074887
tr_outcome_dropout
0.046922
0.046922
0.002567
Positive
70521
32933
37588
70521
0.096627
-0.120813
Death
death
global_log_hazard
NaN
70521
0.000000e+00
0.000000e+00
True
Horizon Interaction Top (risk-scale approximation)
Interactions detected on log-hazard scale may differ in magnitude and direction from risk-scale interactions due to non-linear transformation (∂risk/∂η = (-ln(1-r))(1-r)). For mortality, 33% of top interactions changed direction between scales, while readmission showed only 3% direction changes. We report both scales for completeness and emphasize risk-scale interactions for clinical interpretation.
Time-dependent interaction thresholds were set at rank range ≥8 positions (91.8th percentile for death, 38.6th for readmission) and correlation delta ≥0.03 (100th percentile for death, 94.9th for readmission), based on empirical distributions of interaction variability across horizons.
Readmission interactions showed substantially more time-dependent variability (68% flagged) compared to mortality (11% flagged), suggesting readmission risk factors evolve more dynamically over follow-up while mortality risk factors remain relatively stable.
Functional form
Code
#@title ⚡ Step 12: Functional Form Analysis (XGBoost - DUAL SHAP, Aggregated)import osimport reimport globimport picklefrom datetime import datetimefrom pathlib import Pathimport numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport matplotlib as mplfrom IPython.display import display, Markdown, HTML # FIXED: Added HTMLif"PROJECT_ROOT"notinglobals():raiseRuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")PROJECT_ROOT = Path(PROJECT_ROOT).resolve()TABLE_DIR = PROJECT_ROOT /"_out_tabble"# use your requested folder nameFIG_DIR = PROJECT_ROOT /"_figs"TABLE_DIR.mkdir(parents=True, exist_ok=True)FIG_DIR.mkdir(parents=True, exist_ok=True)IN_DIR = PROJECT_ROOT /"_out"# where Step 5 SHAP files are read fromtry:from scipy.stats import f as f_distexceptException: f_dist =None# --- 1) CONFIG ---CONTINUOUS_VARS = ['adm_age_rec3', 'porc_pobr', 'dit_m', 'tenure_status_household', 'urbanicity_cat', 'evaluacindelprocesoteraputico', 'eva_consumo', 'eva_fam', 'eva_relinterp', 'eva_ocupacion', 'eva_sm', 'eva_fisica', 'eva_transgnorma', 'prim_sub_freq_rec', 'ed_attainment_corr']timestamp = datetime.now().strftime("%Y%m%d_%H%M")EXCEL_PATH = TABLE_DIR /f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.xlsx"PARQUET_PATH = TABLE_DIR /f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.parquet"CSV_PATH = TABLE_DIR /f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.csv"MAX_SCATTER_N =10000# max points shown in each scatter plot; if more, it samples 5000 for speed/readability.POLY_DEGREE =3# polynomial degree used for trend fit (cubic).RANDOM_STATE =2125SHAP_SCALE =1.0#no rescalingSHAP_UNIT_LABEL ="Log-Hazard (model margin)"N_BOOT =500CI_ALPHA =0.05timestamp = datetime.now().strftime("%Y%m%d_%H%M")EXCEL_FILENAME =f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.xlsx"PARQUET_FILENAME =f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.parquet"mpl.rcParams.update({"font.family": "serif","font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],"pdf.fonttype": 42,"ps.fonttype": 42,"font.size": 14,"axes.grid": True,"grid.alpha": 0.30})# --- 2) HELPERS ---def sanitize_name(txt):return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(txt))def make_sheet_name(raw, used): s = re.sub(r"[\[\]\*\?\/\\:]", "_", raw)[:31] base = s i =1while s in used: suffix =f"_{i}" s = (base[:31-len(suffix)] + suffix)[:31] i +=1 used.add(s)return sdef pick_latest_file():"""Find latest SHAP file, handling both _mar26 and non-_mar26 suffixes."""# FIXED: Try both patterns patterns = ["xgb6_corr_DUAL_SHAP_Aggregated_*.pkl","xgb6_corr_DUAL_SHAP_Aggregated_*_mar26.pkl" ] files = []for pat in patterns: files.extend(IN_DIR.glob(pat))ifnot files:raiseFileNotFoundError(f"No xgb6_corr_DUAL_SHAP_Aggregated_*.pkl found in {IN_DIR}")returnsorted(files, key=lambda p: p.stat().st_mtime)[-1]def f_test_linear_vs_quadratic(x, y): x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) m = np.isfinite(x) & np.isfinite(y) x, y = x[m], y[m] n = x.sizeif n <10:return np.nan, np.nan X1 = np.column_stack([np.ones(n), x]) X2 = np.column_stack([np.ones(n), x, x**2])try: b1 = np.linalg.lstsq(X1, y, rcond=None)[0] b2 = np.linalg.lstsq(X2, y, rcond=None)[0] rss1 = np.sum((y - X1 @ b1) **2) rss2 = np.sum((y - X2 @ b2) **2)exceptException:return np.nan, np.nan df1 =1 df2 = n - X2.shape[1]if df2 <=0or rss2 <=0or rss1 < rss2:return np.nan, np.nan f_stat = ((rss1 - rss2) / df1) / (rss2 / df2)if f_dist isNone:returnfloat(f_stat), np.nan p_val =float(f_dist.sf(f_stat, df1, df2))returnfloat(f_stat), p_valdef polyfit_bootstrap_ci(x, y, degree=3, n_boot=300, alpha=0.05, seed=2125): x = np.asarray(x, dtype=float) y = np.asarray(y, dtype=float) m = np.isfinite(x) & np.isfinite(y) x, y = x[m], y[m]if x.size <20or np.unique(x).size <4:returnNone rng = np.random.default_rng(seed) xg = np.linspace(np.min(x), np.max(x), 200) xc = x - np.mean(x) xgc = xg - np.mean(x) deg =min(degree, np.unique(xc).size -1, xc.size -1)if deg <1:returnNonetry: p = np.poly1d(np.polyfit(xc, y, deg)) yhat = p(xgc)exceptException:returnNone boot = [] n = xc.sizefor _ inrange(n_boot): idx = rng.integers(0, n, n) xb, yb = xc[idx], y[idx]if np.unique(xb).size < (deg +1):continuetry: pb = np.poly1d(np.polyfit(xb, yb, deg)) boot.append(pb(xgc))exceptException:continueiflen(boot) <20:return xg, yhat, None, None boot = np.vstack(boot) lo = np.quantile(boot, alpha /2.0, axis=0) hi = np.quantile(boot, 1- alpha /2.0, axis=0)return xg, yhat, lo, hi# --- 3) LOAD AGGREGATED SHAP ---display(Markdown("### Loading DUAL SHAP aggregated file"))shap_file = pick_latest_file()display(Markdown(f"Using file: `{shap_file}`"))with shap_file.open("rb") as f: data = pickle.load(f)ifnotisinstance(data, dict):raiseValueError("Loaded SHAP file is not a dict.")if"X_all"notin data or"feature_names"notin data:raiseValueError("Missing required keys: X_all and/or feature_names.")X_all = data["X_all"]feature_names = data["feature_names"]ifnotisinstance(X_all, pd.DataFrame): X_all = pd.DataFrame(np.asarray(X_all), columns=feature_names)else:iflen(feature_names) == X_all.shape[1]: X_all = X_all.copy() X_all.columns = feature_names# FIXED: Added duplicate feature checkiflen(feature_names) !=len(set(feature_names)): dupes = [f for f in feature_names if feature_names.count(f) >1]raiseValueError(f"Duplicate feature names detected: {set(dupes)}")shap_arrays = {}if"shap_r_all"in data: shap_arrays["Readmission"] = np.asarray(data["shap_r_all"])if"shap_d_all"in data: shap_arrays["Death"] = np.asarray(data["shap_d_all"])if"shap_all"in data andnot shap_arrays: shap_arrays["Readmission"] = np.asarray(data["shap_all"])ifnot shap_arrays:raiseValueError("No SHAP arrays found (shap_r_all / shap_d_all / shap_all).")display(Markdown(f"Outcomes found: `{list(shap_arrays.keys())}`"))display(Markdown("Detected aggregated (non-horizon) SHAP structure."))# --- 4) FUNCTIONAL FORM ANALYSIS ---all_data_list = []summary_rows = []for outcome, shap_mat in shap_arrays.items():if shap_mat.ndim !=2:continue n =min(X_all.shape[0], shap_mat.shape[0]) p =min(X_all.shape[1], shap_mat.shape[1]) X_use = X_all.iloc[:n, :p].copy() S_use = shap_mat[:n, :p]for var in CONTINUOUS_VARS:if var notin X_use.columns:continue col_idx = X_use.columns.get_loc(var) x_vec = X_use.iloc[:, col_idx].to_numpy() y_vec = np.asarray(S_use[:, col_idx], dtype=float) * SHAP_SCALE mask = np.isfinite(x_vec) & np.isfinite(y_vec) x_vec = x_vec[mask] y_vec = y_vec[mask]if x_vec.size ==0:continue f_stat, p_val = f_test_linear_vs_quadratic(x_vec, y_vec) all_data_list.append(pd.DataFrame({"Feature_Value": x_vec,"SHAP_Impact": y_vec,"Predictor": var,"Outcome": outcome,"Scope": "Aggregated_All_Times" })) corr = np.nanif np.std(x_vec) >0and np.std(y_vec) >0: corr =float(np.corrcoef(x_vec, y_vec)[0, 1]) summary_rows.append({"Outcome": outcome,"Scope": "Aggregated_All_Times","Predictor": var,"N": int(x_vec.size),"Mean_SHAP": float(np.mean(y_vec)),"MeanAbs_SHAP": float(np.mean(np.abs(y_vec))),"Q10_SHAP": float(np.quantile(y_vec, 0.10)),"Q50_SHAP": float(np.quantile(y_vec, 0.50)),"Q90_SHAP": float(np.quantile(y_vec, 0.90)),"F_linear_vs_quad": f_stat,"P_linear_vs_quad": p_val,"Nonlinear_p_lt_0_05": bool(pd.notna(p_val) and p_val <0.05),"Corr_X_SHAP": corr }) plt.figure(figsize=(8, 5))if x_vec.size > MAX_SCATTER_N: rng = np.random.default_rng(RANDOM_STATE) idx = rng.choice(x_vec.size, MAX_SCATTER_N, replace=False) plt.scatter(x_vec[idx], y_vec[idx], alpha=0.30, c="#1f77b4", s=15, edgecolors="none", label="Patients (sample)")else: plt.scatter(x_vec, y_vec, alpha=0.45, c="#1f77b4", s=18, edgecolors="none", label="Patients") fit = polyfit_bootstrap_ci( x_vec, y_vec, degree=POLY_DEGREE, n_boot=N_BOOT, alpha=CI_ALPHA, seed=RANDOM_STATE )if fit isnotNone: x_grid, y_hat, y_lo, y_hi = fitif y_lo isnotNoneand y_hi isnotNone: plt.fill_between(x_grid, y_lo, y_hi, color="red", alpha=0.15, label=f"{int((1-CI_ALPHA)*100)}% CI") plt.plot(x_grid, y_hat, "r--", linewidth=2.2, label=f"Trend (poly-{POLY_DEGREE})") plt.axhline(0, color="k", linestyle=":", linewidth=1) plt.title(f"Functional Form: {var}\n({outcome}, Aggregated)", fontsize=13, fontweight="bold") plt.xlabel(f"Feature Value: {var}") plt.ylabel(f"SHAP Impact ({SHAP_UNIT_LABEL})") plt.legend(loc="best") fname =f"XGB12_corr_DUAL_FuncForm_{sanitize_name(outcome)}_{sanitize_name(var)}_{timestamp}" plt.savefig(FIG_DIR /f"{fname}.png", dpi=300, bbox_inches="tight") plt.savefig(FIG_DIR /f"{fname}.pdf", bbox_inches="tight") plt.show()ifnot all_data_list:raiseValueError("No functional-form data created. Check predictor names and SHAP structure.")full_df = pd.concat(all_data_list, ignore_index=True)summary_df = pd.DataFrame(summary_rows).sort_values(["Outcome", "Predictor"]).reset_index(drop=True)# --- 5) SAVE OUTPUTS ---full_df.to_parquet(PARQUET_PATH, index=False)full_df.to_csv(CSV_PATH, index=False)used_sheet_names =set()with pd.ExcelWriter(EXCEL_PATH, engine="xlsxwriter") as writer: summary_df.to_excel(writer, sheet_name="Effects_Summary", index=False) meta_df = pd.DataFrame({"Item": ["SHAP_SCALE", "SHAP_UNIT_LABEL", "Interpretation", "Scope"],"Value": [ SHAP_SCALE, SHAP_UNIT_LABEL,"For survival:cox, SHAP is on log-hazard-ratio scale.","Aggregated across time horizons (no horizon-specific SHAP in source file)." ] }) meta_df.to_excel(writer, sheet_name="Meta", index=False)for (outcome, predictor), g in full_df.groupby(["Outcome", "Predictor"], sort=True): sheet_raw =f"{outcome[:1]}_{predictor[:18]}" sheet_name = make_sheet_name(sheet_raw, used_sheet_names) g_to_save = g[["Feature_Value", "SHAP_Impact"]]iflen(g_to_save) >100000: g_to_save = g_to_save.sample(100000, random_state=RANDOM_STATE) g_to_save.to_excel(writer, sheet_name=sheet_name, index=False)# FIXED: Use display(HTML(...)) instead of print for final outputdisplay(HTML(f"<b>Done.</b> Excel: <code>{EXCEL_PATH}</code>"))display(HTML(f"<b>Done.</b> Parquet: <code>{PARQUET_PATH}</code>"))display(HTML(f"<b>Done.</b> CSV: <code>{CSV_PATH}</code>"))display(HTML(f"<b>Plots saved in:</b> <code>{FIG_DIR}</code>"))global_functional_data = full_dfglobal_functional_summary = summary_df
Loading DUAL SHAP aggregated file
Using file: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb6_corr_DUAL_SHAP_Aggregated_20260306_1821_mar26.pkl
Death: 4×4 plot saved (15 vars + legend, no title)
PNG: XGB12_Faceted4x4_Death_20260306_1834.png
PDF: XGB12_Faceted4x4_Death_20260306_1834.pdf
We used XGBoost to check whether risk really changes linearly with predictors, then used SHAP to see each feature’s true effect: if the SHAP curve looked straight (e.g., age vs. death) we kept a simple linear term; if it was curved or U‑shaped (e.g., age vs. readmission) we modeled it with splines or categories; and if SHAP revealed interactions (e.g., age boosting the effect of alcohol) we added those interaction terms to the final statistical model, so the final model reflects the real shapes and interactions the machine learner found.
Balance between train and test
For categorical variables with more than two levels, standardized mean differences (SMDs) were computed at the level of each category by recoding the factor into a set of binary indicators (one-vs-rest). For each level, the SMD was calculated as the difference in proportions between training and testing samples divided by the pooled standard deviation of a Bernoulli variable.
Code
from pathlib import Pathimport pandas as pdfrom IPython.display import displayBASE_DIR = Path(r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1")imputation_nodum_1 = pd.read_parquet( BASE_DIR /f"imputation_nondum_1.parquet", engine="fastparquet" )# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7mask_adm_death = ( imputation_nodum_1["tr_outcome"].str.contains("adm reasons", case=False, na=False)& imputation_nodum_1["death_time_from_disch_m"].notna()& (imputation_nodum_1["death_time_from_disch_m"] <=0.23)& (imputation_nodum_1["death_event"] >0))# Condition 2: tr_outcome_other == 1 (any time)mask_other = imputation_nodum_1["tr_outcome"] =="other"# Combined exclusion maskexclude = mask_adm_death | mask_otherkeep =~exclude# ── Filter imputation_nodum_1 in place ──imputation_nodum_1 = imputation_nodum_1[~exclude].copy()# Only for imputation_nodum_1 (non-dummified)col ="first_sub_used"# To match your previous grouped dummies:to_other = {"hallucinogens","opioids","amphetamine-type stimulants","tranquilizers/hypnotics","inhalants","others",}# Keep original if neededimputation_nodum_1["first_sub_used_original"] = imputation_nodum_1[col]# Recodes = imputation_nodum_1[col].astype("string").str.strip().str.lower()imputation_nodum_1[col] = s.mask(s.isin(to_other), "other")# Optional: back to categoricalimputation_nodum_1[col] = imputation_nodum_1[col].astype("category")# Check resultdisplay(imputation_nodum_1[col].value_counts(dropna=False))
from pathlib import Pathimport osimport pandas as pdPROJECT_ROOT = Path.cwd() # current notebook directoryOUT_DIR = PROJECT_ROOT /"_out"split_seed2125 = pd.read_parquet( OUT_DIR /f"readm_split_seed2125_test20_mar26.parquet", engine="fastparquet" )
Code
import numpy as npimport pandas as pdfrom IPython.display import displayfrom pathlib import Pathif"PROJECT_ROOT"notinglobals():raiseRuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")PROJECT_ROOT = Path(PROJECT_ROOT).resolve()OUT_DIR = PROJECT_ROOT /"_out_tabble"# or "_out" if you preferOUT_DIR.mkdir(parents=True, exist_ok=True)CONTINUOUS_VARS = ['adm_age_rec3', 'porc_pobr', 'dit_m']# Optional: exclude outcome/time columns from baseline Table 1EXCLUDE_COLS = {'readmit_time_from_disch_m', 'readmit_event', 'death_time_from_disch_m', 'death_event', 'center_id', 'readmit_time_from_adm_m', 'death_time_from_adm_m', 'first_sub_used_original'}MAX_LEVELS =30# safety cap to avoid huge tables for very high-cardinality vars# --- Helpers ---def fmt_median_iqr(x): x = pd.to_numeric(pd.Series(x), errors="coerce").dropna()iflen(x) ==0:return"NA" q1, med, q3 = x.quantile(0.25), x.quantile(0.50), x.quantile(0.75)returnf"{med:.2f} [{q1:.2f}, {q3:.2f}]"def smd_cont(train, test): train = pd.to_numeric(pd.Series(train), errors="coerce").dropna() test = pd.to_numeric(pd.Series(test), errors="coerce").dropna()iflen(train) <2orlen(test) <2:return np.nan pooled = np.sqrt((train.std(ddof=1)**2+ test.std(ddof=1)**2) /2.0)if pooled ==0ornot np.isfinite(pooled):return np.nanreturn (train.mean() - test.mean()) / pooleddef smd_bin(train01, test01): train01 = pd.to_numeric(pd.Series(train01), errors="coerce").dropna() test01 = pd.to_numeric(pd.Series(test01), errors="coerce").dropna()iflen(train01) ==0orlen(test01) ==0:return np.nan p1, p0 = train01.mean(), test01.mean() den = np.sqrt((p1 * (1- p1) + p0 * (1- p0)) /2.0)if den ==0ornot np.isfinite(den):return np.nanreturn (p1 - p0) / dendef fmt_mean_sd(x): x = pd.to_numeric(pd.Series(x), errors="coerce").dropna()return"NA"iflen(x) ==0elsef"{x.mean():.2f} +/- {x.std(ddof=1):.2f}"def fmt_n_pct(x01): x01 = pd.to_numeric(pd.Series(x01), errors="coerce").dropna()iflen(x01) ==0:return"NA" n =int((x01 ==1).sum()) pct =100.0* x01.mean()returnf"{n} ({pct:.1f}%)"def indicator_for_level(series, level, string_mode=False): s = series.copy()if string_mode: s = s.astype("string").str.strip() lvl =str(level).strip() ind = (s == lvl).astype(float) ind[s.isna()] = np.nanreturn ind ind = (s == level).astype(float) ind[s.isna()] = np.nanreturn ind# --- Split alignment ---split = split_seed2125.copy()if {"row_id", "is_train"}.issubset(split.columns): split = split.sort_values("row_id").reset_index(drop=True)ifnot np.array_equal(split["row_id"].to_numpy(), np.arange(len(split))):raiseValueError("split_seed2125$row_id is not 0..N-1 after sorting.") is_train = split["is_train"].astype(bool).to_numpy()elif"is_train"in split.columns: is_train = split["is_train"].astype(bool).to_numpy()else:raiseValueError("split_seed2125 must contain column 'is_train'.")base = imputation_nodum_1.reset_index(drop=True).copy()iflen(base) !=len(is_train):raiseValueError(f"Row mismatch: imputation_nodum_1={len(base)}, split={len(is_train)}")train_df = base.loc[is_train].copy()test_df = base.loc[~is_train].copy()# --- Build Table 1 ---rows = []excluded_high_card = []missing_cont = [c for c in CONTINUOUS_VARS if c notin base.columns]for col in base.columns:if col in EXCLUDE_COLS:continue miss_tr =100.0* train_df[col].isna().mean() miss_te =100.0* test_df[col].isna().mean()# 1) forced continuous varsif col in CONTINUOUS_VARS: smd = smd_cont(train_df[col], test_df[col]) rows.append({"Variable": col,"Type": "Continuous","Level": "","Train": fmt_mean_sd(train_df[col]), # keep old display if you want"Test": fmt_mean_sd(test_df[col]),"Train_Median_IQR": fmt_median_iqr(train_df[col]),"Test_Median_IQR": fmt_median_iqr(test_df[col]),"SMD": smd,"|SMD|": abs(smd) if np.isfinite(smd) else np.nan,"%Missing_Train": round(miss_tr, 2),"%Missing_Test": round(miss_te, 2), })continue s_full = base[col] non_na = s_full.dropna()if non_na.empty:continue# 2) bool / numeric / categorical handlingif pd.api.types.is_bool_dtype(s_full): levels_all = [False, True] levels_report = [True] # one row for dichotomous variable string_mode =Falseelif pd.api.types.is_numeric_dtype(s_full): vals = pd.to_numeric(non_na, errors="coerce") vals = np.sort(pd.unique(vals[np.isfinite(vals)])) levels_all =list(vals) string_mode =False# 0/1 dichotomous -> one row (level=1)iflen(levels_all) ==2andset(np.round(levels_all, 10)).issubset({0.0, 1.0}): levels_report = [1.0]else:iflen(levels_all) > MAX_LEVELS: excluded_high_card.append((col, len(levels_all)))continue levels_report = levels_allelse: vals =sorted(non_na.astype("string").str.strip().dropna().unique().tolist()) levels_all = vals string_mode =Trueiflen(levels_all) > MAX_LEVELS: excluded_high_card.append((col, len(levels_all)))continue# dichotomous categorical -> one row (last level)iflen(levels_all) ==2: levels_report = [levels_all[-1]]else: levels_report = levels_allfor lvl in levels_report: ind_tr = indicator_for_level(train_df[col], lvl, string_mode=string_mode) ind_te = indicator_for_level(test_df[col], lvl, string_mode=string_mode) smd = smd_bin(ind_tr, ind_te) var_type ="Dichotomous"iflen(levels_all) ==2else"Categorical (level)" rows.append({"Variable": col,"Type": var_type,"Level": str(lvl),"Train": fmt_n_pct(ind_tr),"Test": fmt_n_pct(ind_te),"Train_Median_IQR": "","Test_Median_IQR": "","SMD": smd,"|SMD|": abs(smd) if np.isfinite(smd) else np.nan,"%Missing_Train": round(miss_tr, 2),"%Missing_Test": round(miss_te, 2), }) table1_split = pd.DataFrame(rows)table1_split = table1_split.sort_values( ["Type", "Variable", "|SMD|"], ascending=[True, True, False], na_position="last").reset_index(drop=True)summary_split = pd.DataFrame({"N_train": [len(train_df)],"N_test": [len(test_df)],"N_total": [len(base)],"N_rows_table1": [len(table1_split)],"N_missing_continuous_vars": [len(missing_cont)],"N_excluded_high_cardinality_vars": [len(excluded_high_card)],})display(summary_split)if missing_cont:print("Missing CONTINUOUS_VARS in dataset:", missing_cont)if excluded_high_card:print("Excluded high-cardinality variables (name, n_levels):", excluded_high_card)# optional exportcsv_path = OUT_DIR /"table1_split_seed2125_multilevel_mar26.csv"table1_split.to_csv(csv_path, index=False)print(f"Saved: {csv_path}")
N_train
N_test
N_total
N_rows_table1
N_missing_continuous_vars
N_excluded_high_cardinality_vars
0
70521
17631
88152
89
0
0
Code
import pandas as pdfrom IPython.display import HTML, display# Reset options so Pandas doesn't force everythingpd.set_option('display.max_rows', None)pd.set_option('display.max_columns', None)pd.set_option('display.width', None)pd.set_option('display.max_colwidth', None)# Convert DataFrame to HTML and wrap in a scrollable divhtml_table = table1_split.to_html()scroll_box =f"""<div style="max-height:500px; max-width:1000px; overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">{html_table}</div>"""display(HTML(scroll_box))