XGBoost (consolidated version)

This notebook consolidates the evaluation of two fine-tuned XGBoost models for predicting hospital readmission and mortality. It prepares longitudinal data for survival analysis by eliminating immortal time bias. The workflow manages competing outcomes through Cause-Specific Hazard preparations. Finally, it constructs an evaluation grid to assess the model’s predictive capabilities for both events across defined time intervals.

Author

ags

Published

March 6, 2026

Best predictors based on XGBOOST (combination)

0. Package loading and installation

ML

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda install -c conda-forge \
#    numpy \
#    scipy \
#    pandas \
#    pyarrow \
#    scikit-survival \
#    spyder \
#    lifelines

# conda install -c conda-forge fastparquet
# conda install -c conda-forge xgboost
# conda install -c conda-forge pytorch cpuonly
# conda install -c pytorch pytorch cpuonly
# conda install -c conda-forge matplotlib
# conda install -c conda-forge seaborn
# conda install spyder-notebook -c spyder-ide
# conda install notebook nbformat nbconvert
# conda install -c conda-forge xlsxwriter
# conda install -c conda-forge shap

# import subprocess, sys

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "matplotlib"
# ])

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "seaborn"
# ])

print("numpy:", np.__version__)


from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv

#Dput
def dput_df(df, digits=6):
    data = {
        "columns": list(df.columns),
        "data": [
            [round(x, digits) if isinstance(x, (float, np.floating)) else x
             for x in row]
            for row in df.to_numpy()
        ]
    }
    print(data)


#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df

Load data

Code

from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)

import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)
Code

import pandas as pd

for i in range(1, 6):
    globals()[f"imputation_nodum_{i}"] = pd.read_parquet(
        BASE_DIR / f"imputation_nondum_{i}.parquet",
        engine="fastparquet"
    )
Code
from IPython.display import display, HTML
import io
import sys

def fold_output(title, func):
    buffer = io.StringIO()
    sys.stdout = buffer
    func()
    sys.stdout = sys.__stdout__
    
    html = f"""
    <details>
      <summary>{title}</summary>
      <pre>{buffer.getvalue()}</pre>
    </details>
    """
    display(HTML(html))


fold_output(
    "Show imputation_nodum_1 structure",
    lambda: imputation_nodum_1.info()
)

fold_output(
    "Show imputation_1 structure",
    lambda: imputation_1.info()
)
Show imputation_nodum_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 43 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   readmit_time_from_adm_m        88504 non-null  float64
 1   death_time_from_adm_m          88504 non-null  float64
 2   adm_age_rec3                   88504 non-null  float64
 3   porc_pobr                      88504 non-null  float64
 4   dit_m                          88504 non-null  float64
 5   sex_rec                        88504 non-null  object 
 6   tenure_status_household        88504 non-null  object 
 7   cohabitation                   88504 non-null  object 
 8   sub_dep_icd10_status           88504 non-null  object 
 9   any_violence                   88504 non-null  object 
 10  prim_sub_freq_rec              88504 non-null  object 
 11  tr_outcome                     88504 non-null  object 
 12  adm_motive                     88504 non-null  object 
 13  first_sub_used                 88504 non-null  object 
 14  primary_sub_mod                88504 non-null  object 
 15  tipo_de_vivienda_rec2          88504 non-null  object 
 16  national_foreign               88504 non-null  int32  
 17  plan_type_corr                 88504 non-null  object 
 18  occupation_condition_corr24    88504 non-null  object 
 19  marital_status_rec             88504 non-null  object 
 20  urbanicity_cat                 88504 non-null  object 
 21  ed_attainment_corr             88504 non-null  object 
 22  evaluacindelprocesoteraputico  88504 non-null  object 
 23  eva_consumo                    88504 non-null  object 
 24  eva_fam                        88504 non-null  object 
 25  eva_relinterp                  88504 non-null  object 
 26  eva_ocupacion                  88504 non-null  object 
 27  eva_sm                         88504 non-null  object 
 28  eva_fisica                     88504 non-null  object 
 29  eva_transgnorma                88504 non-null  object 
 30  ethnicity                      88504 non-null  float64
 31  dg_psiq_cie_10_instudy         88504 non-null  bool   
 32  dg_psiq_cie_10_dg              88504 non-null  bool   
 33  dx_f3_mood                     88504 non-null  int32  
 34  dx_f6_personality              88504 non-null  int32  
 35  dx_f_any_severe_mental         88504 non-null  bool   
 36  any_phys_dx                    88504 non-null  bool   
 37  polysubstance_strict           88504 non-null  int32  
 38  readmit_event                  88504 non-null  float64
 39  death_event                    88504 non-null  int32  
 40  readmit_time_from_disch_m      88504 non-null  float64
 41  death_time_from_disch_m        88504 non-null  float64
 42  center_id                      88475 non-null  object 
dtypes: bool(4), float64(9), int32(5), object(25)
memory usage: 25.0+ MB
Show imputation_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 78 columns):
 #   Column                                                              Non-Null Count  Dtype  
---  ------                                                              --------------  -----  
 0   readmit_time_from_adm_m                                             88504 non-null  float64
 1   death_time_from_adm_m                                               88504 non-null  float64
 2   adm_age_rec3                                                        88504 non-null  float64
 3   porc_pobr                                                           88504 non-null  float64
 4   dit_m                                                               88504 non-null  float64
 5   national_foreign                                                    88504 non-null  int32  
 6   ethnicity                                                           88504 non-null  float64
 7   dg_psiq_cie_10_instudy                                              88504 non-null  bool   
 8   dg_psiq_cie_10_dg                                                   88504 non-null  bool   
 9   dx_f3_mood                                                          88504 non-null  int32  
 10  dx_f6_personality                                                   88504 non-null  int32  
 11  dx_f_any_severe_mental                                              88504 non-null  bool   
 12  any_phys_dx                                                         88504 non-null  bool   
 13  polysubstance_strict                                                88504 non-null  int32  
 14  readmit_time_from_disch_m                                           88504 non-null  float64
 15  readmit_event                                                       88504 non-null  float64
 16  death_time_from_disch_m                                             88504 non-null  float64
 17  death_event                                                         88504 non-null  int32  
 18  sex_rec_woman                                                       88504 non-null  float64
 19  tenure_status_household_illegal_settlement                          88504 non-null  float64
 20  tenure_status_household_owner_transferred_dwellings_pays_dividends  88504 non-null  float64
 21  tenure_status_household_renting                                     88504 non-null  float64
 22  tenure_status_household_stays_temporarily_with_a_relative           88504 non-null  float64
 23  cohabitation_alone                                                  88504 non-null  float64
 24  cohabitation_with_couple_children                                   88504 non-null  float64
 25  cohabitation_family_of_origin                                       88504 non-null  float64
 26  sub_dep_icd10_status_drug_dependence                                88504 non-null  float64
 27  any_violence_1_domestic_violence_sex_abuse                          88504 non-null  float64
 28  prim_sub_freq_rec_2_2_6_days_wk                                     88504 non-null  float64
 29  prim_sub_freq_rec_3_daily                                           88504 non-null  float64
 30  tr_outcome_adm_discharge_adm_reasons                                88504 non-null  float64
 31  tr_outcome_adm_discharge_rule_violation_undet                       88504 non-null  float64
 32  tr_outcome_completion                                               88504 non-null  float64
 33  tr_outcome_dropout                                                  88504 non-null  float64
 34  tr_outcome_referral                                                 88504 non-null  float64
 35  adm_motive_another_sud_facility_fonodrogas_senda_previene           88504 non-null  float64
 36  adm_motive_justice_sector                                           88504 non-null  float64
 37  adm_motive_sanitary_sector                                          88504 non-null  float64
 38  adm_motive_spontaneous_consultation                                 88504 non-null  float64
 39  first_sub_used_alcohol                                              88504 non-null  float64
 40  first_sub_used_cocaine_paste                                        88504 non-null  float64
 41  first_sub_used_cocaine_powder                                       88504 non-null  float64
 42  first_sub_used_marijuana                                            88504 non-null  float64
 43  first_sub_used_opioids                                              88504 non-null  float64
 44  first_sub_used_tranquilizers_hypnotics                              88504 non-null  float64
 45  primary_sub_mod_cocaine_paste                                       88504 non-null  float64
 46  primary_sub_mod_cocaine_powder                                      88504 non-null  float64
 47  primary_sub_mod_alcohol                                             88504 non-null  float64
 48  primary_sub_mod_marijuana                                           88504 non-null  float64
 49  tipo_de_vivienda_rec2_other_unknown                                 88504 non-null  float64
 50  plan_type_corr_m_pai                                                88504 non-null  float64
 51  plan_type_corr_m_pr                                                 88504 non-null  float64
 52  plan_type_corr_pg_pai                                               88504 non-null  float64
 53  plan_type_corr_pg_pr                                                88504 non-null  float64
 54  occupation_condition_corr24_inactive                                88504 non-null  float64
 55  occupation_condition_corr24_unemployed                              88504 non-null  float64
 56  marital_status_rec_separated_divorced_annulled_widowed              88504 non-null  float64
 57  marital_status_rec_single                                           88504 non-null  float64
 58  urbanicity_cat_1_rural                                              88504 non-null  float64
 59  urbanicity_cat_2_mixed                                              88504 non-null  float64
 60  ed_attainment_corr_2_completed_high_school_or_less                  88504 non-null  float64
 61  ed_attainment_corr_3_completed_primary_school_or_less               88504 non-null  float64
 62  evaluacindelprocesoteraputico_logro_intermedio                      88504 non-null  float64
 63  evaluacindelprocesoteraputico_logro_minimo                          88504 non-null  float64
 64  eva_consumo_logro_intermedio                                        88504 non-null  float64
 65  eva_consumo_logro_minimo                                            88504 non-null  float64
 66  eva_fam_logro_intermedio                                            88504 non-null  float64
 67  eva_fam_logro_minimo                                                88504 non-null  float64
 68  eva_relinterp_logro_intermedio                                      88504 non-null  float64
 69  eva_relinterp_logro_minimo                                          88504 non-null  float64
 70  eva_ocupacion_logro_intermedio                                      88504 non-null  float64
 71  eva_ocupacion_logro_minimo                                          88504 non-null  float64
 72  eva_sm_logro_intermedio                                             88504 non-null  float64
 73  eva_sm_logro_minimo                                                 88504 non-null  float64
 74  eva_fisica_logro_intermedio                                         88504 non-null  float64
 75  eva_fisica_logro_minimo                                             88504 non-null  float64
 76  eva_transgnorma_logro_intermedio                                    88504 non-null  float64
 77  eva_transgnorma_logro_minimo                                        88504 non-null  float64
dtypes: bool(4), float64(69), int32(5)
memory usage: 48.6 MB
Code
from IPython.display import display, Markdown

if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    display(Markdown(f"**First element type:** `{type(imputations_list_jan26[0])}`"))

    if isinstance(imputations_list_jan26[0], dict):
        display(Markdown(f"**First element keys:** `{list(imputations_list_jan26[0].keys())}`"))

    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        display(Markdown(f"**First element shape:** `{imputations_list_jan26[0].shape}`"))

First element type: <class 'pandas.DataFrame'>

First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Format data

Due to inconsistencies and structural heterogeneity across previously merged datasets, we decided not to proceed with a direct inspection and comparison of column names between the first imputed dataset from imputations_list_jan26 (which likely included dummy-encoded variables) and imputation_nodum_1 (which likely retained non–dummy-encoded variables).

Instead, we reconstructed the analytic datasets de novo using the most recent source files available in the original directory (BASE_DIR). Time-to-event variables were re-derived to ensure internal consistency. Variables that could introduce information leakage (e.g., time from admission) were excluded, and the center identifier variable was removed prior to modeling.

Code
#1.2. Build Surv objects from df_final
from IPython.display import display, Markdown
from sksurv.util import Surv

for i in range(1, 6):
    # Get the DataFrame
    df = globals()[f"imputation_nodum_{i}"]

    # Extract time and event arrays
    time_readm  = df["readmit_time_from_disch_m"].to_numpy()
    event_readm = (df["readmit_event"].to_numpy() == 1)
    time_death  = df["death_time_from_disch_m"].to_numpy()
    event_death = (df["death_event"].to_numpy() == 1)

    # Create survival objects
    y_surv_readm = Surv.from_arrays(event=event_readm, time=time_readm)
    y_surv_death = Surv.from_arrays(event=event_death, time=time_death)

    # Store in global variables (optional but matches your pattern)
    globals()[f"y_surv_readm_{i}"]  = y_surv_readm
    globals()[f"y_surv_death_{i}"]  = y_surv_death

    # Print info
    display(Markdown(f"\n--- Imputation {i} ---"))
    display(Markdown(
    f"**y_surv_readm dtype:** {y_surv_readm.dtype}  \n"
    f"**shape:** {y_surv_readm.shape}"
    ))
    display(Markdown(
    f"**y_surv_death dtype:** {y_surv_death.dtype}  \n"
    f"**shape:** {y_surv_death.shape}"
    ))

— Imputation 1 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 2 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 3 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 4 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 5 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

Code
fold_output(
    "Show imputation_nodum_1 (newer database) glimpse",
    lambda: glimpse(imputation_nodum_1)
)
fold_output(
    "Show first db of imputations_list_jan26 (older) glimpse",
    lambda: glimpse(imputations_list_jan26[0])
)
Show imputation_nodum_1 (newer database) glimpse
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Show first db of imputations_list_jan26 (older) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             float64         1.0, 2.0, 1.0, 1.0, 2.0
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset (1–5), we identified and removed predictors with zero variance, as they provide no useful information and can destabilize models. We printed the dropped variables and produced a cleaned version of each design matrix. This ensures that all downstream analyses use only informative predictors.

Code
# Keep only these objects
objects_to_keep = {
    "objects_to_keep",
    "imputation_nodum_1",
    "imputation_nodum_2",
    "imputation_nodum_3",
    "imputation_nodum_4",
    "imputation_nodum_5",
    "y_surv_readm",
    "y_surv_death",
    "imputations_list_jan26"
}

import types

for name in list(globals().keys()):
    obj = globals()[name]
    if (
        name not in objects_to_keep
        and not name.startswith("_")
        and not callable(obj)
        and not isinstance(obj, types.ModuleType)  # <- protects ALL modules
    ):
        del globals()[name]
Code
from IPython.display import display, Markdown

# 1. Define columns to exclude (same as before)
target_cols = [
    "readmit_time_from_disch_m",
    "readmit_event",
    "death_time_from_disch_m",
    "death_event",
]

leak_time_cols = [
    "readmit_time_from_adm_m",
    "death_time_from_adm_m",
]

center_id = ["center_id"]

cols_to_exclude = target_cols + center_id  + leak_time_cols

# 2. Create list of your EXISTING imputation DataFrames (1-5)
imputed_dfs = [
    imputation_nodum_1,
    imputation_nodum_2,
    imputation_nodum_3,
    imputation_nodum_4,
    imputation_nodum_5
]

# 3. Preprocessing loop
X_reduced_list = []

for d, df in enumerate(imputed_dfs):
    imputation_num = d + 1  # Convert 0-index to 1-index for display

    display(Markdown(f"\n=== Imputation dataset {imputation_num} ==="))

    # a) Identify and drop constant predictors
    const_mask = (df.nunique(dropna=False) <= 1)
    dropped_const = df.columns[const_mask].tolist()
    display(Markdown(f"**Constant predictors dropped ({len(dropped_const)}):**"))
    display(Markdown(f"{dropped_const if dropped_const else 'None'}"))

    # b) Remove constant columns
    X_reduced = df.loc[:, ~const_mask]

    # c) Drop target/leakage columns (if present)
    cols_to_drop = [col for col in cols_to_exclude if col in X_reduced.columns]
    if cols_to_drop:
        X_reduced = X_reduced.drop(columns=cols_to_drop)
        display(Markdown(f"**Dropped target/leakage columns:** {cols_to_drop}"))
    else:
        display(Markdown("No target/leakage columns found to drop"))

    # d) Store cleaned DataFrame
    X_reduced_list.append(X_reduced)

    # e) Report shapes
    display(Markdown(f"**Original shape:** {df.shape}"))
    display(Markdown(
        f"**Cleaned shape:** {X_reduced.shape} "
        f"(removed {df.shape[1] - X_reduced.shape[1]} columns)"
    ))

display(Markdown("\n✅ **Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.**"))

=== Imputation dataset 1 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 2 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 3 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 4 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 5 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.

Dummify

A structured preprocessing pipeline was implemented prior to modeling. Ordered categorical variables (e.g., housing status, educational attainment, clinical evaluations, and substance use frequency) were manually mapped to numeric scales reflecting their natural ordering. For nominal categorical variables, prespecified reference categories were enforced to ensure consistent baseline comparisons across imputations. All remaining categorical predictors were then converted to dummy variables using one-hot encoding with the first category dropped to prevent multicollinearity. The procedure was applied consistently across all imputed datasets to ensure harmonized model inputs.

Code
import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype

def preprocess_features_robust(df):
    df_clean = df.copy()

    # ---------------------------------------------------------
    # 1. Ordinal encoding (your existing code)
    # ---------------------------------------------------------
    ordered_mappings = {
        # --- NEW: Housing & Urbanicity ---
        "tenure_status_household": {
            "illegal settlement": 4,                       # Situación Calle
            "stays temporarily with a relative": 3,        # Allegado
            "others": 2,                                   # En pensión / Otros
            "renting": 1,                                  # Arrendando
            "owner/transferred dwellings/pays dividends": 0 # Vivienda Propia
        },
        "urbanicity_cat": {
            "1.Rural": 2,
            "2.Mixed": 1,
            "3.Urban": 0
        },

        # --- Clinical Evaluations (Minimo -> Intermedio -> Alto) ---
        "evaluacindelprocesoteraputico": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_consumo":      {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fam":          {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_relinterp":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_ocupacion":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_sm":           {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fisica":       {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_transgnorma":  {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},

        # --- Frequency (Less freq -> More freq) ---
        "prim_sub_freq_rec": {
            "1.≤1 day/wk": 0,
            "2.2–6 days/wk": 1,
            "3.Daily": 2
        },

        # --- Education (Less -> More) ---
        "ed_attainment_corr": {
            "3-Completed primary school or less": 2,
            "2-Completed high school or less": 1,
            "1-More than high school": 0
        }
    }

    for col, mapping in ordered_mappings.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            df_clean[col] = df_clean[col].map(mapping)

            n_missing = df_clean[col].isnull().sum()
            if n_missing > 0:
                if n_missing == len(df_clean):
                    print(f"⚠️ WARNING: Mapping failed completely for '{col}'.")
                mode_val = df_clean[col].mode()[0]
                df_clean[col] = df_clean[col].fillna(mode_val)

    # ---------------------------------------------------------
    # 2. FORCE reference categories for dummies
    # ---------------------------------------------------------
    dummy_reference = {
        "sex_rec": "man",
        "plan_type_corr": "ambulatory",
        "marital_status_rec": "married/cohabiting",
        "cohabitation": "alone",
        "sub_dep_icd10_status": "hazardous consumption",
        "tr_outcome": "completion",
        "adm_motive": "spontaneous consultation",
        "tipo_de_vivienda_rec2": "formal housing",
        "plan_type_corr": "pg-pab",
        "occupation_condition_corr24": "employed",
        "any_violence": "0.No domestic violence/sex abuse",
        "first_sub_used": "marijuana",
        "primary_sub_mod": "marijuana",
        }

    for col, ref in dummy_reference.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            cats = df_clean[col].unique().tolist()

            if ref in cats:
                new_order = [ref] + [c for c in cats if c != ref]
                cat_type = CategoricalDtype(categories=new_order, ordered=False)
                df_clean[col] = df_clean[col].astype(cat_type)
            else:
                print(f"⚠️ Reference '{ref}' not found in {col}")

    # ---------------------------------------------------------
    # 3. One-hot encoding
    # ---------------------------------------------------------
    df_final = pd.get_dummies(df_clean, drop_first=True, dtype=float)

    return df_final

X_encoded_list_final = [preprocess_features_robust(X) for X in X_reduced_list]
X_encoded_list_final = [clean_names(X) for X in X_encoded_list_final]
Code
from IPython.display import display, Markdown

# 1. DIAGNOSTIC: Check exact string values
display(Markdown("### --- Diagnostic Check ---"))
sample_df = X_encoded_list_final[0]

if 'tenure_status_household' in sample_df.columns:
    display(Markdown("**Unique values in 'tenure_status_household':**"))
    display(Markdown(str(sample_df['tenure_status_household'].unique())))
else:
    display(Markdown("❌ 'tenure_status_household' is missing entirely from input data!"))

if 'urbanicity_cat' in sample_df.columns:
    display(Markdown("**Unique values in 'urbanicity_cat':**"))
    display(Markdown(str(sample_df['urbanicity_cat'].unique())))

if 'ed_attainment_corr' in sample_df.columns:
    display(Markdown("**Unique values in 'ed_attainment_corr':**"))
    display(Markdown(str(sample_df['ed_attainment_corr'].unique())))

— Diagnostic Check —

Unique values in ‘tenure_status_household’:

[3 0 1 2 4]

Unique values in ‘urbanicity_cat’:

[0 1 2]

Unique values in ‘ed_attainment_corr’:

[1 2 0]

We recoded first substance use so small categories are grouped into Others

Code
# Columns to combine
cols_to_group = [
    "first_sub_used_opioids",
    "first_sub_used_others",
    "first_sub_used_hallucinogens",
    "first_sub_used_inhalants",
    "first_sub_used_tranquilizers_hypnotics",
    "first_sub_used_amphetamine_type_stimulants",
]

# Loop over datasets 0–4 and modify in place
for i in range(5):
    df = X_encoded_list_final[i].copy()
    # Collapse into one dummy: if any of these == 1, mark as 1
    df["first_sub_used_other"] = df[cols_to_group].max(axis=1)
    # Drop the rest except the new combined column
    df = df.drop(columns=[c for c in cols_to_group if c != "first_sub_used_other"])
    # Replace the dataset in the original list
    X_encoded_list_final[i] = df
Code
import sys
fold_output(
    "Show first db of X_encoded_list_final (newer) glimpse",
    lambda: glimpse(X_encoded_list_final[0])
)
Show first db of X_encoded_list_final (newer) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             int64           1, 2, 1, 1, 2
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset, we fitted two regularized Cox models (one for readmission and one for death) using Coxnet, which applies elastic-net penalization with a strong LASSO component to enable variable selection. The loop fits both models on every imputation, prints basic model information, and stores all fitted models so they can later be combined or compared across imputations.

Create bins for followup (landmarks)

We extracted the observed event times and corresponding event indicators directly from the structured survival objects (y_surv_readm and y_surv_death). Using the observed event times, we constructed evaluation grids based on the 5th to 95th percentiles of the event-time distribution. These grids define standardized time points at which model performance is assessed for both readmission and mortality outcomes.

Code
import numpy as np
from IPython.display import display, Markdown

# Extract event times directly from structured arrays
event_times_readm = y_surv_readm["time"][y_surv_readm["event"]]
event_times_death = y_surv_death["time"][y_surv_death["event"]]

# Build evaluation grids (5th–95th percentiles, 50 points)
times_eval_readm = np.unique(
    np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50))
)

times_eval_death = np.unique(
    np.quantile(event_times_death, np.linspace(0.05, 0.95, 50))
)

# Display only final result
display(Markdown(
    f"**Eval times (readmission):** `{times_eval_readm[:5]}` ... `{times_eval_readm[-5:]}`"
))

display(Markdown(
    f"**Eval times (death):** `{times_eval_death[:5]}` ... `{times_eval_death[-5:]}`"
))

Eval times (readmission): [0.38709677 0.67741935 1.03225806 1.41935484 1.76666667][46.81833443 50.96030612 55.16129032 60.84848585 67.08322581]

Eval times (death): [0. 0.09677419 1.06666667 2.1691691 3.34812377][74.72632653 78.4516129 82.39472203 86.41935484 92.36311828]

Correct inmortal time bias

First, we eliminated inmortal time bias (dead patients look like without readmission).

This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.

Code
import numpy as np

# Step 3. Replicate across imputations (safe copies)
n_imputations = len(X_encoded_list_final)
y_surv_readm_list = [y_surv_readm.copy() for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death.copy() for _ in range(n_imputations)]

def correct_competing_risks(y_readm_list, y_death_list):
    corrected = []
    for y_readm, y_death in zip(y_readm_list, y_death_list):
        y_corr = y_readm.copy()

        # death observed and occurs before (or at) readmission/censoring time
        mask = (y_death["event"]) & (y_death["time"] < y_corr["time"])

        y_corr["event"][mask] = False
        y_corr["time"][mask] = y_death["time"][mask]

        corrected.append(y_corr)
    return corrected

# Step 4. Apply correction
y_surv_readm_list_corrected = correct_competing_risks(
    y_surv_readm_list,
    y_surv_death_list
)
Code
# Check type and length
type(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)

# Look at the first element
y_surv_readm_list_corrected[0][:5]   # first 5 rows
array([(False, 68.96774194), ( True,  7.        ), ( True, 13.25806452), ( True,  5.        ), ( True,  7.35483871)],
      dtype=[('event', '?'), ('time', '<f8')])
Code
from IPython.display import display, HTML
import html

def nb_print(*args, sep=" "):
    msg = sep.join(str(a) for a in args)
    display(HTML(f"<pre style='margin:0'>{html.escape(msg)}</pre>"))

The fully preprocessed and encoded feature matrices were renamed from X_encoded_list_final to imputations_list_mar26 to reflect the finalized February 2026 analytic version of the imputed datasets.

This object contains the harmonized, ordinal-encoded, and one-hot encoded predictor matrices for all five imputations and will serve as the definitive input for subsequent modeling procedures.

Code
imputations_list_mar26 = X_encoded_list_final
del X_encoded_list_final
Code
import numpy as np
import pandas as pd
from IPython.display import display, Markdown

# ── Build exclusion mask (same for all imputations, based on imputation 0) ──
df0 = imputations_list_mar26[0]

# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7
mask_adm_death = (
    (df0["tr_outcome_adm_discharge_adm_reasons"] == 1)
    & (y_surv_death["event"] == True)
    & (y_surv_death["time"] <= 0.23)
)

# Condition 2: tr_outcome_other == 1 (any time)
mask_other = df0["tr_outcome_other"] == 1

# Combined exclusion mask
exclude = mask_adm_death | mask_other
keep = ~exclude

# ── Report ──
n_total = len(df0)
n_excl_adm = mask_adm_death.sum()
n_excl_other = mask_other.sum()
n_excl_both = (mask_adm_death & mask_other).sum()
n_excl_total = exclude.sum()
n_remaining = keep.sum()

report = f"""### Exclusion Report

| Criterion | n excluded |
|---|---:|
| `tr_outcome_adm_discharge_adm_reasons == 1` & death time ≤ 7 days | {n_excl_adm} |
| `tr_outcome_other == 1` (any time) | {n_excl_other} |
| Both criteria (overlap) | {n_excl_both} |
| **Total unique excluded** | **{n_excl_total}** |
| **Remaining observations** | **{n_remaining}** / {n_total} |
"""
display(Markdown(report))

# ── Apply filter to all imputation lists and outcome arrays ──
imputations_list_mar26 = [df.loc[keep].reset_index(drop=True) for df in imputations_list_mar26]
y_surv_readm_list_mar26 = [y[keep] for y in y_surv_readm_list_corrected]
y_surv_death_list_mar26 = [y[keep] for y in y_surv_death_list]
y_surv_readm_list_corrected_mar26 = [y[keep] for y in y_surv_readm_list_corrected]

# Single (non-list) outcome arrays for convenience
y_surv_readm_mar26 = y_surv_readm[keep]
y_surv_death_mar26 = y_surv_death[keep]

# Rebuild eval time grids on the filtered data
event_times_readm_mar26 = y_surv_readm_mar26["time"][y_surv_readm_mar26["event"]]
event_times_death_mar26 = y_surv_death_mar26["time"][y_surv_death_mar26["event"]]

times_eval_readm_mar26 = np.unique(
    np.quantile(event_times_readm_mar26, np.linspace(0.05, 0.95, 50))
)
times_eval_death_mar26 = np.unique(
    np.quantile(event_times_death_mar26, np.linspace(0.05, 0.95, 50))
)

Exclusion Report

Criterion n excluded
tr_outcome_adm_discharge_adm_reasons == 1 & death time ≤ 7 days 137
tr_outcome_other == 1 (any time) 215
Both criteria (overlap) 0
Total unique excluded 352
Remaining observations 88152 / 88504

Train / test split (80/20)

  1. Sets a fixed random seed to make the 80/20 split exactly reproducible.

  2. Verifies required datasets exist (features and survival outcomes) before doing anything.

  3. Creates a “death-corrected” outcome list if it was not already available.

  4. Derives stratification labels from treatment plan and completion categories plus readmission/death events.

  5. Uses a step-down strategy if some strata are too rare, simplifying the stratification to keep it feasible.

  6. Caches a “full snapshot” of all imputations and outcomes so reruns don’t silently change the split.

  7. Checks row alignment so every imputation and every outcome has the same number of observations.

  8. Optionally checks stability across imputations for plan/completion columns (should not vary much).

  9. Loads split indices from disk when available, ensuring the exact same train/test split across sessions.

  10. Builds train/test datasets consistently for all imputations, then runs strict diagnostics to confirm balance.

Code
#@title 🧪 / 🎓 Reproducible 80/20 split before ML (integrated + idempotent + persisted)
# Stratification hierarchy:
#   1) plan + completion + readm_event + death_event
#   2) mixed fallback for rare full strata (<2) -> plan + readm + death
#   3) full fallback -> plan + readm + death

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from IPython.display import display, Markdown

SEED = 2125
TEST_SIZE = 0.20
FORCE_RESPLIT = False          # True to force new split
STRICT_SPLIT_CHECKS = True     # CI-style hard checks
MAX_EVENT_GAP = 0.01           # 1% tolerance
PERSIST_SPLIT_INDICES = True

# --- Project-root anchored paths ---
from pathlib import Path

def find_project_root(markers=("AGENTS.md", ".git")):
    try:
        cur = Path.cwd().resolve()
    except OSError as e:
        raise RuntimeError(
            "Invalid working directory. Run this notebook from inside the project folder."
        ) from e

    for p in (cur, *cur.parents):
        if any((p / m).exists() for m in markers):
            return p

    raise RuntimeError(
        f"Could not locate project root starting from {cur}. "
        f"Expected one of markers: {markers}."
    )

PROJECT_ROOT = find_project_root()
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_FILE = OUT_DIR / f"comb_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.npz"

def nb_print_md(msg):
    display(Markdown(str(msg)))

nb_print_md(f"**Project root:** `{PROJECT_ROOT}`")


# ---------- Requirements ----------
required = [
    "imputations_list_mar26",
    "y_surv_readm_list_mar26",
    "y_surv_readm_list_corrected_mar26",
    "y_surv_death_list_mar26",
]
missing = [v for v in required if v not in globals()]
if missing:
    raise ValueError(f"Missing required objects: {missing}")

if "y_surv_death_list_corrected_mar26" not in globals():
    y_surv_death_list_corrected_mar26 = [y.copy() for y in y_surv_death_list_mar26]

# ---------- Helpers ----------
def get_plan_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "plan_type_corr_pg_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "plan_type_corr_m_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "plan_type_corr_m_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    return labels

def get_completion_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "tr_outcome_referral" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_referral"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "tr_outcome_dropout" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_dropout"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "tr_outcome_adm_discharge_rule_violation_undet" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_rule_violation_undet"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "tr_outcome_adm_discharge_adm_reasons" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_adm_reasons"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    if "tr_outcome_other" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_other"], errors="coerce").fillna(0).to_numpy() == 1] = 5
    return labels

def build_strata(X0, y_readm0, y_death0):
    """
    Build stratification labels with progressive fallback:

    1) full: plan + completion + readmission_event + death_event
    2) mixed: only rare full strata (<2 rows) are replaced by fallback labels
       (plan + readmission_event + death_event)
    3) fallback: plan + readmission_event + death_event for all rows

    Returns:
        strata (np.ndarray), mode (str), readm_evt (np.ndarray), death_evt (np.ndarray)
    """
    plan = get_plan_labels(X0)
    comp = get_completion_labels(X0)
    readm_evt = y_readm0["event"].astype(int)
    death_evt = y_death0["event"].astype(int)

    full = pd.Series(plan.astype(str) + "_" + comp.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    if full.value_counts().min() >= 2:
        return full.to_numpy(), "full(plan+completion+readm+death)", readm_evt, death_evt

    fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    mixed = full.copy()
    rare = mixed.map(mixed.value_counts()) < 2
    mixed[rare] = fb[rare]
    if mixed.value_counts().min() >= 2:
        return mixed.to_numpy(), "mixed(rare->plan+readm+death)", readm_evt, death_evt

    if fb.value_counts().min() >= 2:
        return fb.to_numpy(), "fallback(plan+readm+death)", readm_evt, death_evt

    raise ValueError("Could not build stratification labels with >=2 rows per stratum.")

def split_df_list(df_list, tr_idx, te_idx):
    tr = [df.iloc[tr_idx].reset_index(drop=True).copy() for df in df_list]
    te = [df.iloc[te_idx].reset_index(drop=True).copy() for df in df_list]
    return tr, te

def split_surv_list(y_list, tr_idx, te_idx):
    tr = [y[tr_idx].copy() for y in y_list]
    te = [y[te_idx].copy() for y in y_list]
    return tr, te

# ---------- Cache full data once (idempotent re-runs) ----------
if "_split_cache_death_mar26" not in globals():
    _split_cache_death_mar26 = {}
cache = _split_cache_death_mar26

if FORCE_RESPLIT:
    cache.pop("idx", None)

if FORCE_RESPLIT or "full" not in cache:
    cache["full"] = {
        "X": [df.reset_index(drop=True).copy() for df in imputations_list_mar26],
        "y_readm": [y.copy() for y in y_surv_readm_list_mar26],
        "y_readm_corr": [y.copy() for y in y_surv_readm_list_corrected_mar26],
        "y_death": [y.copy() for y in y_surv_death_list_mar26],
        "y_death_corr": [y.copy() for y in y_surv_death_list_corrected_mar26],
    }

full = cache["full"]

# ---------- Consistency checks ----------
n_imp = len(full["X"])
n = len(full["X"][0])

if any(len(df) != n for df in full["X"]):
    raise ValueError("Row mismatch inside full X list.")

for name, obj in [
    ("y_readm", full["y_readm"]),
    ("y_readm_corr", full["y_readm_corr"]),
    ("y_death", full["y_death"]),
    ("y_death_corr", full["y_death_corr"]),
]:
    if len(obj) != n_imp:
        raise ValueError(f"{name} length ({len(obj)}) != n_imputations ({n_imp})")
    if any(len(y) != n for y in obj):
        raise ValueError(f"Row mismatch between X and {name}.")

# ---------- Optional diagnostic: plan/completion consistency across imputations ----------
plan_comp_cols = [
    c for c in [
        "plan_type_corr_pg_pr",
        "plan_type_corr_m_pr",
        "plan_type_corr_pg_pai",
        "plan_type_corr_m_pai",
        "tr_outcome_referral",
        "tr_outcome_dropout",
        "tr_outcome_adm_discharge_rule_violation_undet",
        "tr_outcome_adm_discharge_adm_reasons"#,
        #"tr_outcome_other", #2026-03-26: Excluded from consistency check since it was used as an exclusion criterion and thus may differ by design across imputations
    ] if c in full["X"][0].columns
]

max_diff_rows = 0
if plan_comp_cols:
    base_pc = full["X"][0][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
    for i in range(1, n_imp):
        cur_pc = full["X"][i][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
        diff_rows = int((base_pc != cur_pc).any(axis=1).sum())
        max_diff_rows = max(max_diff_rows, diff_rows)

# ---------- Try loading indices from disk ----------
loaded_from_disk = False
if PERSIST_SPLIT_INDICES and (not FORCE_RESPLIT) and SPLIT_FILE.exists() and ("idx" not in cache):
    z = np.load(SPLIT_FILE, allow_pickle=False)
    tr = z["train_idx"].astype(int)
    te = z["test_idx"].astype(int)
    n_disk = int(z["n_full"][0]) if "n_full" in z else n
    if n_disk == n and tr.max() < n and te.max() < n:
        cache["idx"] = (np.sort(tr), np.sort(te))
        cache["strat_mode"] = str(z["strat_mode"][0]) if "strat_mode" in z else "loaded_from_disk"
        loaded_from_disk = True

# ---------- Compute or reuse split indices ----------
if FORCE_RESPLIT or "idx" not in cache:
    strata_used, strat_mode, readm_evt_all, death_evt_all = build_strata(
        full["X"][0], full["y_readm"][0], full["y_death"][0]
    )
    idx = np.arange(n)
    train_idx, test_idx = train_test_split(
        idx, test_size=TEST_SIZE, random_state=SEED, shuffle=True, stratify=strata_used
    )
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    cache["idx"] = (train_idx, test_idx)
    cache["strat_mode"] = strat_mode

    if PERSIST_SPLIT_INDICES:
        np.savez_compressed(
            SPLIT_FILE,
            train_idx=train_idx,
            test_idx=test_idx,
            n_full=np.array([n], dtype=int),
            seed=np.array([SEED], dtype=int),
            test_size=np.array([TEST_SIZE], dtype=float),
            strat_mode=np.array([strat_mode], dtype="U64"),
        )
else:
    train_idx, test_idx = cache["idx"]
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    readm_evt_all = full["y_readm"][0]["event"].astype(int)
    death_evt_all = full["y_death"][0]["event"].astype(int)

# ---------- Build train/test from full snapshot every run ----------
imputations_list_mar26_train, imputations_list_mar26_test = split_df_list(full["X"], train_idx, test_idx)

y_surv_readm_list_train, y_surv_readm_list_test = split_surv_list(full["y_readm"], train_idx, test_idx)
y_surv_readm_list_corrected_train, y_surv_readm_list_corrected_test = split_surv_list(full["y_readm_corr"], train_idx, test_idx)

y_surv_death_list_train, y_surv_death_list_test = split_surv_list(full["y_death"], train_idx, test_idx)
y_surv_death_list_corrected_train, y_surv_death_list_corrected_test = split_surv_list(full["y_death_corr"], train_idx, test_idx)

# Downstream code uses TRAIN only
imputations_list_mar26 = imputations_list_mar26_train
y_surv_readm_list = y_surv_readm_list_train
y_surv_readm_list_corrected = y_surv_readm_list_corrected_train
y_surv_death_list = y_surv_death_list_train
y_surv_death_list_corrected = y_surv_death_list_corrected_train

# ---------- Diagnostics + strict checks ----------
strata_diag, strat_mode_diag, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])
sdiag = pd.Series(strata_diag)
train_strata = set(sdiag.iloc[train_idx].unique())
test_strata = set(sdiag.iloc[test_idx].unique())
missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

readm_gap = abs(readm_evt_all[train_idx].mean() - readm_evt_all[test_idx].mean())
death_gap = abs(death_evt_all[train_idx].mean() - death_evt_all[test_idx].mean())

# full-strata rarity report (before fallback)
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)
vc_full = strata_full.value_counts()
rare_rows = int((strata_full.map(vc_full) < 2).sum())

if STRICT_SPLIT_CHECKS:
    assert len(np.intersect1d(train_idx, test_idx)) == 0, "Train/Test index overlap detected."
    assert (len(train_idx) + len(test_idx)) == n, "Train/Test sizes do not sum to n."
    assert len(missing_in_test) == 0, f"Strata missing in test: {missing_in_test}"
    assert len(missing_in_train) == 0, f"Strata missing in train: {missing_in_train}"
    assert readm_gap < MAX_EVENT_GAP, f"Readmission rate imbalance > {MAX_EVENT_GAP:.0%} (gap={readm_gap:.4f})"
    assert death_gap < MAX_EVENT_GAP, f"Death rate imbalance > {MAX_EVENT_GAP:.0%} (gap={death_gap:.4f})"

# ---------- Summary ----------
nb_print_md(f"**Loaded indices from disk:** `{loaded_from_disk}`")
nb_print_md(f"**Split file:** `{SPLIT_FILE}`")
nb_print_md(f"**Split mode used:** `{cache.get('strat_mode', strat_mode_diag)}`")
nb_print_md(f"**Plan/completion diff rows across imputations (max vs imp0):** `{max_diff_rows}`")
nb_print_md(f"**Full strata count:** `{vc_full.size}` | **Min full stratum size:** `{int(vc_full.min())}` | **Rows in rare full strata (<2):** `{rare_rows}`")
nb_print_md(f"**Train/Test sizes:** `{len(train_idx)}` ({len(train_idx)/n:.1%}) / `{len(test_idx)}` ({len(test_idx)/n:.1%})")
nb_print_md(
    "**Readmission rate all/train/test:** "
    f"`{readm_evt_all.mean():.3%}` / `{readm_evt_all[train_idx].mean():.3%}` / `{readm_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    "**Death rate all/train/test:** "
    f"`{death_evt_all.mean():.3%}` / `{death_evt_all[train_idx].mean():.3%}` / `{death_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    f"**Strata in train/test:** `{len(train_strata)}` / `{len(test_strata)}` | "
    f"**Missing train→test:** `{len(missing_in_test)}` | **Missing test→train:** `{len(missing_in_train)}`"
)

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Loaded indices from disk: True

Split file: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\comb_split_seed2125_test20_mar26.npz

Split mode used: fallback(plan+readm+death)

Plan/completion diff rows across imputations (max vs imp0): 0

Full strata count: 95 | Min full stratum size: 1 | Rows in rare full strata (<2): 4

Train/Test sizes: 70521 (80.0%) / 17631 (20.0%)

Readmission rate all/train/test: 21.622% / 21.621% / 21.627%

Death rate all/train/test: 4.310% / 4.309% / 4.311%

Strata in train/test: 20 / 20 | Missing train→test: 0 | Missing test→train: 0

Code
# counts per stratum in train/test
train_counts = sdiag.iloc[train_idx].value_counts()
test_counts  = sdiag.iloc[test_idx].value_counts()

min_train = int(train_counts.min())
min_test  = int(test_counts.min())

nb_print_md(f"**Min stratum count in TRAIN (used strata):** `{min_train}`")
nb_print_md(f"**Min stratum count in TEST (used strata):** `{min_test}`")

# strata that got 0 in test or 0 in train
zero_in_test = sorted(set(train_counts.index) - set(test_counts.index))
zero_in_train = sorted(set(test_counts.index) - set(train_counts.index))

nb_print_md(f"**Strata with 0 in TEST:** `{len(zero_in_test)}`")
nb_print_md(f"**Strata with 0 in TRAIN:** `{len(zero_in_train)}`")

# show examples with their full-data counts
if len(zero_in_test) > 0:
    ex = zero_in_test[:10]
    nb_print_md(f"**Examples 0 in TEST (up to 10):** `{ex}`")
    nb_print_md(f"**Full-data counts:** `{[int(sdiag.value_counts()[k]) for k in ex]}`")

Min stratum count in TRAIN (used strata): 18

Min stratum count in TEST (used strata): 5

Strata with 0 in TEST: 0

Strata with 0 in TRAIN: 0

Code
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)

vc = strata_full.value_counts()
display(Markdown(f"**# full strata:** `{vc.size}`"))
display(Markdown(f"**Min stratum size (full):** `{int(vc.min())}`"))
display(Markdown(f"**# strata with count < 2:** `{int((vc < 2).sum())}`"))

# full strata: 95

Min stratum size (full): 1

# strata with count < 2: 4

Code
plan = get_plan_labels(full["X"][0])
readm_evt = full["y_readm"][0]["event"].astype(int)
death_evt = full["y_death"][0]["event"].astype(int)

fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))

rare_mask = strata_full.map(strata_full.value_counts()) < 2
n_rare = int(rare_mask.sum())

display(Markdown(f"**Rows in rare full-strata (<2):** `{n_rare}`"))
if n_rare > 0:
    display(Markdown(
        f"**Rare rows proportion:** `{n_rare/len(strata_full):.3%}`"
    ))

Rows in rare full-strata (<2): 4

Rare rows proportion: 0.005%

Code
# Use the actual stratification mode that was used to split
strata_used, strat_mode, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])

s = pd.Series(strata_used)
train_strata = set(s.iloc[train_idx].unique())
test_strata = set(s.iloc[test_idx].unique())

missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

display(Markdown(f"**Strata used:** `{strat_mode}`"))
display(Markdown(f"**# strata in train:** `{len(train_strata)}` | **# strata in test:** `{len(test_strata)}`"))
display(Markdown(f"**Strata present in train but missing in test:** `{len(missing_in_test)}`"))
display(Markdown(f"**Strata present in test but missing in train:** `{len(missing_in_train)}`"))

Strata used: fallback(plan+readm+death)

# strata in train: 20 | # strata in test: 20

Strata present in train but missing in test: 0

Strata present in test but missing in train: 0

Code
from pathlib import Path
import pandas as pd
import numpy as np
from IPython.display import display, Markdown

PROJECT_ROOT = find_project_root()   # no hardcoded absolute path
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_PARQUET = OUT_DIR / f"comb_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"

split_df = pd.DataFrame({
    "row_id": np.arange(n),
    "is_train": np.isin(np.arange(n), train_idx)
})

split_df.to_parquet(SPLIT_PARQUET, index=False)

display(Markdown(f"**Project root:** `{PROJECT_ROOT}`"))
display(Markdown(f"**Saved split to:** `{SPLIT_PARQUET}`"))

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Saved split to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\comb_split_seed2125_test20_mar26.parquet

Code
import pandas as pd
import numpy as np
from pathlib import Path

SEED = 2125
TEST_SIZE = 0.20

# Use the first imputation (complete data)
X_full = full["X"][0]
y_death_full = full["y_death"][0]

# Find admission age column
age_col = 'adm_age_rec3' if 'adm_age_rec3' in X_full.columns else \
          [c for c in X_full.columns if 'adm_age' in c][0]

# Create the 4-column split file
split_export = pd.DataFrame({
    'row_id': np.arange(1, len(X_full) + 1),  # 1-based for R
    'is_train': [i in train_idx for i in range(len(X_full))],
    'death_time_from_disch_m': np.round(y_death_full['time'], 2),
    'adm_age_rec3': X_full[age_col]
})

out_dir = PROJECT_ROOT / "_out"
out_dir.mkdir(exist_ok=True)

# Export
fname = out_dir / f"comb_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"
split_export.to_parquet(fname, index=False)

nb_print(f"Exported: {fname}")
nb_print(f"Total: {len(split_export)} rows")
nb_print(f"Train: {split_export['is_train'].sum()} ({100*split_export['is_train'].mean():.1f}%)")
nb_print(f"Test: {(~split_export['is_train']).sum()} ({100*(~split_export['is_train']).mean():.1f}%)")
nb_print(f"\nFirst 5 rows:")
nb_print(split_export.head())
Exported: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\comb_split_seed2125_test20_mar26.parquet
Total: 88152 rows
Train: 70521 (80.0%)
Test: 17631 (20.0%)
First 5 rows:
   row_id  is_train  death_time_from_disch_m  adm_age_rec3
0       1      True                    68.97         31.53
1       2      True                    81.32         20.61
2       3      True                   116.74         42.52
3       4     False                    91.97         60.61
4       5      True                    31.03         45.08
Code
df0 = imputations_list_mar26[0]

# Calculate exactly what you need
mean_age = df0["adm_age_rec3"].mean()
count_foreign = (df0["national_foreign"] == 1).sum()

# Print results
nb_print(f"Mean of adm_age_rec3: {mean_age:.4f}")
nb_print(f"Count of national_foreign == 1: {count_foreign}")
Mean of adm_age_rec3: 35.7256
Count of national_foreign == 1: 453

We cleaned our environment safely so that:

  • Old models

  • Temporary objects

  • Large intermediate datasets

do not interfere with the next modeling block.

Code
# Safe cleanup before Readmission XGBoost blocks
import types
import gc

# Ensure logger exists (some target cells expect it)
if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Compatibility: one Optuna/Bootstrap block checks jan26 naming
#if "imputations_list_jan26" not in globals() and "imputations_list_mar26" in globals():
#    imputations_list_jan26 = imputations_list_mar26

KEEP = {
    "nb_print", "study",
    "imputations_list_mar26", #"imputations_list",#"imputations_list_jan26"
    "X_train", "y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list",
    # Optional plot config objects:
    "plt", "sns", "matplotlib", "mpl", "rcParams", "PROJECT_ROOT"
}

# ensure both variants are kept
KEEP.update({
    "y_surv_readm_list_corrected_mar26", "y_surv_readm_list_corrected",
    "y_surv_readm_list_mar26", "y_surv_death_list_mar26"
})

# after cleanup, sanity-check alignment before tuning
if "imputations_list_mar26" in globals() and "y_surv_readm_list_corrected" in globals():
    assert len(imputations_list_mar26[0]) == len(y_surv_readm_list_corrected[0]), \
        f"Row mismatch: X={len(imputations_list_mar26[0])}, y={len(y_surv_readm_list_corrected[0])}"
        
for name, obj in list(globals().items()):
    if name in KEEP or name.startswith("_"):
        continue
    if isinstance(obj, types.ModuleType):   # keep imports
        continue
    if callable(obj):                        # keep functions/classes
        continue
    del globals()[name]

gc.collect()

required = ["y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list"]
missing = [x for x in required if x not in globals()]
nb_print("Missing required objects:", missing)
Missing required objects: []

Advanced Survival Modeling: XGBoost & Stratified Evaluation

In this section, we transition to a Gradient Boosted Decision Tree (GBDT) framework using XGBoost. This approach serves as a robust, non-linear benchmark to validate findings from the neural network, specifically optimized for high-imbalance survival data (approx. 4% death rate).

Full Metrics

The updated pipeline uses a cause-specific (death-censored) framework for readmission and reports discrimination/calibration metrics accordingly.

  1. Trains two cause-specific XGBoost survival models (survival:cox): one for death and one for readmission.
  2. Uses 5-fold stratified cross-validation across all imputations, with composite stratification (event type × treatment plan).
  3. Encodes survival targets for XGBoost as +time if event, −time if censored.
  4. Uses output_margin=True, clips risk scores to [-15, 15], then exponentiates for hazard-scale calculations.
  5. Converts fold-level risk scores to survival probabilities using the Breslow baseline hazard estimator (for both endpoints).
  6. Computes Global and time-dependent Uno’s C-index with a safe fallback (tau truncation when needed).
  7. Computes Global IBS and time-dependent IBS (only when at least 2 time points are available), plus horizon-specific point Brier Score.
  8. For horizon-specific classification metrics, restricts to valid case/control subjects at each horizon and estimates thresholds on training folds only (F1 for death, Youden for readmission).
  9. Uses pre-encoded plan variables (no re-dummying in this step) and applies early stopping (up to 5000 rounds), logging best_iteration per fold.
  10. Stores fold artifacts for reproducibility: predictions, baseline hazards, and exact CV train/validation splits.
  11. Aggregates out-of-fold SHAP values across folds and imputations by patient index, with index re-alignment before export.
  12. Saves a complete artifact set:
  13. xgb6_corr_DUAL_metrics_<timestamp>.csv
  14. xgb6_corr_DUAL_final_ev_hyp_<timestamp>.pkl
  15. xgb6_corr_DUAL_BaselineHazards_<timestamp>.pkl
  16. xgb6_corr_DUAL_CV_Splits_<timestamp>.pkl
  17. xgb6_corr_DUAL_SHAP_Aggregated_<timestamp>.pkl (when SHAP is computed)

📌 5 Core Assumptions of This Pipeline (Cause-Specific Framework)

  1. Cause-specific hazard assumption
    Death and readmission are modeled separately; competing events are treated as censoring for each endpoint.

  2. Independent censoring assumption
    Censoring (including competing-event censoring) is assumed conditionally independent of the event process given covariates.

  3. Cox-type risk structure assumption
    survival:cox optimizes a Cox partial likelihood, so effects are interpreted on the log-risk scale.

  4. Breslow baseline hazard validity
    Absolute survival reconstruction relies on fold-specific Breslow baseline hazard estimation.

  5. Imputation and aggregation assumption
    Multiple imputations are analyzed via repeated CV and cross-imputation aggregation (not formal Rubin pooling for every metric).

Code
#@title ⚡ Step 5: Full XGBoost Analysis + Dual CV-SHAP (Fully Audited & 100% Reproducible)

#from pandas.core.indexes.accessors import _TDTotalSecondsReturnType
import numpy as np
import pandas as pd
import xgboost as xgb
import pickle
import glob
import time
import gc
import os
import re
import shap
from pathlib import Path
import shutil
import uuid
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import brier_score_loss, f1_score, roc_curve, confusion_matrix
from sksurv.metrics import concordance_index_ipcw, integrated_brier_score

import warnings
warnings.filterwarnings("ignore")

if 'nb_print' not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Fast preflight mode: set True to only test paths and stop
DRY_RUN_PATHS_ONLY = False

def preflight_out_dir(out_dir_rel="_out", min_free_gb=1.0):
    cwd = Path.cwd().resolve()
    out_dir = (cwd / out_dir_rel).resolve()

    try:
        out_dir.mkdir(parents=True, exist_ok=True)
    except OSError as e:
        raise RuntimeError(f"Cannot create OUT_DIR '{out_dir}' from cwd '{cwd}': {e}") from e

    # write/read/delete probe (catches permission, mount, device issues)
    probe = out_dir / f".probe_{uuid.uuid4().hex}.tmp"
    try:
        probe.write_text("ok", encoding="utf-8")
        if probe.read_text(encoding="utf-8") != "ok":
            raise RuntimeError("Write/read probe mismatch.")
    except Exception as e:
        raise RuntimeError(f"OUT_DIR not writable/readable: '{out_dir}'. {e}") from e
    finally:
        try:
            probe.unlink(missing_ok=True)
        except Exception:
            pass

    free_gb = shutil.disk_usage(out_dir).free / (1024**3)
    if free_gb < min_free_gb:
        raise RuntimeError(f"Low free space in '{out_dir}': {free_gb:.2f} GB")

    return out_dir, cwd, free_gb

OUT_DIR_PATH, CWD_PATH, FREE_GB = preflight_out_dir("_out", min_free_gb=1.0)
OUT_DIR = str(OUT_DIR_PATH)  # keeps your existing f"{OUT_DIR}/..." code working

nb_print(f"✅ Path preflight OK | cwd={CWD_PATH} | OUT_DIR={OUT_DIR} | free={FREE_GB:.2f} GB")

if DRY_RUN_PATHS_ONLY:
    nb_print("Path preflight passed. Stopping before heavy runtime.")
    raise SystemExit(0)

nb_print("Starting Full XGBoost Analysis (Cause-Specific Hazard / Death-Censored)...")

# --- 0. CHECK FOR EXISTING FILE (ATOMIC LOAD) ---
existing_shap_files = sorted(glob.glob(f"{OUT_DIR}/xgb6_corr_DUAL_SHAP_Aggregated_*_mar26.pkl"))
compute_needed = True
metrics_log = [] 
cv_splits_log = [] 

if existing_shap_files:
    latest_shap = existing_shap_files[-1]
    match = re.search(r'_(\d{8}_\d{4})_mar26\.pkl', latest_shap)
    if match:
        ts = match.group(1)
        csv_file = f"{OUT_DIR}/xgb6_corr_DUAL_metrics_{ts}_mar26.csv"
        raw_file = f"{OUT_DIR}/xgb6_corr_DUAL_final_ev_hyp_{ts}_mar26.pkl"
        hz_file = f"{OUT_DIR}/xgb6_corr_DUAL_BaselineHazards_{ts}_mar26.pkl"
        split_file = f"{OUT_DIR}/xgb6_corr_DUAL_CV_Splits_{ts}_mar26.pkl"
        
        if os.path.exists(csv_file) and os.path.exists(raw_file) and os.path.exists(hz_file) and os.path.exists(split_file):
            nb_print(f"✅ Found complete set of artifacts for timestamp {ts}")
            nb_print("⏭️ Skipping computation and loading artifacts into memory.")
            try:
                with open(latest_shap, "rb") as f: shap_data_dual = pickle.load(f)
                
                df_results = pd.read_csv(csv_file)
                metrics_log = df_results.to_dict('records') 
                
                with open(raw_file, "rb") as f: raw_data_log = pickle.load(f)
                with open(hz_file, "rb") as f: baseline_hazards_log = pickle.load(f)
                with open(split_file, "rb") as f: cv_splits_log = pickle.load(f)
                
                nb_print("📦 Successfully loaded all artifacts, including CV Splits.")
                compute_needed = False
            except Exception as e:
                nb_print(f"⚠️ Error loading artifacts: {e}. Will recompute.")
        else:
            nb_print(f"⚠️ Partial artifacts found for {ts} (missing CSV, Raw, Hz, or Splits). Recomputing...")

if compute_needed:
    start_time = time.time()

# --- 1. CONFIGURATION ---
    N_IMPUTATIONS = len(imputations_list_mar26)
    K_FOLDS = 5 
    TIMES_EVAL = [3, 6, 9, 12, 24, 36, 48, 60, 72, 84, 96, 108]
    # --- 2026-02-24 ADD: pooled multi-horizon SHAP export ---
    SAVE_MULTIH_POOLED_SHAP = True # raw TreeSHAP by horizon does not make sense.
    TARGET_HORIZONS_EXPORT = [3, 6, 12, 36, 60, 96]
    COMPUTE_SHAP_ALL_IMPS = True 
    SHAP_IMP_IDX = 0 

    BEST_PARAMS_r = {
        "objective": "survival:cox",
        "eval_metric": "cox-nloglik",
        "tree_method": "hist",
        "device": "cpu",
        "verbosity": 0,
        'nthread': 30,
        "seed": 2125,
        "learning_rate": 0.0033765926340982,
        "max_depth": 9,
        "min_child_weight": 4,
        "subsample": 0.6102749292690339,
        "colsample_bytree": 0.405146532644876,
        "reg_alpha": 0.2849441972783855,
        "reg_lambda": 0.6681588791085157,
        "gamma": 0.0431853685583679,
    }
    BEST_PARAMS_d = {
        "objective": "survival:cox",
        "eval_metric": "cox-nloglik",
        "tree_method": "hist",
        "device": "cpu",
        "verbosity": 0,
        'nthread': 30,
        "seed": 2125,
        "learning_rate": 0.03096732484594198,
        "max_depth": 6,
        "min_child_weight": 15,
        "subsample": 0.7787433700020623,
        "colsample_bytree": 0.6168899316303945,
        "reg_alpha": 0.01370823890516611,
        "reg_lambda": 2.536795457202958,
        "gamma": 0.9822857777079363,
    }
    # --- 2. SAFE HELPERS ---
    def safe_uno_ipcw(y_tr, y_va, risk, tau=None):
        if tau is not None:
            try:
                return concordance_index_ipcw(y_tr, y_va, risk, tau=tau)[0], tau
            except:
                return np.nan, tau

        try:
            return concordance_index_ipcw(y_tr, y_va, risk)[0], None
        except:
            pass 

        max_va = float(np.max(y_va["time"]))
        for q in (0.95, 0.90, 0.85, 0.80):
            try:
                tau_candidate = float(np.quantile(y_tr["time"], q))
                tau_safe = min(tau_candidate, max_va - 1e-8)
                if tau_safe <= 0: continue
                c_idx = concordance_index_ipcw(y_tr, y_va, risk, tau=tau_safe)[0]
                return c_idx, tau_safe
            except:
                continue
        return np.nan, None

    def get_survival_probs_breslow(risk_train, risk_val, y_train_struc, time_grid):
        exp_risk_tr = np.exp(risk_train)
        exp_risk_va = np.exp(risk_val)
        
        df_base = pd.DataFrame({'time': y_train_struc['time'], 'event': y_train_struc['event'], 'risk_score': exp_risk_tr})
        df_base = df_base.sort_values('time')
        unique_times = df_base['time'].unique()
        risk_sum = df_base['risk_score'][::-1].cumsum()[::-1]
        event_counts = df_base.groupby('time')['event'].sum()
        
        cum_hazard = 0
        h0_times = []
        h0_vals = []
        for t in unique_times:
            if t in event_counts.index:
                deaths = event_counts.loc[t]
                idx = df_base['time'].searchsorted(t, side='left')
                total_risk = risk_sum.iloc[idx]
                if total_risk > 0: cum_hazard += deaths / total_risk
            h0_times.append(t)
            h0_vals.append(cum_hazard)
            
        h0_interp = np.interp(time_grid, h0_times, h0_vals)
        surv_probs = np.exp(-np.outer(exp_risk_va, h0_interp))
        return np.nan_to_num(surv_probs, nan=0.0), h0_times, h0_vals

    def get_binary_target_bool(event_bool_va, times_va, t_horizon):
        is_case = (event_bool_va == True) & (times_va <= t_horizon)
        is_control = (times_va > t_horizon)
        valid_mask = is_case | is_control
        return is_case[valid_mask].astype(int), valid_mask

    def find_optimal_threshold(y_true, probas, method='youden'):
        if len(np.unique(y_true)) < 2: return 0.5
        fpr, tpr, thresholds = roc_curve(y_true, probas)
        if method == 'youden':
            idx = np.argmax(tpr - fpr)
        elif method == 'f1':
            f1_scores = [f1_score(y_true, (probas >= th).astype(int), zero_division=0) for th in thresholds]
            idx = np.argmax(f1_scores)
        return thresholds[idx]

    def evaluate_with_threshold(y_true, probas, threshold):
        y_pred = (probas >= threshold).astype(int)
        tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0,1]).ravel()
        with np.errstate(divide='ignore', invalid='ignore'):
            return {
                'Sens': tp/(tp+fn) if (tp+fn)>0 else 0.0,
                'Spec': tn/(tn+fp) if (tn+fp)>0 else 0.0,
                'PPV': tp/(tp+fp) if (tp+fp)>0 else 0.0,
                'NPV': tn/(tn+fn) if (tn+fn)>0 else 0.0,
                'F1': f1_score(y_true, y_pred, zero_division=0)
            }

    # --- 3. MAIN PIPELINE ---
    raw_data_log = []
    baseline_hazards_log = []
    cv_shap_r = []; cv_shap_d = []; cv_X_val = []

    for imp_idx in range(N_IMPUTATIONS):
        X_curr = imputations_list_mar26[imp_idx].copy()
        
        if not X_curr.index.is_unique:
            raise ValueError("X_curr index is not unique. Please ensure patient indices are preserved.")
            
        plan_cols = [c for c in X_curr.columns if c.startswith("plan_type_corr")]
        X_curr[plan_cols] = X_curr[plan_cols].astype("float32")
        
        y_d = y_surv_death_list[imp_idx]
        y_r = y_surv_readm_list_corrected[imp_idx] 
        
        t_d_full = y_d['time'] if isinstance(y_d['time'], np.ndarray) else y_d['time'].values
        e_d_full = y_d['event'] if isinstance(y_d['event'], np.ndarray) else y_d['event'].values.astype(bool)
        t_r_full = y_r['time'] if isinstance(y_r['time'], np.ndarray) else y_r['time'].values
        e_r_full = y_r['event'] if isinstance(y_r['event'], np.ndarray) else y_r['event'].values.astype(bool)
        
        plan_idx = np.zeros(len(X_curr), dtype=int) 
        if "plan_type_corr_pg_pr" in X_curr.columns: plan_idx[X_curr["plan_type_corr_pg_pr"] == 1] = 1
        if "plan_type_corr_m_pr" in X_curr.columns: plan_idx[X_curr["plan_type_corr_m_pr"] == 1] = 2
        if "plan_type_corr_pg_pai" in X_curr.columns: plan_idx[X_curr["plan_type_corr_pg_pai"] == 1] = 3
        if "plan_type_corr_m_pai" in X_curr.columns: plan_idx[X_curr["plan_type_corr_m_pai"] == 1] = 4
        
        events_cr = np.zeros(len(e_d_full), dtype=int)
        mask_death_first = e_d_full & (~e_r_full | (t_d_full < t_r_full))
        events_cr[mask_death_first] = 1
        mask_readm_first = e_r_full & (~e_d_full | (t_r_full < t_d_full))
        events_cr[mask_readm_first] = 2
        strat_labels = (events_cr * 10) + plan_idx

        nb_print(f"📂 Processing Imputation {imp_idx + 1}/{N_IMPUTATIONS}...")
        
        skf = StratifiedKFold(n_splits=K_FOLDS, shuffle=True, random_state=2125) 
        
        for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_curr, strat_labels)):
            X_tr, X_va = X_curr.iloc[train_idx], X_curr.iloc[val_idx]
            
            # --- STORE CV SPLITS ---
            cv_splits_log.append({
                'imp_idx': imp_idx,
                'fold_idx': fold_idx,
                'train_indices': train_idx.tolist(),
                'val_indices': val_idx.tolist(),
                'train_ids': X_curr.index[train_idx].tolist(),
                'val_ids': X_curr.index[val_idx].tolist()
            })
            
            t_d_tr, e_d_tr = t_d_full[train_idx], e_d_full[train_idx]
            t_d_va, e_d_va = t_d_full[val_idx], e_d_full[val_idx]
            t_r_tr, e_r_tr = t_r_full[train_idx], e_r_full[train_idx]
            t_r_va, e_r_va = t_r_full[val_idx], e_r_full[val_idx]

            # Train Death
            y_xgb_d_tr = np.where(e_d_tr, t_d_tr, -t_d_tr)
            y_xgb_d_va = np.where(e_d_va, t_d_va, -t_d_va)
            dtrain_d = xgb.DMatrix(X_tr, label=y_xgb_d_tr); dval_d = xgb.DMatrix(X_va, label=y_xgb_d_va)
            model_d = xgb.train(BEST_PARAMS_d, dtrain_d, num_boost_round=5000, evals=[(dval_d, 'val')], early_stopping_rounds=100, verbose_eval=False)
            
            # Train Readmission
            y_xgb_r_tr = np.where(e_r_tr, t_r_tr, -t_r_tr)
            y_xgb_r_va = np.where(e_r_va, t_r_va, -t_r_va)
            dtrain_r = xgb.DMatrix(X_tr, label=y_xgb_r_tr); dval_r = xgb.DMatrix(X_va, label=y_xgb_r_va)
            model_r = xgb.train(BEST_PARAMS_r, dtrain_r, num_boost_round=5000, evals=[(dval_r, 'val')], early_stopping_rounds=100, verbose_eval=False)

            # DUAL SHAP
            if COMPUTE_SHAP_ALL_IMPS or imp_idx == SHAP_IMP_IDX:
                X_va_indexed = X_va.copy()
                
                explainer_r = shap.TreeExplainer(model_r)
                shap_r_vals = explainer_r.shap_values(X_va)
                df_shap_r = pd.DataFrame(shap_r_vals, index=X_va.index, columns=X_va.columns)
                cv_shap_r.append(df_shap_r)
                
                explainer_d = shap.TreeExplainer(model_d)
                shap_d_vals = explainer_d.shap_values(X_va)
                df_shap_d = pd.DataFrame(shap_d_vals, index=X_va.index, columns=X_va.columns)
                cv_shap_d.append(df_shap_d)
                
                cv_X_val.append(X_va_indexed)

            # PREDICTIONS 
            risk_d_tr = np.clip(model_d.predict(dtrain_d, output_margin=True), -15, 15)
            risk_d_va = np.clip(model_d.predict(dval_d, output_margin=True), -15, 15)
            risk_r_tr = np.clip(model_r.predict(dtrain_r, output_margin=True), -15, 15)
            risk_r_va = np.clip(model_r.predict(dval_r, output_margin=True), -15, 15)
            
            y_tr_struc_d, y_va_struc_d = y_d[train_idx], y_d[val_idx]
            y_tr_struc_r, y_va_struc_r = y_r[train_idx], y_r[val_idx]

            # MATRIX CALCULATIONS (Both via Breslow)
            probs_d_tr, _, _ = get_survival_probs_breslow(risk_d_tr, risk_d_tr, y_tr_struc_d, TIMES_EVAL)
            probs_d_va, h0_times_d, h0_vals_d = get_survival_probs_breslow(risk_d_tr, risk_d_va, y_tr_struc_d, TIMES_EVAL)
            
            probs_r_tr, _, _ = get_survival_probs_breslow(risk_r_tr, risk_r_tr, y_tr_struc_r, TIMES_EVAL)
            probs_r_va, h0_times_r, h0_vals_r = get_survival_probs_breslow(risk_r_tr, risk_r_va, y_tr_struc_r, TIMES_EVAL)

            baseline_hazards_log.append({
                'imp_idx': imp_idx, 'fold_idx': fold_idx,
                'times_d': h0_times_d, 'h0_d': h0_vals_d,
                'times_r': h0_times_r, 'h0_r': h0_vals_r
            })

            # --- METRICS: GLOBAL ---
            c_idx_d, used_tau_d = safe_uno_ipcw(y_tr_struc_d, y_va_struc_d, risk_d_va)
            if np.isfinite(c_idx_d):
                metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': "Uno's C-Index", 'Time': 'Global', 'Value': c_idx_d})
                if used_tau_d is not None:
                    nb_print(f"  ⚠️ [Fold {fold_idx}] Death Global Uno truncated to tau={used_tau_d:.1f}")
            else:
                nb_print(f"  ❌ [Fold {fold_idx}] Death Global Uno completely failed.")

            c_idx_r, used_tau_r = safe_uno_ipcw(y_tr_struc_r, y_va_struc_r, risk_r_va)
            if np.isfinite(c_idx_r):
                metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': "Uno's C-Index", 'Time': 'Global', 'Value': c_idx_r})
                if used_tau_r is not None:
                    nb_print(f"  ⚠️ [Fold {fold_idx}] Readm Global Uno truncated to tau={used_tau_r:.1f}")
            else:
                nb_print(f"  ❌ [Fold {fold_idx}] Readm Global Uno completely failed.")
            
            try:
                ibs_d = integrated_brier_score(y_tr_struc_d, y_va_struc_d, probs_d_va, TIMES_EVAL)
                metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': 'IBS', 'Time': 'Global', 'Value': ibs_d})
            except Exception as e: 
                nb_print(f"  ❌ [Fold {fold_idx}] Death Global IBS failed: {e}")
            
            try:
                ibs_r = integrated_brier_score(y_tr_struc_r, y_va_struc_r, probs_r_va, TIMES_EVAL)
                metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': 'IBS', 'Time': 'Global', 'Value': ibs_r})
            except Exception as e: 
                nb_print(f"  ❌ [Fold {fold_idx}] Readm Global IBS failed: {e}")

            # --- METRICS: TIME-DEPENDENT ---
            for i, t in enumerate(TIMES_EVAL):
                # DEATH
                y_bin_d_tr, mask_d_tr = get_binary_target_bool(e_d_full[train_idx], t_d_tr, t)
                y_bin_d_va, mask_d_va = get_binary_target_bool(e_d_full[val_idx], t_d_va, t)
                prob_d_tr = np.nan_to_num(1 - probs_d_tr[:, i], nan=0.0)
                prob_d_va = np.nan_to_num(1 - probs_d_va[:, i], nan=0.0)
                
                if mask_d_va.sum() > 0 and mask_d_tr.sum() > 0:
                    opt_thresh_d = find_optimal_threshold(y_bin_d_tr, prob_d_tr[mask_d_tr], method='f1') 
                    bs_d = brier_score_loss(y_bin_d_va, prob_d_va[mask_d_va])
                    m_d = evaluate_with_threshold(y_bin_d_va, prob_d_va[mask_d_va], opt_thresh_d)
                    
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': 'Brier Score', 'Time': t, 'Value': bs_d})
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': 'Threshold', 'Time': t, 'Value': opt_thresh_d})
                    for k, v in m_d.items(): metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': k, 'Time': t, 'Value': v})
                else:
                    nb_print(f"⚠️ [Imp {imp_idx}|Fold {fold_idx}|Time {t}] Death Brier/Thresholds failed: Not enough events/controls.")
                
                # Death Time-Dependent Uno
                uno_d_t, _ = safe_uno_ipcw(y_tr_struc_d, y_va_struc_d, risk_d_va, tau=t)
                if np.isfinite(uno_d_t):
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': "Uno's C-Index", 'Time': t, 'Value': uno_d_t})

                # Death Time-Dependent IBS
                times_up_to_t = [tt for tt in TIMES_EVAL if tt <= t]
                if len(times_up_to_t) >= 2:
                    try:
                        idxs = [j for j, tt in enumerate(TIMES_EVAL) if tt <= t]
                        probs_d_slice = probs_d_va[:, idxs]
                        ibs_d_t = integrated_brier_score(y_tr_struc_d, y_va_struc_d, probs_d_slice, times_up_to_t)
                        metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Death', 'Metric': 'IBS', 'Time': t, 'Value': ibs_d_t})
                    except Exception as e: 
                        nb_print(f"⚠️ [Imp {imp_idx}|Fold {fold_idx}|Time {t}] Death IBS failed: {e}")

                # READMISSION
                y_bin_r_tr, mask_r_tr = get_binary_target_bool(e_r_full[train_idx], t_r_tr, t)
                y_bin_r_va, mask_r_va = get_binary_target_bool(e_r_full[val_idx], t_r_va, t)
                prob_r_tr = np.nan_to_num(1 - probs_r_tr[:, i], nan=0.0)
                prob_r_va = np.nan_to_num(1 - probs_r_va[:, i], nan=0.0)
                
                if mask_r_va.sum() > 0 and mask_r_tr.sum() > 0:
                    opt_thresh_r = find_optimal_threshold(y_bin_r_tr, prob_r_tr[mask_r_tr], method='youden') 
                    bs_r = brier_score_loss(y_bin_r_va, prob_r_va[mask_r_va])
                    m_r = evaluate_with_threshold(y_bin_r_va, prob_r_va[mask_r_va], opt_thresh_r)
                    
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': 'Brier Score', 'Time': t, 'Value': bs_r})
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': 'Threshold', 'Time': t, 'Value': opt_thresh_r})
                    for k, v in m_r.items(): metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': k, 'Time': t, 'Value': v})
                else:
                    nb_print(f"⚠️ [Imp {imp_idx}|Fold {fold_idx}|Time {t}] Readm Brier/Thresholds failed: Not enough events/controls.")

                # Readmission Time-Dependent Uno
                uno_r_t, _ = safe_uno_ipcw(y_tr_struc_r, y_va_struc_r, risk_r_va, tau=t)
                if np.isfinite(uno_r_t):
                    metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': "Uno's C-Index", 'Time': t, 'Value': uno_r_t})
                
                # Readmission Time-Dependent IBS
                times_up_to_t_r = [tt for tt in TIMES_EVAL if tt <= t]
                if len(times_up_to_t_r) >= 2: 
                    try:
                        idxs_r = [j for j, tt in enumerate(TIMES_EVAL) if tt <= t]
                        probs_r_slice = probs_r_va[:, idxs_r]
                        ibs_r_t = integrated_brier_score(y_tr_struc_r, y_va_struc_r, probs_r_slice, times_up_to_t_r)
                        metrics_log.append({'Imp': imp_idx, 'Fold': fold_idx, 'Risk': 'Readmission', 'Metric': 'IBS', 'Time': t, 'Value': ibs_r_t})
                    except Exception as e: 
                        nb_print(f"⚠️ [Imp {imp_idx}|Fold {fold_idx}|Time {t}] Readm IBS failed: {e}")
            
            raw_data_log.append({
                'imp_idx': imp_idx, 'fold_idx': fold_idx,
                'best_iter_death': model_d.best_iteration,
                'best_iter_readm': model_r.best_iteration,
                'risk_pred_readm': risk_r_va,
                'risk_pred_death': risk_d_va,
                'probs_readm_matrix': probs_r_va, 
                'probs_death_matrix': probs_d_va,
                'eval_times': TIMES_EVAL,
                'y_val_r': y_va_struc_r,  
                'y_val_d': y_va_struc_d                 
            })
            
            del model_d, model_r, dtrain_d, dval_d, dtrain_r, dval_r
            gc.collect()

    # --- CONSOLIDATE & EXPORT ---
    nb_print("\nConsolidating Analysis Artifacts...")

    timestamp = datetime.now().strftime("%Y%m%d_%H%M")

    shap_name = None
    # --- 2026-02-24 ADD ---
    multih_name = None
    multih_summary_name = None
    
    if cv_X_val:
        X_all = pd.concat(cv_X_val, axis=0)
        df_shap_r_all = pd.concat(cv_shap_r)
        df_shap_d_all = pd.concat(cv_shap_d)
        
        X_all_avg = X_all.groupby(X_all.index).mean().sort_index()
        shap_r_avg = df_shap_r_all.groupby(df_shap_r_all.index).mean().reindex(X_all_avg.index)
        shap_d_avg = df_shap_d_all.groupby(df_shap_d_all.index).mean().reindex(X_all_avg.index)
        
        shap_data_dual = {
            'X_all': X_all_avg,
            'shap_r_all': shap_r_avg.values, 
            'shap_d_all': shap_d_avg.values, 
            'feature_names': X_all_avg.columns.tolist()
        }
        shap_name = f"{OUT_DIR}/xgb6_corr_DUAL_SHAP_Aggregated_{timestamp}_mar26.pkl"
        with open(shap_name, "wb") as f:
            pickle.dump(shap_data_dual, f)
        nb_print(f"Saved Averaged DUAL SHAP data: {shap_name}")

        # --- 2026-02-24 ADD: pooled Multi-Horizon SHAP artifact ---
        if SAVE_MULTIH_POOLED_SHAP:
            target_h = [int(h) for h in TARGET_HORIZONS_EXPORT if h in TIMES_EVAL]
            if not target_h:
                target_h = [3, 6, 12, 36, 60, 96]

            contrib_r = (
                df_shap_r_all.groupby(df_shap_r_all.index).size()
                .reindex(X_all_avg.index).fillna(0).astype(np.int32)
            )
            contrib_d = (
                df_shap_d_all.groupby(df_shap_d_all.index).size()
                .reindex(X_all_avg.index).fillna(0).astype(np.int32)
            )

            # Cox SHAP is on log-hazard margin scale (time-invariant by PH design)
            shap_r_matrix = shap_r_avg.to_numpy(dtype=np.float32, copy=False)
            shap_d_matrix = shap_d_avg.to_numpy(dtype=np.float32, copy=False)

            shap_data_dual_multih = {
                "X_all": X_all_avg,
                "feature_names": X_all_avg.columns.tolist(),
                "horizons": target_h,
                "meta": {
                    "model_objective": "survival:cox",
                    "shap_scale": "log-hazard margin",
                    "time_invariant_by_design": True,
                    "pooling_level": "patient-level pooled over all imputations/folds",
                    "note": "Same pooled SHAP matrix is referenced for each horizon."
                },
                "Readmission": {},
                "Death": {}
            }

            for h in target_h:
                shap_data_dual_multih["Readmission"][h] = {
                    "shap_values": shap_r_matrix,
                    "n_contrib_per_patient": contrib_r.to_numpy(dtype=np.int32),
                    "n_contrib_total": int(len(df_shap_r_all))
                }
                shap_data_dual_multih["Death"][h] = {
                    "shap_values": shap_d_matrix,
                    "n_contrib_per_patient": contrib_d.to_numpy(dtype=np.int32),
                    "n_contrib_total": int(len(df_shap_d_all))
                }

            multih_name = f"{OUT_DIR}/xgb6_corr_DUAL_SHAP_MultiH_{timestamp}_mar26.pkl"
            with open(multih_name, "wb") as f:
                pickle.dump(shap_data_dual_multih, f, protocol=pickle.HIGHEST_PROTOCOL)

            multih_summary = pd.DataFrame({
                "Outcome": ["Readmission", "Death"],
                "N_patients": [shap_r_matrix.shape[0], shap_d_matrix.shape[0]],
                "N_features": [shap_r_matrix.shape[1], shap_d_matrix.shape[1]],
                "N_contrib_total": [int(len(df_shap_r_all)), int(len(df_shap_d_all))],
                "Horizons_exported": [",".join(map(str, target_h)), ",".join(map(str, target_h))]
            })
            multih_summary_name = f"{OUT_DIR}/xgb6_corr_DUAL_SHAP_MultiH_summary_{timestamp}_mar26.csv"
            multih_summary.to_csv(multih_summary_name, index=False)

            nb_print(f"💾 Saved pooled Multi-H SHAP data: {multih_name}")
            nb_print(f"💾 Saved pooled Multi-H SHAP summary: {multih_summary_name}")

    df_results = pd.DataFrame(metrics_log)

    csv_name = f"{OUT_DIR}/xgb6_corr_DUAL_metrics_{timestamp}_mar26.csv"
    df_results.to_csv(csv_name, index=False)
    pkl_name = f"{OUT_DIR}/xgb6_corr_DUAL_final_ev_hyp_{timestamp}_mar26.pkl"
    with open(pkl_name, "wb") as f:
        pickle.dump(raw_data_log, f, protocol=pickle.HIGHEST_PROTOCOL)
    hz_name = f"{OUT_DIR}/xgb6_corr_DUAL_BaselineHazards_{timestamp}_mar26.pkl"
    with open(hz_name, "wb") as f:
        pickle.dump(baseline_hazards_log, f, protocol=pickle.HIGHEST_PROTOCOL)
    nb_print(f"Saved Cause-Specific Baseline Hazards pickle: {hz_name}")
    splits_name = f"{OUT_DIR}/xgb6_corr_DUAL_CV_Splits_{timestamp}_mar26.pkl"
    with open(splits_name, "wb") as f:
        pickle.dump(cv_splits_log, f, protocol=pickle.HIGHEST_PROTOCOL)
    nb_print(f"Saved CV splits: {splits_name}")

    total_time = (time.time() - start_time) / 60
    nb_print(f"\n🏁 Complete! Total execution time: {total_time:.2f} minutes.")
    try:
        from google.colab import files
        files.download(csv_name)
        files.download(pkl_name)
        files.download(hz_name)
        files.download(splits_name)
        if shap_name: files.download(shap_name)
        if multih_name: files.download(multih_name)
        if multih_summary_name: files.download(multih_summary_name)
    except: pass
✅ Path preflight OK | cwd=G:\My Drive\Alvacast\SISTRAT 2023\cons | OUT_DIR=G:\My Drive\Alvacast\SISTRAT 2023\cons\_out | free=450.88 GB
Starting Full XGBoost Analysis (Cause-Specific Hazard / Death-Censored)...
📂 Processing Imputation 1/5...
📂 Processing Imputation 2/5...
📂 Processing Imputation 3/5...
📂 Processing Imputation 4/5...
📂 Processing Imputation 5/5...
Consolidating Analysis Artifacts...
Saved Averaged DUAL SHAP data: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out/xgb6_corr_DUAL_SHAP_Aggregated_20260306_1821_mar26.pkl
💾 Saved pooled Multi-H SHAP data: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out/xgb6_corr_DUAL_SHAP_MultiH_20260306_1821_mar26.pkl
💾 Saved pooled Multi-H SHAP summary: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out/xgb6_corr_DUAL_SHAP_MultiH_summary_20260306_1821_mar26.csv
Saved Cause-Specific Baseline Hazards pickle: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out/xgb6_corr_DUAL_BaselineHazards_20260306_1821_mar26.pkl
Saved CV splits: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out/xgb6_corr_DUAL_CV_Splits_20260306_1821_mar26.pkl
🏁 Complete! Total execution time: 54.84 minutes.

Horizons are exported labels; SHAP values are identical across horizons by PH model design.

Code
(
    pd.DataFrame(metrics_log)
    .loc[lambda d: d['Metric'].isin(["Uno's C-Index", 'IBS'])]
    .loc[lambda d: d['Time'].isin(['Global'])]
    .groupby(['Risk', 'Metric'])['Value']
    .agg(['mean', 'std'])
    .reset_index()
)
Risk Metric mean std
0 Death IBS 0.035996 0.000621
1 Death Uno's C-Index 0.744036 0.011530
2 Readmission IBS 0.168982 0.000973
3 Readmission Uno's C-Index 0.616215 0.006511
Code
from IPython.display import HTML, display

metrics_log_comb = pd.DataFrame(metrics_log)
metrics_log_comb = (
    metrics_log_comb
    .round(3)
    .sort_values(['Risk', 'Metric', 'Time'])
)

nb_print("First 10 rows")
display(metrics_log_comb.head(10))

nb_print("Last 10 rows")
display(metrics_log_comb.tail(10))
First 10 rows
Imp Fold Risk Metric Time Value
4 0 0 Death Brier Score 3 0.005
222 0 1 Death Brier Score 3 0.004
440 0 2 Death Brier Score 3 0.004
658 0 3 Death Brier Score 3 0.006
876 0 4 Death Brier Score 3 0.005
1094 1 0 Death Brier Score 3 0.005
1312 1 1 Death Brier Score 3 0.004
1530 1 2 Death Brier Score 3 0.004
1748 1 3 Death Brier Score 3 0.006
1966 1 4 Death Brier Score 3 0.005
Last 10 rows
Imp Fold Risk Metric Time Value
3271 3 0 Readmission Uno's C-Index Global 0.609
3489 3 1 Readmission Uno's C-Index Global 0.616
3707 3 2 Readmission Uno's C-Index Global 0.619
3925 3 3 Readmission Uno's C-Index Global 0.627
4143 3 4 Readmission Uno's C-Index Global 0.610
4361 4 0 Readmission Uno's C-Index Global 0.611
4579 4 1 Readmission Uno's C-Index Global 0.616
4797 4 2 Readmission Uno's C-Index Global 0.620
5015 4 3 Readmission Uno's C-Index Global 0.626
5233 4 4 Readmission Uno's C-Index Global 0.609
Code
# --- Aggregate metrics + bootstrap CIs (no export) ---

import numpy as np
import pandas as pd
from IPython.display import display, HTML

if "df_results" not in globals():
    df_results = pd.DataFrame(metrics_log_comb)

def bootstrap_ci_non_normal(values, n_boot=2000, alpha=0.05, random_state=2125):
    values = np.asarray(values, dtype=float)
    values = values[np.isfinite(values)]
    if values.size == 0:
        return np.nan, np.nan, np.nan
    if values.size == 1:
        v = float(values[0])
        return v, v, v

    rng = np.random.default_rng(random_state)
    n = values.size
    boot_means = np.empty(n_boot, dtype=float)
    for i in range(n_boot):
        boot_means[i] = rng.choice(values, size=n, replace=True).mean()

    mean_val = float(values.mean())
    lower = float(np.quantile(boot_means, alpha / 2))
    upper = float(np.quantile(boot_means, 1 - alpha / 2))
    return mean_val, lower, upper

summary_stats = []
for (outcome, time_pt, metric), group in df_results.groupby(["Risk", "Time", "Metric"], dropna=False):
    vals = group["Value"].dropna().values
    mean_val, lower, upper = bootstrap_ci_non_normal(vals)
    summary_stats.append({
        "Outcome": outcome,
        "Time": time_pt,
        "Metric": metric,
        "Mean": mean_val,
        "CI_Lower": lower,
        "CI_Upper": upper,
        "Format": f"{mean_val:.3f} [{lower:.3f}-{upper:.3f}]" if np.isfinite(mean_val) else "NA"
    })

df_summary = pd.DataFrame(summary_stats).sort_values(["Outcome", "Metric", "Time"]).reset_index(drop=True)

ds_aj_df_summary = df_summary

display(HTML(f"""
<div style="
    height:500px;
    overflow:auto;
    border:1px solid #ccc;
    padding:10px;
    background-color:white;
    font-family:'Times New Roman';
    font-size:13px;
">
    {ds_aj_df_summary.to_html(index=False)}
</div>
"""))
Outcome Time Metric Mean CI_Lower CI_Upper Format
Death 3 Brier Score 0.004945 0.004612 0.005257 0.005 [0.005-0.005]
Death 6 Brier Score 0.007008 0.006659 0.007352 0.007 [0.007-0.007]
Death 9 Brier Score 0.008884 0.008462 0.009308 0.009 [0.008-0.009]
Death 12 Brier Score 0.010960 0.010549 0.011402 0.011 [0.011-0.011]
Death 24 Brier Score 0.020957 0.020484 0.021453 0.021 [0.020-0.021]
Death 36 Brier Score 0.031790 0.031482 0.032110 0.032 [0.031-0.032]
Death 48 Brier Score 0.043711 0.043224 0.044131 0.044 [0.043-0.044]
Death 60 Brier Score 0.059104 0.058334 0.059801 0.059 [0.058-0.060]
Death 72 Brier Score 0.078487 0.078061 0.078886 0.078 [0.078-0.079]
Death 84 Brier Score 0.106293 0.105474 0.107134 0.106 [0.105-0.107]
Death 96 Brier Score 0.144751 0.143523 0.145963 0.145 [0.144-0.146]
Death 108 Brier Score 0.211055 0.208989 0.212809 0.211 [0.209-0.213]
Death 3 F1 0.068438 0.059066 0.076917 0.068 [0.059-0.077]
Death 6 F1 0.090598 0.082561 0.097869 0.091 [0.083-0.098]
Death 9 F1 0.093074 0.087420 0.098551 0.093 [0.087-0.099]
Death 12 F1 0.104675 0.100188 0.109545 0.105 [0.100-0.110]
Death 24 F1 0.168633 0.163657 0.173721 0.169 [0.164-0.174]
Death 36 F1 0.214948 0.208818 0.221350 0.215 [0.209-0.221]
Death 48 F1 0.267459 0.260657 0.274847 0.267 [0.261-0.275]
Death 60 F1 0.331021 0.323626 0.338681 0.331 [0.324-0.339]
Death 72 F1 0.388608 0.379958 0.397145 0.389 [0.380-0.397]
Death 84 F1 0.468009 0.458922 0.476851 0.468 [0.459-0.477]
Death 96 F1 0.540100 0.536832 0.542847 0.540 [0.537-0.543]
Death 108 F1 0.619781 0.616600 0.622888 0.620 [0.617-0.623]
Death 6 IBS 0.005873 0.005540 0.006192 0.006 [0.006-0.006]
Death 9 IBS 0.006815 0.006461 0.007158 0.007 [0.006-0.007]
Death 12 IBS 0.007735 0.007376 0.008099 0.008 [0.007-0.008]
Death 24 IBS 0.011713 0.011349 0.012097 0.012 [0.011-0.012]
Death 36 IBS 0.015711 0.015371 0.016056 0.016 [0.015-0.016]
Death 48 IBS 0.019491 0.019210 0.019779 0.019 [0.019-0.020]
Death 60 IBS 0.023095 0.022847 0.023330 0.023 [0.023-0.023]
Death 72 IBS 0.026554 0.026332 0.026749 0.027 [0.026-0.027]
Death 84 IBS 0.029854 0.029642 0.030041 0.030 [0.030-0.030]
Death 96 IBS 0.032991 0.032773 0.033192 0.033 [0.033-0.033]
Death 108 IBS 0.035996 0.035764 0.036214 0.036 [0.036-0.036]
Death Global IBS 0.035996 0.035764 0.036214 0.036 [0.036-0.036]
Death 3 NPV 0.995363 0.995091 0.995651 0.995 [0.995-0.996]
Death 6 NPV 0.993501 0.993181 0.993821 0.994 [0.993-0.994]
Death 9 NPV 0.991718 0.991282 0.992147 0.992 [0.991-0.992]
Death 12 NPV 0.989882 0.989426 0.990303 0.990 [0.989-0.990]
Death 24 NPV 0.981661 0.981110 0.982221 0.982 [0.981-0.982]
Death 36 NPV 0.972075 0.971542 0.972575 0.972 [0.972-0.973]
Death 48 NPV 0.962999 0.962080 0.963994 0.963 [0.962-0.964]
Death 60 NPV 0.950596 0.949245 0.952055 0.951 [0.949-0.952]
Death 72 NPV 0.934634 0.933077 0.936078 0.935 [0.933-0.936]
Death 84 NPV 0.915857 0.913871 0.917921 0.916 [0.914-0.918]
Death 96 NPV 0.887886 0.885714 0.889960 0.888 [0.886-0.890]
Death 108 NPV 0.835326 0.832469 0.838487 0.835 [0.832-0.838]
Death 3 PPV 0.067215 0.054227 0.082390 0.067 [0.054-0.082]
Death 6 PPV 0.087502 0.078775 0.095623 0.088 [0.079-0.096]
Death 9 PPV 0.088143 0.081616 0.094176 0.088 [0.082-0.094]
Death 12 PPV 0.097710 0.091665 0.103859 0.098 [0.092-0.104]
Death 24 PPV 0.146176 0.142082 0.150749 0.146 [0.142-0.151]
Death 36 PPV 0.201815 0.198084 0.205706 0.202 [0.198-0.206]
Death 48 PPV 0.245981 0.241822 0.250430 0.246 [0.242-0.250]
Death 60 PPV 0.319628 0.316053 0.323383 0.320 [0.316-0.323]
Death 72 PPV 0.394243 0.387175 0.401486 0.394 [0.387-0.401]
Death 84 PPV 0.465118 0.458690 0.471548 0.465 [0.459-0.472]
Death 96 PPV 0.543168 0.534446 0.551405 0.543 [0.534-0.551]
Death 108 PPV 0.631305 0.624096 0.637678 0.631 [0.624-0.638]
Death 3 Sens 0.080606 0.067581 0.094392 0.081 [0.068-0.094]
Death 6 Sens 0.100929 0.088628 0.112897 0.101 [0.089-0.113]
Death 9 Sens 0.103534 0.093639 0.113492 0.104 [0.094-0.113]
Death 12 Sens 0.120013 0.109057 0.132570 0.120 [0.109-0.133]
Death 24 Sens 0.202683 0.190770 0.215128 0.203 [0.191-0.215]
Death 36 Sens 0.231618 0.220016 0.243037 0.232 [0.220-0.243]
Death 48 Sens 0.297630 0.280047 0.315795 0.298 [0.280-0.316]
Death 60 Sens 0.346366 0.329959 0.363818 0.346 [0.330-0.364]
Death 72 Sens 0.387783 0.369331 0.405186 0.388 [0.369-0.405]
Death 84 Sens 0.472560 0.457647 0.487720 0.473 [0.458-0.488]
Death 96 Sens 0.539193 0.527727 0.550238 0.539 [0.528-0.550]
Death 108 Sens 0.609921 0.600583 0.619762 0.610 [0.601-0.620]
Death 3 Spec 0.993788 0.992735 0.994831 0.994 [0.993-0.995]
Death 6 Spec 0.992104 0.991131 0.993147 0.992 [0.991-0.993]
Death 9 Spec 0.989916 0.988799 0.990989 0.990 [0.989-0.991]
Death 12 Spec 0.986863 0.985048 0.988510 0.987 [0.985-0.989]
Death 24 Spec 0.972946 0.971108 0.974715 0.973 [0.971-0.975]
Death 36 Spec 0.966928 0.965487 0.968410 0.967 [0.965-0.968]
Death 48 Spec 0.952315 0.949272 0.955421 0.952 [0.949-0.955]
Death 60 Spec 0.944587 0.941735 0.947369 0.945 [0.942-0.947]
Death 72 Spec 0.935648 0.931368 0.940132 0.936 [0.931-0.940]
Death 84 Spec 0.913394 0.910374 0.916351 0.913 [0.910-0.916]
Death 96 Spec 0.888774 0.883139 0.894325 0.889 [0.883-0.894]
Death 108 Spec 0.846775 0.840001 0.852920 0.847 [0.840-0.853]
Death 3 Threshold 0.041669 0.039075 0.044525 0.042 [0.039-0.045]
Death 6 Threshold 0.053463 0.051052 0.056256 0.053 [0.051-0.056]
Death 9 Threshold 0.061844 0.059680 0.064173 0.062 [0.060-0.064]
Death 12 Threshold 0.068603 0.065582 0.071651 0.069 [0.066-0.072]
Death 24 Threshold 0.089627 0.087115 0.092270 0.090 [0.087-0.092]
Death 36 Threshold 0.112737 0.110178 0.115305 0.113 [0.110-0.115]
Death 48 Threshold 0.117295 0.114277 0.120669 0.117 [0.114-0.121]
Death 60 Threshold 0.125761 0.123246 0.128246 0.126 [0.123-0.128]
Death 72 Threshold 0.129895 0.125660 0.134369 0.130 [0.126-0.134]
Death 84 Threshold 0.115646 0.112946 0.118207 0.116 [0.113-0.118]
Death 96 Threshold 0.105166 0.102354 0.108214 0.105 [0.102-0.108]
Death 108 Threshold 0.094071 0.091803 0.096606 0.094 [0.092-0.097]
Death 3 Uno's C-Index 0.767110 0.760641 0.773511 0.767 [0.761-0.774]
Death 6 Uno's C-Index 0.756177 0.750597 0.762084 0.756 [0.751-0.762]
Death 9 Uno's C-Index 0.756303 0.752364 0.760643 0.756 [0.752-0.761]
Death 12 Uno's C-Index 0.750297 0.744458 0.756585 0.750 [0.744-0.757]
Death 24 Uno's C-Index 0.761329 0.753689 0.769762 0.761 [0.754-0.770]
Death 36 Uno's C-Index 0.761965 0.757232 0.767204 0.762 [0.757-0.767]
Death 48 Uno's C-Index 0.761785 0.758554 0.765202 0.762 [0.759-0.765]
Death 60 Uno's C-Index 0.761619 0.757943 0.765825 0.762 [0.758-0.766]
Death 72 Uno's C-Index 0.757146 0.754946 0.759374 0.757 [0.755-0.759]
Death 84 Uno's C-Index 0.751659 0.749061 0.754439 0.752 [0.749-0.754]
Death 96 Uno's C-Index 0.746719 0.743942 0.749778 0.747 [0.744-0.750]
Death 108 Uno's C-Index 0.741849 0.738676 0.745284 0.742 [0.739-0.745]
Death Global Uno's C-Index 0.744036 0.739990 0.748918 0.744 [0.740-0.749]
Readmission 3 Brier Score 0.041263 0.040599 0.041930 0.041 [0.041-0.042]
Readmission 6 Brier Score 0.066882 0.066353 0.067381 0.067 [0.066-0.067]
Readmission 9 Brier Score 0.085881 0.085201 0.086542 0.086 [0.085-0.087]
Readmission 12 Brier Score 0.102176 0.101411 0.102980 0.102 [0.101-0.103]
Readmission 24 Brier Score 0.147658 0.147211 0.148064 0.148 [0.147-0.148]
Readmission 36 Brier Score 0.181864 0.181525 0.182231 0.182 [0.182-0.182]
Readmission 48 Brier Score 0.211390 0.211036 0.211761 0.211 [0.211-0.212]
Readmission 60 Brier Score 0.240791 0.239848 0.241696 0.241 [0.240-0.242]
Readmission 72 Brier Score 0.269659 0.268780 0.270524 0.270 [0.269-0.271]
Readmission 84 Brier Score 0.301690 0.300837 0.302418 0.302 [0.301-0.302]
Readmission 96 Brier Score 0.329145 0.327956 0.330263 0.329 [0.328-0.330]
Readmission 108 Brier Score 0.356582 0.355161 0.358016 0.357 [0.355-0.358]
Readmission 3 F1 0.138473 0.136281 0.140658 0.138 [0.136-0.141]
Readmission 6 F1 0.198539 0.195041 0.202141 0.199 [0.195-0.202]
Readmission 9 F1 0.239052 0.235558 0.242819 0.239 [0.236-0.243]
Readmission 12 F1 0.272598 0.269234 0.276225 0.273 [0.269-0.276]
Readmission 24 F1 0.349717 0.347942 0.351593 0.350 [0.348-0.352]
Readmission 36 F1 0.398286 0.396795 0.399875 0.398 [0.397-0.400]
Readmission 48 F1 0.433459 0.430962 0.435887 0.433 [0.431-0.436]
Readmission 60 F1 0.462770 0.458308 0.466872 0.463 [0.458-0.467]
Readmission 72 F1 0.484893 0.479287 0.490143 0.485 [0.479-0.490]
Readmission 84 F1 0.503913 0.498252 0.509392 0.504 [0.498-0.509]
Readmission 96 F1 0.513907 0.508757 0.519116 0.514 [0.509-0.519]
Readmission 108 F1 0.538655 0.531830 0.545534 0.539 [0.532-0.546]
Readmission 6 IBS 0.053288 0.052755 0.053804 0.053 [0.053-0.054]
Readmission 9 IBS 0.064085 0.063572 0.064573 0.064 [0.064-0.065]
Readmission 12 IBS 0.073151 0.072607 0.073673 0.073 [0.073-0.074]
Readmission 24 IBS 0.098118 0.097698 0.098542 0.098 [0.098-0.099]
Readmission 36 IBS 0.115671 0.115430 0.115912 0.116 [0.115-0.116]
Readmission 48 IBS 0.128974 0.128808 0.129140 0.129 [0.129-0.129]
Readmission 60 IBS 0.139564 0.139388 0.139754 0.140 [0.139-0.140]
Readmission 72 IBS 0.148358 0.148134 0.148581 0.148 [0.148-0.149]
Readmission 84 IBS 0.156026 0.155749 0.156299 0.156 [0.156-0.156]
Readmission 96 IBS 0.162884 0.162525 0.163220 0.163 [0.163-0.163]
Readmission 108 IBS 0.168982 0.168608 0.169354 0.169 [0.169-0.169]
Readmission Global IBS 0.168982 0.168608 0.169354 0.169 [0.169-0.169]
Readmission 3 NPV 0.969817 0.968956 0.970769 0.970 [0.969-0.971]
Readmission 6 NPV 0.946961 0.945872 0.947928 0.947 [0.946-0.948]
Readmission 9 NPV 0.928690 0.927568 0.929755 0.929 [0.928-0.930]
Readmission 12 NPV 0.911759 0.910730 0.912839 0.912 [0.911-0.913]
Readmission 24 NPV 0.853715 0.852891 0.854556 0.854 [0.853-0.855]
Readmission 36 NPV 0.797979 0.797177 0.798713 0.798 [0.797-0.799]
Readmission 48 NPV 0.738125 0.737334 0.738938 0.738 [0.737-0.739]
Readmission 60 NPV 0.667033 0.665299 0.668948 0.667 [0.665-0.669]
Readmission 72 NPV 0.582881 0.580992 0.584853 0.583 [0.581-0.585]
Readmission 84 NPV 0.475512 0.473676 0.477622 0.476 [0.474-0.478]
Readmission 96 NPV 0.367861 0.364731 0.371248 0.368 [0.365-0.371]
Readmission 108 NPV 0.246878 0.243429 0.250513 0.247 [0.243-0.251]
Readmission 3 PPV 0.080334 0.078843 0.081794 0.080 [0.079-0.082]
Readmission 6 PPV 0.124139 0.121684 0.126680 0.124 [0.122-0.127]
Readmission 9 PPV 0.156870 0.153920 0.159874 0.157 [0.154-0.160]
Readmission 12 PPV 0.186790 0.184091 0.189492 0.187 [0.184-0.189]
Readmission 24 PPV 0.270044 0.267610 0.272584 0.270 [0.268-0.273]
Readmission 36 PPV 0.335512 0.333390 0.337762 0.336 [0.333-0.338]
Readmission 48 PPV 0.395238 0.393442 0.397079 0.395 [0.393-0.397]
Readmission 60 PPV 0.462874 0.460389 0.465122 0.463 [0.460-0.465]
Readmission 72 PPV 0.535164 0.532841 0.537553 0.535 [0.533-0.538]
Readmission 84 PPV 0.613916 0.610895 0.617122 0.614 [0.611-0.617]
Readmission 96 PPV 0.700124 0.697419 0.702626 0.700 [0.697-0.703]
Readmission 108 PPV 0.799959 0.798391 0.801592 0.800 [0.798-0.802]
Readmission 3 Sens 0.505393 0.490221 0.520168 0.505 [0.490-0.520]
Readmission 6 Sens 0.498901 0.482245 0.514220 0.499 [0.482-0.514]
Readmission 9 Sens 0.505117 0.491260 0.517730 0.505 [0.491-0.518]
Readmission 12 Sens 0.505612 0.494935 0.515579 0.506 [0.495-0.516]
Readmission 24 Sens 0.496590 0.492012 0.501145 0.497 [0.492-0.501]
Readmission 36 Sens 0.490383 0.485162 0.495805 0.490 [0.485-0.496]
Readmission 48 Sens 0.480327 0.473336 0.486905 0.480 [0.473-0.487]
Readmission 60 Sens 0.463273 0.454634 0.471816 0.463 [0.455-0.472]
Readmission 72 Sens 0.443770 0.435035 0.452139 0.444 [0.435-0.452]
Readmission 84 Sens 0.427713 0.419767 0.435330 0.428 [0.420-0.435]
Readmission 96 Sens 0.406159 0.400146 0.412225 0.406 [0.400-0.412]
Readmission 108 Sens 0.406462 0.398319 0.414624 0.406 [0.398-0.415]
Readmission 3 Spec 0.732800 0.723879 0.742025 0.733 [0.724-0.742]
Readmission 6 Spec 0.716941 0.707456 0.727374 0.717 [0.707-0.727]
Readmission 9 Spec 0.702908 0.693726 0.713276 0.703 [0.694-0.713]
Readmission 12 Spec 0.698538 0.691865 0.705377 0.699 [0.692-0.705]
Readmission 24 Spec 0.686220 0.681299 0.691345 0.686 [0.681-0.691]
Readmission 36 Spec 0.674374 0.668452 0.680241 0.674 [0.668-0.680]
Readmission 48 Spec 0.665715 0.659707 0.672385 0.666 [0.660-0.672]
Readmission 60 Spec 0.666480 0.659169 0.674126 0.666 [0.659-0.674]
Readmission 72 Spec 0.668310 0.660614 0.676173 0.668 [0.661-0.676]
Readmission 84 Spec 0.658438 0.650781 0.666506 0.658 [0.651-0.667]
Readmission 96 Spec 0.665050 0.658963 0.671946 0.665 [0.659-0.672]
Readmission 108 Spec 0.656479 0.646391 0.666462 0.656 [0.646-0.666]
Readmission 3 Threshold 0.049195 0.048553 0.049856 0.049 [0.049-0.050]
Readmission 6 Threshold 0.081218 0.080218 0.082314 0.081 [0.080-0.082]
Readmission 9 Threshold 0.105623 0.104460 0.106976 0.106 [0.104-0.107]
Readmission 12 Threshold 0.126767 0.125723 0.127899 0.127 [0.126-0.128]
Readmission 24 Threshold 0.185619 0.184857 0.186452 0.186 [0.185-0.186]
Readmission 36 Threshold 0.226624 0.225452 0.227781 0.227 [0.225-0.228]
Readmission 48 Threshold 0.257982 0.256596 0.259463 0.258 [0.257-0.259]
Readmission 60 Threshold 0.283977 0.282011 0.286185 0.284 [0.282-0.286]
Readmission 72 Threshold 0.308629 0.306267 0.311200 0.309 [0.306-0.311]
Readmission 84 Threshold 0.329349 0.326938 0.331927 0.329 [0.327-0.332]
Readmission 96 Threshold 0.350132 0.348534 0.351806 0.350 [0.349-0.352]
Readmission 108 Threshold 0.362603 0.359899 0.365414 0.363 [0.360-0.365]
Readmission 3 Uno's C-Index 0.673061 0.667260 0.678170 0.673 [0.667-0.678]
Readmission 6 Uno's C-Index 0.657483 0.652649 0.662183 0.657 [0.653-0.662]
Readmission 9 Uno's C-Index 0.650135 0.645974 0.654420 0.650 [0.646-0.654]
Readmission 12 Uno's C-Index 0.645736 0.642092 0.649310 0.646 [0.642-0.649]
Readmission 24 Uno's C-Index 0.634680 0.632041 0.637588 0.635 [0.632-0.638]
Readmission 36 Uno's C-Index 0.630349 0.627824 0.633105 0.630 [0.628-0.633]
Readmission 48 Uno's C-Index 0.627410 0.625420 0.629568 0.627 [0.625-0.630]
Readmission 60 Uno's C-Index 0.624965 0.622992 0.627068 0.625 [0.623-0.627]
Readmission 72 Uno's C-Index 0.623570 0.621243 0.626010 0.624 [0.621-0.626]
Readmission 84 Uno's C-Index 0.622149 0.619806 0.624602 0.622 [0.620-0.625]
Readmission 96 Uno's C-Index 0.619740 0.617483 0.622041 0.620 [0.617-0.622]
Readmission 108 Uno's C-Index 0.618638 0.616162 0.621170 0.619 [0.616-0.621]
Readmission Global Uno's C-Index 0.616215 0.613803 0.618803 0.616 [0.614-0.619]
Code
#@title 📊 Step 7: Final Reporting & Plots (XGB7 Readm Corr - Local/Positron Version)

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib as mpl
from datetime import datetime
from IPython.display import display, Markdown
from pathlib import Path

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
FIGS_DIR = PROJECT_ROOT / "_figs"
FIGS_DIR.mkdir(parents=True, exist_ok=True)

display(Markdown(f"📁 Using PROJECT_ROOT: `{PROJECT_ROOT}`"))

# --- 1. CONFIGURATION ---
# Font Settings: Times New Roman (standard scientific publishing style)
mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "axes.labelsize": 14,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "figure.dpi": 300
})

timestamp = datetime.now().strftime("%Y%m%d_%H%M")

# --- 2. ROBUST DATA LOADING ---
try:
    if 'metrics_log_comb' not in locals():
        if 'metrics_log' in locals():
             metrics_log_comb = metrics_log
        else:
            raise NameError("metrics_log_comb not found in memory.")

    if isinstance(metrics_log_comb, pd.DataFrame):
        display(Markdown("✅ metrics_log_comb is already a DataFrame. Using directly."))
        df_res = metrics_log_comb.copy()
    elif isinstance(metrics_log_comb, dict):
        display(Markdown("✅ metrics_log_comb is a dictionary. Converting to DataFrame..."))
        df_res = pd.DataFrame(metrics_log_comb['data'], columns=metrics_log_comb['columns'])
    elif isinstance(metrics_log_comb, list):
         display(Markdown("✅ metrics_log_comb is a list. Converting to DataFrame..."))
         df_res = pd.DataFrame(metrics_log_comb)
    else:
        raise TypeError(f"Unexpected type for metrics_log_comb: {type(metrics_log_comb)}")

    display(Markdown(f"📊 Data loaded: {len(df_res)} rows."))

except NameError:
    display(Markdown("❌ Error: 'metrics_log_comb' variable not found in memory."))
    raise

# --- 3. AGGREGATE STATS (Mean, Std, SE, CI) ---
grouped = df_res.groupby(['Risk', 'Metric', 'Time'])['Value']
df_agg = grouped.agg(['mean', 'std', 'count']).reset_index()
df_agg['Time_num'] = pd.to_numeric(df_agg['Time'], errors='coerce')

cis = []
for name, group in grouped:
    values = group.dropna().values
    if len(values) > 1:
        lower = np.percentile(values, 2.5)
        upper = np.percentile(values, 97.5)
    else:
        val = values[0] if len(values) > 0 else np.nan
        lower = upper = val
    cis.append((name[0], name[1], name[2], lower, upper))

ci_df = pd.DataFrame(cis, columns=['Risk', 'Metric', 'Time', 'lower', 'upper'])
df_agg = df_agg.merge(ci_df, on=['Risk', 'Metric', 'Time'])
df_agg['Display'] = df_agg.apply(lambda x: f"{x['mean']:.3f} ({x['lower']:.3f}-{x['upper']:.3f})", axis=1)


summary_filename = FIGS_DIR / f"XGB7_comb_final_summary_{timestamp}.csv"
df_agg.to_csv(summary_filename, index=False)
display(Markdown(f"💾 Saved summary stats to local directory: {summary_filename}"))

# --- 4. PLOT 1: CLASSIFICATION METRICS ---
print("\n" + "="*80)
display(Markdown("📈 PLOT: Classification Metrics"))
print("="*80)

fig, axes = plt.subplots(1, 2, figsize=(18, 6), sharey=True)
colors = {'Sens': '#1f77b4', 'Spec': '#ff7f0e', 'PPV': '#2ca02c', 'NPV': '#d62728'}
markers = {'Sens': 'o', 'Spec': 's', 'PPV': '^', 'NPV': 'v'}

for i, risk in enumerate(['Death', 'Readmission']):
    ax = axes[i]
    subset = df_agg[(df_agg['Risk'] == risk) & (df_agg['Time_num'] <= 96) & (df_agg['Metric'].isin(colors.keys()))]
    for m in colors.keys():
        d = subset[subset['Metric'] == m]
        if not d.empty:
            ax.plot(d['Time_num'], d['mean'], label=m, color=colors[m], marker=markers[m], linewidth=2)
            ax.fill_between(d['Time_num'], d['lower'], d['upper'], color=colors[m], alpha=0.15)
    ax.set_title(None)#ax.set_title(f"{risk}: Classification Performance", fontsize=14, fontweight='bold')
    ax.set_xlabel("Time (Months)")
    ax.set_ylabel("Score")
    ax.legend(loc='lower right')
    ax.grid(True, linestyle='--', alpha=0.5)

plt.tight_layout()
fname_cls = f"XGB7_comb_Classification_Metrics_{timestamp}"
# Plot 1 saves
plt.savefig(FIGS_DIR / f"{fname_cls}.pdf", bbox_inches="tight")
plt.savefig(FIGS_DIR / f"{fname_cls}.png", dpi=300, bbox_inches="tight")
display(Markdown(f"💾 Saved plot: `{FIGS_DIR / f'{fname_cls}.png'}`"))
plt.show()

# --- 5. PLOT 2: UNO'S C-INDEX ---
print("\n" + "="*80)
display(Markdown("📈 PLOT: Discrimination (Uno's C-Index)"))
print("="*80)

plt.figure(figsize=(10, 5))
uno_df = df_agg[df_agg['Metric'] == "Uno's C-Index"].copy()
uno_df['Time_num'] = pd.to_numeric(uno_df['Time'], errors='coerce')
risk_colors = {'Death': '#D62728', 'Readmission': '#1F77B4'}

for risk in ['Death', 'Readmission']:
    d = uno_df[uno_df['Risk'] == risk]
    if not d.empty:
        d = d.sort_values('Time_num')
        d_plot = d.dropna(subset=['Time_num'])
        plt.plot(d_plot['Time_num'], d_plot['mean'], marker='o', label=f"{risk} (Mean)", color=risk_colors[risk], linewidth=2)
        plt.fill_between(d_plot['Time_num'], d_plot['lower'], d_plot['upper'], color=risk_colors[risk], alpha=0.15, label=f"{risk} (95% CI)")

plt.title(None)#plt.title("Time-Dependent Discrimination (Uno's C-Index)", fontsize=14, fontweight='bold')
plt.xlabel("Months")
plt.ylabel("Concordance Index")
plt.ylim(0.4, 1.0)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()

fname_uno = f"XGB7_comb_Uno_CIndex_CI_{timestamp}"
plt.savefig(FIGS_DIR / f"{fname_uno}.pdf", bbox_inches="tight")
plt.savefig(FIGS_DIR / f"{fname_uno}.png", dpi=300, bbox_inches="tight")
display(Markdown(f"💾 Saved plot: `{FIGS_DIR / f'{fname_uno}.png'}`"))
plt.show()

# --- 6. PLOT 3: INTEGRATED BRIER SCORE (IBS) ---
print("\n" + "="*80)
display(Markdown("📉 PLOT: Time-Dependent Integrated Brier Score (IBS)"))
print("="*80)

plt.figure(figsize=(10, 5))
ibs_df = df_agg[df_agg['Metric'] == 'IBS'].copy()
ibs_df['Time_num'] = pd.to_numeric(ibs_df['Time'], errors='coerce')

for risk in ['Death', 'Readmission']:
    d = ibs_df[ibs_df['Risk'] == risk]
    if not d.empty:
        d = d.sort_values('Time_num')
        d_plot = d.dropna(subset=['Time_num'])
        plt.plot(d_plot['Time_num'], d_plot['mean'], marker='o', label=f"{risk} (Mean)", color=risk_colors[risk], linewidth=2)
        plt.fill_between(d_plot['Time_num'], d_plot['lower'], d_plot['upper'], color=risk_colors[risk], alpha=0.15, label=f"{risk} (95% CI)")

plt.title(None)#plt.title("Time-Dependent Integrated Brier Score (IBS)", fontsize=14, fontweight='bold')
plt.xlabel("Months")
plt.ylabel("Integrated Brier Score (lower is better)")
ymin = max(0.0, ibs_df['lower'].min() - 0.01) if not ibs_df.empty else 0.0
ymax = ibs_df['upper'].max() + 0.01 if not ibs_df.empty else 1.0
plt.ylim(ymin, ymax)
plt.legend()
plt.grid(True, linestyle='--', alpha=0.5)
plt.tight_layout()

fname_ibs = f"XGB7_comb_IBS_CI_{timestamp}"
plt.savefig(FIGS_DIR / f"{fname_ibs}.pdf", bbox_inches="tight")
plt.savefig(FIGS_DIR / f"{fname_ibs}.png", dpi=300, bbox_inches="tight")
display(Markdown(f"💾 Saved plot: `{FIGS_DIR / f'{fname_ibs}.png'}`"))
plt.show()

# --- 7. TABLES ---
print("\n" + "="*80)
display(Markdown("📋 TABLES"))
print("="*80)

display(Markdown("\n>>> GLOBAL PERFORMANCE"))
global_df = df_agg[df_agg['Time'] == 'Global'][['Risk', 'Metric', 'Display']].reset_index(drop=True)
if not global_df.empty:
    display(global_df)
else:
    display(Markdown("No global metrics found in aggregation."))

display(Markdown("\n>>> DETAILED METRICS BY TIME"))
pivot_df = df_agg.pivot_table(index=['Risk', 'Time_num'], columns='Metric', values='Display', aggfunc='first')
display(pivot_df)

📁 Using PROJECT_ROOT: G:\My Drive\Alvacast\SISTRAT 2023\cons

✅ metrics_log_comb is already a DataFrame. Using directly.

📊 Data loaded: 5450 rows.

💾 Saved summary stats to local directory: G:Drive_figs_comb_final_summary_20260306_1822.csv

📈 PLOT: Classification Metrics

💾 Saved plot: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_Classification_Metrics_20260306_1822.png

📈 PLOT: Discrimination (Uno’s C-Index)

💾 Saved plot: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_Uno_CIndex_CI_20260306_1822.png

📉 PLOT: Time-Dependent Integrated Brier Score (IBS)

💾 Saved plot: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_IBS_CI_20260306_1822.png

📋 TABLES

GLOBAL PERFORMANCE

Risk Metric Display
0 Death IBS 0.036 (0.035-0.037)
1 Death Uno's C-Index 0.744 (0.733-0.766)
2 Readmission IBS 0.169 (0.167-0.170)
3 Readmission Uno's C-Index 0.616 (0.609-0.627)

DETAILED METRICS BY TIME

Metric Brier Score F1 IBS NPV PPV Sens Spec Threshold Uno's C-Index
Risk Time_num
Death 3.0 0.005 (0.004-0.006) 0.068 (0.026-0.104) NaN 0.995 (0.994-0.996) 0.067 (0.027-0.164) 0.081 (0.025-0.146) 0.994 (0.989-0.998) 0.042 (0.032-0.056) 0.767 (0.738-0.791)
6.0 0.007 (0.006-0.008) 0.091 (0.056-0.118) 0.006 (0.005-0.007) 0.993 (0.992-0.995) 0.088 (0.047-0.122) 0.101 (0.053-0.155) 0.992 (0.988-0.996) 0.053 (0.045-0.069) 0.756 (0.735-0.779)
9.0 0.009 (0.008-0.011) 0.093 (0.070-0.114) 0.007 (0.006-0.008) 0.992 (0.990-0.993) 0.088 (0.058-0.109) 0.104 (0.064-0.149) 0.990 (0.985-0.994) 0.062 (0.052-0.074) 0.756 (0.740-0.770)
12.0 0.011 (0.010-0.013) 0.105 (0.088-0.130) 0.008 (0.007-0.009) 0.990 (0.988-0.991) 0.098 (0.073-0.123) 0.120 (0.076-0.182) 0.987 (0.978-0.993) 0.069 (0.054-0.082) 0.750 (0.726-0.775)
24.0 0.021 (0.020-0.023) 0.169 (0.145-0.189) 0.012 (0.011-0.013) 0.982 (0.979-0.983) 0.146 (0.130-0.174) 0.203 (0.157-0.247) 0.973 (0.965-0.979) 0.090 (0.079-0.100) 0.761 (0.736-0.790)
36.0 0.032 (0.031-0.033) 0.215 (0.183-0.240) 0.016 (0.015-0.017) 0.972 (0.970-0.974) 0.202 (0.188-0.219) 0.232 (0.173-0.282) 0.967 (0.960-0.974) 0.113 (0.102-0.125) 0.762 (0.747-0.780)
48.0 0.044 (0.042-0.045) 0.267 (0.244-0.300) 0.020 (0.019-0.021) 0.963 (0.960-0.967) 0.246 (0.229-0.267) 0.298 (0.233-0.381) 0.952 (0.940-0.965) 0.117 (0.106-0.134) 0.762 (0.750-0.772)
60.0 0.059 (0.056-0.062) 0.331 (0.304-0.361) 0.023 (0.022-0.024) 0.951 (0.946-0.956) 0.320 (0.301-0.334) 0.346 (0.284-0.425) 0.945 (0.930-0.956) 0.126 (0.115-0.135) 0.762 (0.747-0.778)
72.0 0.078 (0.077-0.080) 0.389 (0.343-0.415) 0.027 (0.026-0.027) 0.934 (0.927-0.939) 0.394 (0.360-0.420) 0.388 (0.292-0.449) 0.936 (0.917-0.956) 0.130 (0.112-0.149) 0.757 (0.747-0.764)
84.0 0.106 (0.103-0.110) 0.468 (0.422-0.494) 0.030 (0.029-0.030) 0.916 (0.908-0.923) 0.465 (0.438-0.490) 0.473 (0.396-0.524) 0.913 (0.899-0.927) 0.116 (0.102-0.124) 0.752 (0.741-0.761)
96.0 0.145 (0.141-0.150) 0.540 (0.524-0.551) 0.033 (0.032-0.034) 0.888 (0.877-0.896) 0.543 (0.513-0.586) 0.539 (0.479-0.585) 0.889 (0.866-0.917) 0.105 (0.094-0.121) 0.747 (0.735-0.759)
108.0 0.211 (0.203-0.217) 0.620 (0.604-0.632) 0.036 (0.035-0.037) 0.835 (0.823-0.851) 0.631 (0.596-0.657) 0.610 (0.577-0.659) 0.847 (0.814-0.867) 0.094 (0.086-0.105) 0.742 (0.730-0.756)
Readmission 3.0 0.041 (0.039-0.043) 0.138 (0.129-0.148) NaN 0.970 (0.967-0.974) 0.080 (0.073-0.087) 0.505 (0.440-0.574) 0.733 (0.690-0.769) 0.049 (0.047-0.052) 0.673 (0.645-0.687)
6.0 0.067 (0.064-0.069) 0.199 (0.186-0.213) 0.053 (0.051-0.055) 0.947 (0.943-0.951) 0.124 (0.113-0.135) 0.499 (0.440-0.571) 0.717 (0.672-0.758) 0.081 (0.077-0.085) 0.657 (0.638-0.674)
9.0 0.086 (0.083-0.088) 0.239 (0.227-0.255) 0.064 (0.062-0.066) 0.929 (0.924-0.933) 0.157 (0.144-0.169) 0.505 (0.450-0.555) 0.703 (0.665-0.752) 0.106 (0.101-0.112) 0.650 (0.634-0.666)
12.0 0.102 (0.100-0.106) 0.273 (0.262-0.288) 0.073 (0.071-0.075) 0.912 (0.907-0.916) 0.187 (0.175-0.200) 0.506 (0.450-0.542) 0.699 (0.670-0.731) 0.127 (0.122-0.132) 0.646 (0.630-0.658)
24.0 0.148 (0.146-0.149) 0.350 (0.343-0.358) 0.098 (0.097-0.100) 0.854 (0.850-0.857) 0.270 (0.260-0.281) 0.497 (0.475-0.515) 0.686 (0.662-0.708) 0.186 (0.183-0.189) 0.635 (0.628-0.646)
36.0 0.182 (0.181-0.183) 0.398 (0.392-0.406) 0.116 (0.115-0.117) 0.798 (0.794-0.801) 0.335 (0.327-0.348) 0.490 (0.466-0.511) 0.674 (0.650-0.701) 0.227 (0.222-0.232) 0.630 (0.624-0.642)
48.0 0.211 (0.210-0.213) 0.433 (0.420-0.442) 0.129 (0.128-0.130) 0.738 (0.735-0.742) 0.395 (0.389-0.405) 0.480 (0.442-0.502) 0.666 (0.640-0.699) 0.258 (0.252-0.266) 0.627 (0.622-0.637)
60.0 0.241 (0.237-0.244) 0.463 (0.441-0.479) 0.139 (0.139-0.140) 0.667 (0.659-0.677) 0.463 (0.451-0.471) 0.463 (0.422-0.498) 0.666 (0.634-0.701) 0.284 (0.276-0.295) 0.625 (0.620-0.634)
72.0 0.270 (0.266-0.273) 0.485 (0.454-0.507) 0.148 (0.148-0.149) 0.583 (0.574-0.591) 0.535 (0.525-0.545) 0.444 (0.397-0.477) 0.668 (0.640-0.709) 0.309 (0.300-0.321) 0.624 (0.617-0.634)
84.0 0.302 (0.298-0.305) 0.504 (0.478-0.530) 0.156 (0.155-0.157) 0.476 (0.469-0.483) 0.614 (0.603-0.628) 0.428 (0.394-0.467) 0.658 (0.622-0.688) 0.329 (0.318-0.339) 0.622 (0.616-0.630)
96.0 0.329 (0.325-0.334) 0.514 (0.489-0.537) 0.163 (0.161-0.164) 0.368 (0.356-0.380) 0.700 (0.690-0.712) 0.406 (0.375-0.432) 0.665 (0.644-0.696) 0.350 (0.344-0.359) 0.620 (0.614-0.627)
108.0 0.357 (0.351-0.363) 0.539 (0.510-0.573) 0.169 (0.167-0.170) 0.247 (0.236-0.259) 0.800 (0.795-0.808) 0.406 (0.373-0.448) 0.656 (0.605-0.692) 0.363 (0.350-0.376) 0.619 (0.609-0.627)
Code
#@title 📊 Step 5b: Robust Time-to-Event Calibration Plots (Fixed Colors + Timestamp File Pick)

import re
import pickle
from pathlib import Path
from datetime import datetime

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
from IPython.display import display, Markdown

TARGET_TIMES = [3, 6, 12, 36, 60, 96]
RISK_GROUPS = 10
PREFER = "latest"  # "latest" or "earliest"

FILE_PATTERNS = [
    "xgb6_corr_DUAL_final_ev_hyp_*_mar26.pkl",
]

timestamp_now = datetime.now().strftime("%Y%m%d_%H%M")
TS_RE = re.compile(r"_(\d{8}_\d{4})_mar26\.pkl$")

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
IN_DIR = PROJECT_ROOT / "_out"
OUT_DIR = PROJECT_ROOT / "_figs"
OUT_DIR.mkdir(parents=True, exist_ok=True)

if not IN_DIR.exists():
    raise FileNotFoundError(f"Input directory does not exist: {IN_DIR}")

def pick_file_by_timestamp(in_dir: Path, patterns, prefer="latest"):
    prefer = str(prefer).lower().strip()
    if prefer not in {"latest", "earliest"}:
        raise ValueError("PREFER must be 'latest' or 'earliest'.")

    candidates = []
    for pat in patterns:
        for p in in_dir.glob(pat):
            m = TS_RE.search(p.name)
            if not m:
                continue
            try:
                ts = datetime.strptime(m.group(1), "%Y%m%d_%H%M")
            except ValueError:
                continue
            candidates.append((ts, p))

    if not candidates:
        raise FileNotFoundError(f"No valid files found in '{in_dir}' for patterns: {patterns}")

    candidates.sort(key=lambda x: x[0])
    return candidates[-1] if prefer == "latest" else candidates[0]


def get_time_idx(eval_times, t):
    arr = np.asarray(eval_times, dtype=float)
    idx = np.where(np.isclose(arr, float(t)))[0]
    return int(idx[0]) if idx.size else None

def get_calibration_data(raw_log, risk_type, time_point, n_groups=10):
    parts = []

    for entry in raw_log:
        idx = get_time_idx(entry.get("eval_times", []), time_point)
        if idx is None:
            continue

        if risk_type == "Readmission":
            mat = entry.get("probs_readm_matrix")
            y_val = entry.get("y_val_r")
        else:
            mat = entry.get("probs_death_matrix")
            y_val = entry.get("y_val_d")

        if mat is None or y_val is None:
            continue

        mat = np.asarray(mat, dtype=float)
        if mat.ndim != 2 or idx >= mat.shape[1]:
            continue

        surv = mat[:, idx]
        time_arr = np.asarray(y_val["time"], dtype=float)
        event_arr = np.asarray(y_val["event"], dtype=bool).astype(int)

        n = min(len(surv), len(time_arr), len(event_arr))
        if n == 0:
            continue

        prob = 1.0 - np.clip(surv[:n], 0.0, 1.0)
        parts.append(pd.DataFrame({
            "prob": prob,
            "time": time_arr[:n],
            "event": event_arr[:n],
        }))

    if not parts:
        return None

    df = pd.concat(parts, ignore_index=True).dropna(subset=["prob", "time", "event"])
    uniq = int(df["prob"].nunique())
    bins = min(n_groups, uniq)
    if bins < 2:
        return None

    df["group"] = pd.qcut(df["prob"].rank(method="first"), q=bins, labels=False)

    kmf = KaplanMeierFitter()
    out = []

    for g in sorted(df["group"].unique()):
        grp = df[df["group"] == g]
        if grp.empty:
            continue

        kmf.fit(durations=grp["time"], event_observed=grp["event"])
        obs = 1.0 - float(kmf.survival_function_at_times([time_point]).iloc[0])

        out.append({
            "group": int(g),
            "n": int(len(grp)),
            "mean_pred": float(grp["prob"].mean()),
            "obs_freq": float(np.clip(obs, 0.0, 1.0)),
        })

    out_df = pd.DataFrame(out).sort_values("mean_pred")
    return out_df if not out_df.empty else None

if 'raw_data_log' in globals() and isinstance(raw_data_log, list) and len(raw_data_log) > 0:
    display(Markdown(f"✅ Using in-memory `raw_data_log` ({len(raw_data_log)} entries)."))
else:
    picked_ts, picked_file = pick_file_by_timestamp(IN_DIR, FILE_PATTERNS, prefer=PREFER)
    with picked_file.open("rb") as f:
        raw_data_log = pickle.load(f)
    display(Markdown(f"Loaded file: `{picked_file}` ({picked_ts:%Y-%m-%d %H:%M})"))

display(Markdown(f"Using PROJECT_ROOT: `{PROJECT_ROOT}`"))
display(Markdown(f"Saving figures to: `{OUT_DIR}`"))

times_sorted = sorted(TARGET_TIMES)
cmap = plt.get_cmap("viridis", len(times_sorted))
time_colors = {t: cmap(i) for i, t in enumerate(times_sorted)}

def plot_calibration(risk_name, time_horizons):
    fig, ax = plt.subplots(figsize=(10, 8))
    any_line = False

    for t in sorted(time_horizons):
        cal = get_calibration_data(raw_data_log, risk_name, t, n_groups=RISK_GROUPS)
        if cal is None or cal.empty:
            continue

        ax.plot(
            cal["mean_pred"], cal["obs_freq"],
            marker="o", linestyle="-", linewidth=2,
            color=time_colors[t], label=f"{t} Months"
        )
        any_line = True

    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect Calibration")
    ax.set_title(None)
    ax.set_xlabel("Predicted Probability", fontsize=13)
    ax.set_ylabel("Observed Frequency (Kaplan-Meier)", fontsize=13)
    ax.grid(True, alpha=0.3)

    if risk_name == "Death":
        ax.set_xlim(0, 0.35)
        ax.set_ylim(0, 0.45)
    else:
        ax.set_xlim(0, 1.0)
        ax.set_ylim(0, 1.0)

    if any_line:
        ax.legend(fontsize=14, frameon=True, title="Horizon")

    fig.tight_layout()

    png = OUT_DIR / f"XGB7_comb_Calibration_{risk_name}_{timestamp_now}.png"
    pdf = OUT_DIR / f"XGB7_comb_Calibration_{risk_name}_{timestamp_now}.pdf"
    fig.savefig(png, dpi=300, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight")
    display(Markdown(f"Saved: `{png}`"))
    plt.show()

✅ Using in-memory raw_data_log (25 entries).

Using PROJECT_ROOT: G:\My Drive\Alvacast\SISTRAT 2023\cons

Saving figures to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs

Code
plot_calibration("Readmission", TARGET_TIMES)

Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_Calibration_Readmission_20260306_1822.png

Code
plot_calibration("Death", TARGET_TIMES)

Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_Calibration_Death_20260306_1822.png

10 Take-Home Messages

  1. This script produces competing-risk calibration curves for readmission using Aalen-Johansen.

  2. Predictions come from cross-validated folds, avoiding optimistic bias.

  3. It correctly treats death as a competing event, not simple censoring.

  4. Observed risks are estimated using the Aalen-Johansen estimator, which is appropriate for CIF.

  5. Predicted risk is defined as: 1 − S_readmission(t) from the cause-specific Cox model.

  6. Patients are grouped into quantile-based risk bins (default 10).

  7. Calibration is evaluated at multiple time horizons simultaneously.

  8. A patient-level master dataset is saved for future bootstrap calibration inference.

  9. The pipeline separates modeling and calibration — improving reproducibility.

  10. Output figures are publication-ready (PNG 300dpi + PDF).

5 Assumptions of This Code

  1. Cause-Specific Hazard Validity
    The readmission survival probabilities generated via Breslow are correctly specified under a cause-specific Cox model.

  2. Independence of Competing Events
    Death and readmission are assumed to follow the standard competing risks framework (non-informative censoring conditional on covariates).

  3. Proper Cross-Validation Aggregation
    Combining all validation folds into a single master dataset assumes that pooling cross-validated predictions is unbiased.
    (This is generally acceptable.)

  4. Quantile Binning Adequacy
    Risk binning assumes that quantile groups meaningfully represent calibration strata.
    Calibration results can change with:

    • Different number of bins
    • Different minimum bin size
  5. Aalen-Johansen Stability
    Observed CIF estimation assumes:

    • Enough events per bin
    • Stable risk sets
    • No extreme sparsity in late follow-up
Code
#@title 📈 Step 7a: Calibration Plot (No boostrap)

import os, re, glob, pickle, time
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import AalenJohansenFitter
from IPython.display import display, Markdown

TARGET_TIMES = [3, 6, 12, 36, 60, 96]
RISK_GROUPS = 10
IN_DIR = "_out"
OUT_DIR = "_figs"
PREFER = "latest"

# NEW: master calibration data export config
MASTER_SAVE_MODE = "both"   # "parquet", "pkl", or "both"
MASTER_PARQUET_COMPRESSION = "snappy"

FILE_PATTERNS = ["xgb6_corr_DUAL_final_ev_hyp_*_mar26.pkl"]
TS_RE = re.compile(r"_(\d{8}_\d{4})_mar26\.pkl$")
timestamp_now = datetime.now().strftime("%Y%m%d_%H%M")

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
IN_DIR = PROJECT_ROOT / "_out"
OUT_DIR = PROJECT_ROOT / "_figs"
IN_DIR.mkdir(parents=True, exist_ok=True)
OUT_DIR.mkdir(parents=True, exist_ok=True)

def log(msg):
    print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True)

def pick_file_by_timestamp(in_dir, patterns, prefer="latest"):
    candidates = []
    for pat in patterns:
        for p in glob.glob(os.path.join(in_dir, pat)):
            m = TS_RE.search(os.path.basename(p))
            if not m:
                continue
            try:
                ts = datetime.strptime(m.group(1), "%Y%m%d_%H%M")
            except ValueError:
                continue
            candidates.append((ts, p))
    if not candidates:
        raise FileNotFoundError(f"No valid files found in '{in_dir}' for {patterns}")
    candidates.sort(key=lambda x: x[0])
    return candidates[-1] if prefer == "latest" else candidates[0]

def get_time_idx(eval_times, t):
    arr = np.asarray(eval_times, dtype=float)
    idx = np.where(np.isclose(arr, float(t)))[0]
    return int(idx[0]) if idx.size else None

def build_competing_outcome(y_r, y_d):
    # 0=censored, 1=readmission, 2=death(competing)
    tr = np.asarray(y_r["time"], dtype=float)
    td = np.asarray(y_d["time"], dtype=float)
    er = np.asarray(y_r["event"], dtype=bool)
    ed = np.asarray(y_d["event"], dtype=bool)

    t = np.minimum(tr, td)
    ev = np.zeros(len(t), dtype=np.int8)

    readm_first = er & (~ed | (tr < td))
    death_first = ed & (~er | (td <= tr))  # tie -> death
    ev[readm_first] = 1
    ev[death_first] = 2
    return t, ev

def aj_cif_at_time(ajf, t):
    if hasattr(ajf, "cumulative_density_at_times"):
        try:
            vals = ajf.cumulative_density_at_times([t])
            return float(np.asarray(vals).reshape(-1)[0])
        except Exception:
            pass
    try:
        cd = ajf.cumulative_density_
        s = cd.iloc[:, 0] if isinstance(cd, pd.DataFrame) else pd.Series(cd)
        i = s.index.searchsorted(t, side="right") - 1
        return float(s.iloc[i]) if i >= 0 else 0.0
    except Exception:
        return np.nan

def build_master_df(raw_log, target_times):
    parts = []
    tset = sorted(set(int(x) for x in target_times))
    n_total = len(raw_log)

    for k, e in enumerate(raw_log, 1):
        y_r = e.get("y_val_r")
        y_d = e.get("y_val_d")
        if y_r is None or y_d is None:
            continue

        idx_map = {t: get_time_idx(e.get("eval_times", []), t) for t in tset}
        if all(v is None for v in idx_map.values()):
            continue

        t, ev = build_competing_outcome(y_r, y_d)
        pmat = np.asarray(e.get("probs_readm_matrix"), dtype=float)

        if pmat.ndim != 2 or pmat.shape[0] != len(t):
            continue

        block = {
            "time": np.asarray(t, dtype=float),
            "event": np.asarray(ev, dtype=np.int8)
        }

        for tt, idx in idx_map.items():
            if idx is not None and idx < pmat.shape[1]:
                block[f"pred_{tt}"] = 1.0 - np.clip(pmat[:, idx], 0.0, 1.0)

        parts.append(pd.DataFrame(block))

        if (k % 5 == 0) or (k == n_total):
            log(f"Prepared fold-block {k}/{n_total}")

    if not parts:
        return None

    df = pd.concat(parts, ignore_index=True)
    df = df.dropna(subset=["time", "event"])

    if not df["event"].isin([0, 1, 2]).all():
        bad = df.loc[~df["event"].isin([0, 1, 2]), "event"].unique()
        raise ValueError(f"Unexpected event codes: {bad}")

    return df

# NEW: save patient-level master df for later bootstrap
def save_master_df(df_master, out_dir, stamp, mode="both", compression="snappy"):
    os.makedirs(out_dir, exist_ok=True)
    base = os.path.join(out_dir, f"xgb7_calibration_master_{stamp}_mar26")
    saved_paths = []

    df_save = df_master.copy()

    # optional size optimization
    if "time" in df_save.columns:
        df_save["time"] = df_save["time"].astype("float32")
    if "event" in df_save.columns:
        df_save["event"] = df_save["event"].astype("int8")
    pred_cols = [c for c in df_save.columns if c.startswith("pred_")]
    for c in pred_cols:
        df_save[c] = df_save[c].astype("float32")

    parquet_ok = False
    if mode in ("parquet", "both"):
        parquet_path = base + ".parquet"
        try:
            df_save.to_parquet(parquet_path, index=False, compression=compression)
            log(f"Saved master df: {parquet_path}")
            saved_paths.append(parquet_path)
            parquet_ok = True
        except Exception as e:
            log(f"Parquet save failed: {e}")

    if mode in ("pkl", "both") or (mode == "parquet" and not parquet_ok):
        pkl_path = base + ".pkl"
        df_save.to_pickle(pkl_path)
        log(f"Saved master df: {pkl_path}")
        saved_paths.append(pkl_path)

    return saved_paths

def get_calibration_data_aj(df_master, time_point, n_groups=5, min_bin_n=30, round_decimals=3):
    col = f"pred_{int(time_point)}"
    if col not in df_master.columns:
        return None

    df = df_master[["time", "event", col]].rename(columns={col: "pred"}).dropna().copy()

    bins = min(n_groups, int(df["pred"].nunique()))
    if bins < 2:
        return None

    # binning by prediction only
    df["bin"] = pd.qcut(
        df["pred"].rank(method="first"),
        q=bins,
        labels=False,
        duplicates="drop"
    )

    n_bins_real = int(df["bin"].nunique())
    if n_bins_real < 2:
        return None

    try:
        ajf = AalenJohansenFitter(calculate_variance=False)  # speed-up
    except TypeError:
        ajf = AalenJohansenFitter()

    rows = []
    t0 = time.perf_counter()

    for j, b in enumerate(sorted(df["bin"].dropna().unique()), 1):
        g = df[df["bin"] == b].copy()
        n = len(g)
        if n < min_bin_n:
            continue

        g["time"] = np.round(g["time"].to_numpy(dtype=float), round_decimals)

        n_readm = int((g["event"] == 1).sum())
        n_death = int((g["event"] == 2).sum())

        if n_readm == 0:
            obs_cif = 0.0
        else:
            try:
                ajf.fit(g["time"], g["event"], event_of_interest=1)
                obs_cif = aj_cif_at_time(ajf, time_point)
            except Exception as ex:
                log(f"[t={time_point}] bin={int(b)} AJ failed: {ex}")
                continue

        rows.append({
            "bin": int(b),
            "n": int(n),
            "n_readm": n_readm,
            "n_death": n_death,
            "mean_pred": float(g["pred"].mean()),
            "obs_cif_aj": float(np.clip(obs_cif, 0.0, 1.0)),
        })

        log(f"[t={time_point}] bin {j}/{n_bins_real} done (n={n}, readm={n_readm}, death={n_death})")

    log(f"[t={time_point}] finished in {time.perf_counter() - t0:.1f}s")

    out = pd.DataFrame(rows).sort_values("mean_pred")
    return out if not out.empty else None

def plot_calibration_aj(raw_log, target_times, stamp, out_dir="_figs"):
    os.makedirs(out_dir, exist_ok=True)
    os.makedirs(IN_DIR, exist_ok=True)

    log("Building master dataframe once (for all horizons)...")
    t0 = time.perf_counter()
    df_master = build_master_df(raw_log, target_times)
    if df_master is None or df_master.empty:
        raise RuntimeError("No usable calibration data could be assembled.")
    log(f"Master DF ready: {len(df_master):,} rows in {time.perf_counter() - t0:.1f}s")

    # NEW: save master data for later bootstrap
    save_master_df(
        df_master=df_master,
        out_dir=IN_DIR,
        stamp=stamp,
        mode=MASTER_SAVE_MODE,
        compression=MASTER_PARQUET_COMPRESSION
    )

    fig, ax = plt.subplots(figsize=(10, 8))
    cmap = plt.get_cmap("viridis", len(target_times))
    any_line = False

    for i, t in enumerate(sorted(target_times)):
        log(f"Starting horizon t={t}m")
        cal = get_calibration_data_aj(
            df_master, time_point=t, n_groups=RISK_GROUPS, min_bin_n=30, round_decimals=3
        )
        if cal is None or cal.empty:
            log(f"[t={t}] skipped (no valid bins)")
            continue

        cal.to_csv(os.path.join(IN_DIR, f"xgb7_calibration_aj_t{t}m_{timestamp_now}.csv"), index=False)

        ax.plot(
            cal["mean_pred"], cal["obs_cif_aj"],
            marker="o", linewidth=2, color=cmap(i), label=f"{t} months"
        )
        any_line = True

    ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect calibration")
    ax.set_xlim(0, 1)
    ax.set_ylim(0, 1)
    ax.set_xlabel("Predicted probability (raw 1 - S_readmission)")
    ax.set_ylabel("Observed CIF (Aalen-Johansen)")
    ax.grid(True, alpha=0.3)
    if any_line:
        ax.legend()
    fig.tight_layout()

    png = os.path.join(out_dir, f"XGB7_comb_Calibration_AJ_Readmission_{timestamp_now}.png")
    pdf = png.replace(".png", ".pdf")
    fig.savefig(png, dpi=300, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight")
    log(f"Saved: {png}")
    log(f"Saved: {pdf}")
    #plt.show()
    plt.close(fig)   

# -------- Load + run --------
picked_ts, picked_file = pick_file_by_timestamp(IN_DIR, FILE_PATTERNS, prefer=PREFER)
with open(picked_file, "rb") as f:
    raw_data_log = pickle.load(f)

log(f"Loaded: {picked_file}")
stamp = picked_ts.strftime("%Y%m%d_%H%M")

# Replace the Load + run block:
if 'raw_data_log' in globals() and isinstance(raw_data_log, list) and len(raw_data_log) > 0:
    log(f"Using in-memory raw_data_log ({len(raw_data_log)} entries).")
    stamp = timestamp_now
else:
    picked_ts, picked_file = pick_file_by_timestamp(IN_DIR, FILE_PATTERNS, prefer=PREFER)
    with open(picked_file, "rb") as f:
        raw_data_log = pickle.load(f)
    log(f"Loaded: {picked_file}")
    stamp = picked_ts.strftime("%Y%m%d_%H%M")

plot_calibration_aj(raw_data_log, TARGET_TIMES, stamp, out_dir=OUT_DIR)
Code
#@title 📈 Step 7b: Bootstrap Calibration Plot (Aalen-Johansen with 95% CI)

import os, re, glob, time
from datetime import datetime
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import AalenJohansenFitter
from joblib import Parallel, delayed
from IPython.display import display, Markdown
import warnings
import shutil
from pathlib import Path

warnings.filterwarnings("ignore")

# --- 1. CONFIGURATION ---
TARGET_TIMES = [3, 6, 12, 36, 60, 96] # Reduced to key horizons to avoid visual clutter
RISK_GROUPS = 10
N_BOOTSTRAPS = 500  # 200 is a good fast standard. For final paper, increase to 500 or 1000.
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
IN_DIR = PROJECT_ROOT / "_out"
OUT_DIR = PROJECT_ROOT / "_figs"
PREFER = "latest"

os.makedirs(IN_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

timestamp_now = datetime.now().strftime("%Y%m%d_%H%M")

# --- 2. AJ HELPER FUNCTIONS ---
def aj_cif_at_time(ajf, t):
    """Safely extracts the cumulative incidence at a specific time point."""
    if hasattr(ajf, "cumulative_density_at_times"):
        try:
            vals = ajf.cumulative_density_at_times([t])
            return float(np.asarray(vals).reshape(-1)[0])
        except: pass
    try:
        cd = ajf.cumulative_density_
        s = cd.iloc[:, 0] if isinstance(cd, pd.DataFrame) else pd.Series(cd)
        i = s.index.searchsorted(t, side="right") - 1
        return float(s.iloc[i]) if i >= 0 else 0.0
    except: return np.nan

def calculate_aj_for_sample(df_sample, time_point, n_groups, min_bin_n=30):
    """Calculates observed AJ CIF per bin."""
    bins = min(n_groups, int(df_sample["pred"].nunique()))
    if bins < 2:
        return None

    df = df_sample.copy()
    df["bin"] = pd.qcut(
        df["pred"].rank(method="first"),
        q=bins,
        labels=False,
        duplicates="drop"
    )

    bins_real = sorted(df["bin"].dropna().unique())
    if len(bins_real) < 2:
        return None

    results = {}

    for b in bins_real:
        b = int(b)
        g = df[df["bin"] == b]
        if len(g) < min_bin_n:
            results[b] = np.nan
            continue

        n_readm = int((g["event"] == 1).sum())
        if n_readm == 0:
            results[b] = 0.0
            continue

        try:
            try:
                ajf = AalenJohansenFitter(calculate_variance=False, jitter_level=1e-5)
            except TypeError:
                ajf = AalenJohansenFitter(calculate_variance=False)

            # no rounding of time during bootstrap
            ajf.fit(
                g["time"].to_numpy(dtype=float),
                g["event"].to_numpy(dtype=int),
                event_of_interest=1
            )
            v = aj_cif_at_time(ajf, time_point)
            results[b] = float(np.clip(v, 0.0, 1.0)) if np.isfinite(v) else np.nan
        except Exception:
            results[b] = np.nan

    return results if results else None

def bootstrap_iteration(df_horizon, time_point, n_groups):
    """Performs a single complete bootstrap iteration."""
    df_boot = df_horizon.sample(frac=1.0, replace=True)
    return calculate_aj_for_sample(df_boot, time_point, n_groups)

# --- 3. LOAD MASTER DATAFRAME ---
display(Markdown("**Searching for the latest Master Parquet file...**"))
parquet_files = sorted(glob.glob(os.path.join(IN_DIR, "xgb7_calibration_master_*_mar26.parquet")))

if not parquet_files:
    raise FileNotFoundError("No Master Parquet (_mar26) file found. Run Step 7a first.")

latest_master = parquet_files[-1]
display(Markdown(f"📦 **Loading Master DF:** `{latest_master}`"))
df_master = pd.read_parquet(latest_master)

# --- 3.5. Preflight block ---

def preflight_bootstrap_inputs(
    df_master,
    target_times,
    n_bootstraps,
    risk_groups,
    in_dir,
    out_dir,
    min_rows_per_horizon=200,
    min_readm_events=10,
    min_free_gb=0.5,
):
    errors, warns = [], []

    # Path checks
    for name, p in [("IN_DIR", in_dir), ("OUT_DIR", out_dir)]:
        p_str = str(p)
        if not os.path.isabs(p_str):
            warns.append(f"{name} is not absolute: {p_str}")
        try:
            os.makedirs(p_str, exist_ok=True)
        except Exception as e:
            errors.append(f"Cannot create {name}={p_str}: {e}")

    # Global dataframe checks
    if df_master is None or df_master.empty:
        errors.append("df_master is empty.")
        return [], errors, warns

    required = {"time", "event"}
    missing = required - set(df_master.columns)
    if missing:
        errors.append(f"Missing required columns: {sorted(missing)}")

    if "event" in df_master.columns:
        ev = pd.Series(df_master["event"]).dropna()
        bad_events = sorted(set(ev.unique()) - {0, 1, 2})
        if bad_events:
            errors.append(f"Invalid event codes found: {bad_events}")

    if "time" in df_master.columns:
        t = pd.to_numeric(df_master["time"], errors="coerce")
        if t.isna().all():
            errors.append("Column 'time' is all NaN after numeric conversion.")
        if np.isinf(t.to_numpy(dtype=float)).any():
            errors.append("Column 'time' contains inf values.")
        if (t.dropna() < 0).any():
            warns.append("Column 'time' contains negative values.")

    if n_bootstraps < 10:
        warns.append(f"N_BOOTSTRAPS={n_bootstraps} is very low.")
    if risk_groups < 2:
        errors.append(f"RISK_GROUPS must be >=2. Current: {risk_groups}")

    # Horizon-level checks
    valid_times = []
    for t in sorted(target_times):
        col = f"pred_{t}"
        if col not in df_master.columns:
            warns.append(f"Missing column: {col}")
            continue

        d = df_master[["time", "event", col]].rename(columns={col: "pred"}).dropna()
        if d.empty:
            warns.append(f"{col}: no non-null rows")
            continue

        uniq_pred = int(d["pred"].nunique())
        bins_eff = min(risk_groups, uniq_pred)
        n_readm = int((d["event"] == 1).sum())

        if len(d) < min_rows_per_horizon:
            warns.append(f"{col}: too few rows ({len(d)})")
            continue
        if bins_eff < 2:
            warns.append(f"{col}: <2 effective bins (uniq_pred={uniq_pred})")
            continue
        if n_readm < min_readm_events:
            warns.append(f"{col}: too few readmission events ({n_readm})")
            continue

        valid_times.append(int(t))

    if not valid_times:
        errors.append("No valid horizons left after preflight checks.")

    return valid_times, errors, warns


valid_times, preflight_errors, preflight_warnings = preflight_bootstrap_inputs(
    df_master=df_master,
    target_times=TARGET_TIMES,
    n_bootstraps=N_BOOTSTRAPS,
    risk_groups=RISK_GROUPS,
    in_dir=IN_DIR,
    out_dir=OUT_DIR,
)

for w in preflight_warnings:
    display(Markdown(f"⚠️ {w}"))

if preflight_errors:
    raise RuntimeError("Preflight failed:\n- " + "\n- ".join(preflight_errors))

# Use only validated horizons
TARGET_TIMES = valid_times
display(Markdown(f"✅ Preflight passed. Horizons used: `{TARGET_TIMES}`"))

# --- 4. CALCULATION & PLOTTING ---
os.makedirs(OUT_DIR, exist_ok=True)
fig, ax = plt.subplots(figsize=(10, 8))
cmap = plt.get_cmap("viridis", len(TARGET_TIMES))
any_line = False

display(Markdown(f"<br>🚀 **Starting Bootstrap Calibration** ({N_BOOTSTRAPS} resamples per horizon, CPUs: n-1)..."))

export_rows = []

for i, t in enumerate(sorted(TARGET_TIMES)):
    col = f"pred_{t}"
    if col not in df_master.columns:
        display(Markdown(f"⚠️ Missing `{col}`"))
        continue

    df_h = df_master[["time", "event", col]].rename(columns={col: "pred"}).dropna().copy()
    bins_eff = min(RISK_GROUPS, int(df_h["pred"].nunique()))
    evt = df_h["event"].value_counts().to_dict()
    display(Markdown(f"- t={t}: n={len(df_h):,}, unique_pred={df_h['pred'].nunique():,}, bins_eff={bins_eff}, events={evt}"))

    if bins_eff < 2:
        display(Markdown(f"  - ❌ skipped: <2 unique predictions"))
        continue

    original_results = calculate_aj_for_sample(df_h, time_point=t, n_groups=bins_eff, min_bin_n=30)
    if not original_results:
        display(Markdown(f"  - ❌ skipped: no original AJ results"))
        continue

    df_h["bin"] = pd.qcut(df_h["pred"].rank(method="first"), q=bins_eff, labels=False, duplicates="drop")
    mean_preds = df_h.groupby("bin")["pred"].mean().to_dict()

    # --- Export table rows for this horizon (original sample, not bootstrap) ---
    bin_summary = (
        df_h.groupby("bin", dropna=True)
            .agg(
                n_patients=("event", "size"),
                n_readmissions=("event", lambda s: int((s == 1).sum())),
                n_deaths=("event", lambda s: int((s == 2).sum())),
                mean_pred=("pred", "mean"),
            )
            .reset_index()
    )
    for _, r in bin_summary.iterrows():
        b = int(r["bin"])
        export_rows.append({
            "Time (mo)": int(t),
            "Decile": b + 1,  # 1..10 instead of 0..9
            "N Patients": int(r["n_patients"]),
            "N Readmissions": int(r["n_readmissions"]),
            "N Deaths": int(r["n_deaths"]),
            "Mean Pred": float(r["mean_pred"]),
            "Obs CIF": float(original_results.get(b, np.nan)),  # AJ observed CIF at this horizon
        })

    boot_results = Parallel(n_jobs=-2)(
        delayed(bootstrap_iteration)(df_h, t, bins_eff) for _ in range(N_BOOTSTRAPS)
    )
    for b in sorted(original_results.keys()):
        valid_n = sum(1 for res in boot_results if res and b in res and np.isfinite(res[b]))
        print(f"[t={t}] bin={b} valid={valid_n}/{N_BOOTSTRAPS}, orig={original_results[b]}")

    x_vals, y_vals, y_lower, y_upper = [], [], [], []
    for b in sorted(original_results.keys()):
        if np.isnan(original_results[b]) or b not in mean_preds:
            continue
        b_ests = [res[b] for res in boot_results if res and b in res and not np.isnan(res[b])]
        if len(b_ests) >= int(0.3 * N_BOOTSTRAPS):   # relaxed from 50%
            x_vals.append(mean_preds[b])
            y_vals.append(original_results[b])
            y_lower.append(np.percentile(b_ests, 2.5))
            y_upper.append(np.percentile(b_ests, 97.5))

    if not x_vals:
        display(Markdown(f"  - ❌ skipped: no bins passed CI filter"))
        continue

    ax.plot(x_vals, y_vals, marker="o", linewidth=2, color=cmap(i), label=f"{t} months")
    ax.fill_between(x_vals, y_lower, y_upper, color=cmap(i), alpha=0.15)
    any_line = True
    display(Markdown(f"  - ✅ t={t} plotted"))

# --- 5. FINAL PLOT FORMATTING ---
ax.plot([0, 1], [0, 1], "k--", alpha=0.5, label="Perfect Calibration")

# Restored absolute axis limits
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)

ax.set_xlabel("Predicted Probability (XGBoost)", fontsize=14)
ax.set_ylabel("Observed Cumulative Incidence (Aalen-Johansen)", fontsize=14)
ax.set_title(f"Calibration with 95% Bootstrap Confidence Intervals ({N_BOOTSTRAPS} resamples)", fontsize=16, fontweight='bold')
ax.grid(True, alpha=0.3)

if any_line:
    ax.legend(fontsize=12, loc="upper left")
fig.tight_layout()

# Save figures
png = os.path.join(OUT_DIR, f"XGB7_comb_Calibration_AJ_Bootstrap_CI_{timestamp_now}.png")
pdf = png.replace(".png", ".pdf")
fig.savefig(png, dpi=300, bbox_inches="tight")
fig.savefig(pdf, bbox_inches="tight")

display(Markdown(f"<br>✅ **Master plot successfully saved at:** `{png}`"))

plt.show()
# --- Save calibration table ---
df_export = pd.DataFrame(export_rows).sort_values(["Time (mo)", "Decile"])

# optional rounding for readability
df_export["Mean Pred"] = df_export["Mean Pred"].round(6)
df_export["Obs CIF"] = df_export["Obs CIF"].round(6)

csv_out = os.path.join(IN_DIR, f"xgb7b_calibration_table_{timestamp_now}.csv")
df_export.to_csv(csv_out, index=False)

display(Markdown(f"✅ **Calibration table exported:** `{csv_out}`"))

Searching for the latest Master Parquet file…

📦 Loading Master DF: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb7_calibration_master_20260306_1822_mar26.parquet

✅ Preflight passed. Horizons used: [3, 6, 12, 36, 60, 96]


🚀 Starting Bootstrap Calibration (500 resamples per horizon, CPUs: n-1)…

  • t=3: n=352,605, unique_pred=349,092, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=3 plotted
  • t=6: n=352,605, unique_pred=349,042, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=6 plotted
  • t=12: n=352,605, unique_pred=349,014, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=12 plotted
  • t=36: n=352,605, unique_pred=348,676, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=36 plotted
  • t=60: n=352,605, unique_pred=348,573, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=60 plotted
  • t=96: n=352,605, unique_pred=348,351, bins_eff=10, events={0: 264150, 1: 76235, 2: 12220}
  • ✅ t=96 plotted


Master plot successfully saved at: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\XGB7_comb_Calibration_AJ_Bootstrap_CI_20260306_1823.png

Calibration table exported: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb7b_calibration_table_20260306_1823.csv

Code
display(HTML(f"""
<div style="
    height:400px;
    overflow:auto;
    border:1px solid #ccc;
    padding:10px;
    background-color:white;
    font-family:'Times New Roman';
    font-size:13px;
">
    {df_export.to_html(index=False)}
</div>
"""))
Time (mo) Decile N Patients N Readmissions N Deaths Mean Pred Obs CIF
3 1 35261 2538 1545 0.017346 0.013659
3 2 35260 4035 1282 0.024775 0.019695
3 3 35261 4828 1275 0.029154 0.024180
3 4 35260 5821 1270 0.032860 0.025993
3 5 35261 6746 1206 0.036450 0.032112
3 6 35260 8095 1009 0.040248 0.037654
3 7 35260 8857 1063 0.044703 0.045491
3 8 35261 9592 1209 0.050590 0.049526
3 9 35260 11052 1277 0.060188 0.061958
3 10 35261 14671 1084 0.093380 0.115821
6 1 35261 2538 1544 0.029678 0.025639
6 2 35260 4047 1291 0.042280 0.032649
6 3 35261 4832 1272 0.049659 0.043336
6 4 35260 5786 1266 0.055884 0.047909
6 5 35261 6766 1204 0.061898 0.059807
6 6 35260 8084 1020 0.068255 0.068233
6 7 35260 8877 1060 0.075682 0.076554
6 8 35261 9577 1200 0.085453 0.083704
6 9 35260 11068 1274 0.101311 0.104663
6 10 35261 14660 1089 0.154752 0.172018
12 1 35261 2540 1547 0.048375 0.041944
12 2 35260 4054 1286 0.068653 0.059924
12 3 35261 4831 1289 0.080426 0.074465
12 4 35260 5768 1254 0.090322 0.080238
12 5 35261 6785 1211 0.099839 0.093869
12 6 35260 8112 1015 0.109867 0.112419
12 7 35260 8847 1063 0.121506 0.122139
12 8 35261 9584 1188 0.136739 0.133118
12 9 35260 11068 1282 0.161247 0.164548
12 10 35261 14646 1085 0.240676 0.247609
36 1 35261 2536 1547 0.090605 0.081354
36 2 35260 4058 1280 0.127503 0.119576
36 3 35261 4826 1281 0.148501 0.133476
36 4 35260 5755 1277 0.165955 0.150882
36 5 35261 6814 1209 0.182591 0.178312
36 6 35260 8078 1018 0.199953 0.204205
36 7 35260 8860 1065 0.219891 0.223172
36 8 35261 9589 1167 0.245574 0.238593
36 9 35260 11077 1289 0.286075 0.274627
36 10 35261 14642 1087 0.406567 0.376104
60 1 35261 2533 1551 0.113559 0.100495
60 2 35260 4054 1282 0.159034 0.146277
60 3 35261 4840 1270 0.184628 0.168417
60 4 35260 5744 1297 0.205787 0.191951
60 5 35261 6819 1200 0.225844 0.218311
60 6 35260 8081 1016 0.246680 0.251036
60 7 35260 8862 1070 0.270424 0.273051
60 8 35261 9583 1162 0.300777 0.291441
60 9 35260 11086 1286 0.348096 0.327848
60 10 35261 14633 1086 0.482516 0.430065
96 1 35261 2527 1547 0.136227 0.120792
96 2 35260 4055 1282 0.189879 0.177972
96 3 35261 4835 1283 0.219787 0.201211
96 4 35260 5755 1280 0.244340 0.229822
96 5 35261 6821 1207 0.267492 0.259661
96 6 35260 8085 1014 0.291388 0.293867
96 7 35260 8844 1070 0.318450 0.314101
96 8 35261 9585 1168 0.352765 0.330523
96 9 35260 11075 1283 0.405561 0.377877
96 10 35261 14653 1086 0.549056 0.483297

SHAP

For XGB Cox, SHAP is on log-hazard scale, so case ranking should use absolute risk from baseline hazard:

Baseline hazards let you compute absolute risk per patient/horizon. But this automatically converts TreeSHAP values into additive risk-point SHAP. That needs a different explainer setup (probability-output function, usually model-agnostic and slower).

Use baseline hazards only to compute absolute risk at horizon for ranking/labeling patients and calibration.

Avoid labeling SHAP axes as “% risk points

  • SHAP Values - Log-hazard - ✅ Correct (not converted)

  • Patient Ranking - Absolute risk - ✅ Correct (uses baseline hazard)

  • Waterfall Titles - Log-hazard - ✅ Correct (labeled properly)

  • Absolute Risk Computation - Probability [0-1] - ✅ Correct (via collect_risk_by_id())

Feature contributions do NOT sum on absolute risk scale due to non-linear transformation (non-linear link function); Converting Individual SHAP Values to Absolute Risk Breaks Additivity. So, water fall should not show absolute risk contributions. - Log-Hazard: f(x) = base + Σ SHAP_i –> Linear: Additive

  • Absolute Risk: f(x) = 1 - exp(-H₀(t) × exp(base + Σ SHAP_i)) –> Non-linear: NOT Additive

SHAP (SHapley Additive exPlanations) values were computed for both readmission and mortality models. To prevent data leakage, SHAP values were calculated on out-of-sample cross-validated predictions: for each patient, explanations were generated from models trained on 4/5 of folds that did not include that patient. SHAP values were averaged across 5-fold cross-validation and 5 multiple imputations, providing internally validated feature importance estimates. However, be honest this is internal validation only, not external.

Code
#@title Step 8: SHAP Analysis & Plots (DUAL, Multi-Horizon, source_tag-traceable)
# Corrected version - compatible with Step 5 outputs from XGboost_combined_mar26.ipynb

import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
import shap
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import display, Markdown

# -----------------------------
# 0) Config
# -----------------------------
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))

IN_DIR = os.path.join(PROJECT_ROOT, "_out")
OUT_DIR = os.path.join(PROJECT_ROOT, "_out")
FIG_DIR = os.path.join(PROJECT_ROOT, "_figs")

# Set None to auto-use all eval_times found in raw logs
TARGET_HORIZONS = [12, 60]
MAX_BEESWARM_N = 90000
RNG = np.random.RandomState(2125)

# CI config (optional but enabled by default)
BOOTSTRAP_CI = True
N_BOOTSTRAP = 500
CI_ALPHA = 0.05
BOOTSTRAP_MAX_N = 90000

# Multicollinearity check
CORR_THRESHOLD = 0.90

# Unified DPI
FIG_DPI = 300

os.makedirs(IN_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)

# NOTE: These key candidates are matched against Step 5 output structure
# Step 5 saves: 'risk_pred_readm', 'risk_pred_death', 'probs_readm_matrix', 'probs_death_matrix', etc.
OUTCOME_CFG = {
    "readm": {
        "label": "Readmission",
        "shap_key_candidates": ["shap_r_all", "shap_readm_all"],
        "margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],
        "probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],
        "hz_times_candidates": ["times_r", "times_readm"],
        "hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"],
    },
    "death": {
        "label": "Death",
        "shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],
        "margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],
        "probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],
        "hz_times_candidates": ["times_d", "times_death", "times_mort"],
        "hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"],
    },
}

VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]

mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "axes.labelsize": 14,
    "axes.titlesize": 15,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "figure.dpi": FIG_DPI
})

# -----------------------------
# 1) Pick latest complete Step 5 bundle (simplified)
# -----------------------------
TS_RE = re.compile(r"(\d{8}_\d{4})")  # Just find the timestamp anywhere

def pick_latest_complete_bundle(in_dir):
    """Find the most recent complete set of Step 5 output files."""
    shap_files = glob.glob(os.path.join(in_dir, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl"))
    
    candidates = []
    for shapf in shap_files:
        m = TS_RE.search(os.path.basename(shapf))
        if not m:
            continue
        tag = m.group(1)
        
        # Check for companion files - try both naming patterns
        for suffix in [f"_{tag}_mar26.pkl", f"_{tag}.pkl"]:
            rawf = os.path.join(in_dir, f"xgb6_corr_DUAL_final_ev_hyp{suffix}")
            hzf = os.path.join(in_dir, f"xgb6_corr_DUAL_BaselineHazards{suffix}")
            splitf = os.path.join(in_dir, f"xgb6_corr_DUAL_CV_Splits{suffix}")
            
            if all(os.path.exists(p) for p in (rawf, hzf, splitf)):
                dt = pd.to_datetime(tag, format="%Y%m%d_%H%M", errors="coerce")
                if pd.notna(dt):
                    candidates.append((dt, tag, shapf, rawf, hzf, splitf))
                    break  # Found complete set for this tag
    
    if not candidates:
        raise FileNotFoundError(
            f"No complete Step 5 bundle found in '{in_dir}'. "
            f"Need: xgb6_corr_DUAL_SHAP_Aggregated_*.pkl + "
            f"final_ev_hyp + BaselineHazards + CV_Splits"
        )
    
    candidates.sort(key=lambda x: x[0])
    return candidates[-1]  # Latest

source_tag, shap_file, raw_file, hz_file, split_file = pick_latest_complete_bundle(IN_DIR)[1:]

# For traceability, use source_tag in filenames
FILE_TAG = source_tag
RUN_TS = pd.Timestamp.now().strftime("%Y%m%d_%H%M")

# -----------------------------
# 2) Load artifacts
# -----------------------------
try:
    with open(shap_file, "rb") as f:
        shap_data = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load SHAP file {shap_file}: {e}") from e

try:
    with open(raw_file, "rb") as f:
        raw_data_log = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load raw data file {raw_file}: {e}") from e

try:
    with open(hz_file, "rb") as f:
        baseline_hazards_log = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load baseline hazards file {hz_file}: {e}") from e

try:
    with open(split_file, "rb") as f:
        cv_splits_log = pickle.load(f)
except Exception as e:
    raise RuntimeError(f"Failed to load CV splits file {split_file}: {e}") from e

required_keys = {"X_all", "feature_names"}
missing = required_keys - set(shap_data.keys())
if missing:
    raise KeyError(f"Missing keys in SHAP file: {missing}")

X_all = shap_data["X_all"]
feature_names = list(shap_data["feature_names"])

if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(X_all, columns=feature_names)

if list(X_all.columns) != feature_names:
    X_all = X_all.reindex(columns=feature_names)

if not X_all.index.is_unique:
    raise ValueError("X_all index must be unique.")

# -----------------------------
# 3) Helpers
# -----------------------------
def get_first(dct, keys, default=None):
    """Return first value from dct matching any key in keys."""
    for k in keys:
        if k in dct and dct[k] is not None:
            return dct[k]
    return default

def find_first_key(dct, keys):
    """Return first key from keys that exists in dct."""
    for k in keys:
        if k in dct:
            return k
    return None

def h0_at_t(times, h0_vals, t):
    """Get baseline cumulative hazard at time t using left-continuous interpolation."""
    times = np.asarray(times, dtype=float).ravel()
    h0_vals = np.asarray(h0_vals, dtype=float).ravel()
    if times.size == 0 or h0_vals.size == 0 or len(times) != len(h0_vals):
        return np.nan
    
    t = float(t)
    if t < times[0]:
        return 0.0  # Before first event, cumulative hazard is 0
    
    i = np.searchsorted(times, t, side="right") - 1
    i = max(0, min(i, len(h0_vals) - 1))  # Clip to valid range
    return float(h0_vals[i])

def fmt_horizon(h):
    h = float(h)
    return str(int(h)) if abs(h - round(h)) < 1e-9 else f"{h:g}"

def horizon_token(h):
    return f"{fmt_horizon(h).replace('.', 'p')}m"

def save_current_figure(stem, outcome, horizon=None):
    fig = plt.gcf()
    hz = f"_{horizon_token(horizon)}" if horizon is not None else ""
    base = f"xgb8_dual_{outcome}_{stem}{hz}_{FILE_TAG}"
    png = os.path.join(FIG_DIR, f"{base}.png")
    pdf = os.path.join(FIG_DIR, f"{base}.pdf")
    fig.savefig(png, dpi=FIG_DPI, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight")
    return [png, pdf]

def discover_horizons(raw_log):
    """Extract all unique eval_times from raw data log."""
    vals = []
    for rec in raw_log:
        ev = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        vals.extend([v for v in ev if np.isfinite(v)])
    return sorted(set(vals))

def bootstrap_mean_abs_shap(shap_vals, n_boot=200, alpha=0.05, seed=2026, max_n=None):
    """Bootstrap confidence intervals for mean absolute SHAP values."""
    n, p = shap_vals.shape
    rng = np.random.RandomState(seed)

    if max_n is not None and n > max_n:
        idx = rng.choice(n, size=max_n, replace=False)
        X_sub = shap_vals[idx, :]
    else:
        X_sub = shap_vals

    n_eff = X_sub.shape[0]
    point = np.abs(X_sub).mean(axis=0)

    boot = np.empty((n_boot, p), dtype=float)
    for b in range(n_boot):
        ib = rng.choice(n_eff, size=n_eff, replace=True)
        boot[b] = np.abs(X_sub[ib]).mean(axis=0)

    lo = np.quantile(boot, alpha / 2.0, axis=0)
    hi = np.quantile(boot, 1.0 - alpha / 2.0, axis=0)
    return point, lo, hi, n_eff

def collect_margin_by_id(raw_log, split_map, cfg):
    """Collect margin (log-hazard) predictions by patient ID, averaged across folds."""
    rows = []
    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))
        split_rec = split_map.get(key)
        if split_rec is None:
            continue

        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()

        if len(val_ids) != len(margins) or len(val_ids) == 0:
            continue

        # FIXED: Ensure IDs are strings for consistent matching
        for i, pid in enumerate(val_ids):
            rows.append((str(pid), float(margins[i])))

    if not rows:
        return pd.DataFrame(columns=["id", "margin"])

    return pd.DataFrame(rows, columns=["id", "margin"]).groupby("id", as_index=False)["margin"].mean()

def collect_risk_by_id(raw_log, split_map, hz_map, cfg, horizon):
    """Collect absolute risk predictions by patient ID at a specific horizon."""
    rows = []
    t = float(horizon)

    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))

        split_rec = split_map.get(key)
        hz_rec = hz_map.get(key)

        if split_rec is None:
            continue

        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()
        if len(val_ids) != len(margins) or len(val_ids) == 0:
            continue

        risk_vec = None
        eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)

        # Try to get risk from survival probability matrix
        if eval_times.size > 0 and probs_mat.ndim == 2:
            j = np.where(np.isclose(eval_times, t))[0]
            if j.size > 0:
                jj = int(j[0])
                if probs_mat.shape[0] == len(val_ids) and probs_mat.shape[1] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[:, jj]
                elif probs_mat.shape[1] == len(val_ids) and probs_mat.shape[0] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[jj, :]
                if risk_vec is not None:
                    risk_vec = np.asarray(risk_vec, dtype=float).ravel()

        # Fallback: compute from baseline hazard and margins
        if risk_vec is None and hz_rec is not None:
            times = get_first(hz_rec, cfg["hz_times_candidates"], [])
            h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], [])
            H0_t = h0_at_t(times, h0_vals, t)
            if np.isfinite(H0_t):
                surv = np.exp(-np.exp(margins) * H0_t)
                risk_vec = 1.0 - surv

        if risk_vec is not None and len(risk_vec) == len(val_ids):
            risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)
            for i, pid in enumerate(val_ids):
                rv = float(risk_vec[i])
                if np.isfinite(rv):
                    rows.append((str(pid), rv))

    if not rows:
        return pd.DataFrame(columns=["id", "risk"])

    return pd.DataFrame(rows, columns=["id", "risk"]).groupby("id", as_index=False)["risk"].mean()

def collect_risk_samples_by_id(raw_log, split_map, hz_map, cfg, horizon):
    """Collect all risk samples by patient ID (for CI computation)."""
    rows = []
    t = float(horizon)

    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))
        split_rec = split_map.get(key)
        hz_rec = hz_map.get(key)
        if split_rec is None:
            continue

        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()
        if len(val_ids) != len(margins) or len(val_ids) == 0:
            continue

        risk_vec = None
        eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)

        # Try survival probability matrix
        if eval_times.size > 0 and probs_mat.ndim == 2:
            j = np.where(np.isclose(eval_times, t))[0]
            if j.size > 0:
                jj = int(j[0])
                if probs_mat.shape[0] == len(val_ids) and probs_mat.shape[1] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[:, jj]
                elif probs_mat.shape[1] == len(val_ids) and probs_mat.shape[0] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[jj, :]

        # Fallback to baseline hazard
        if risk_vec is None and hz_rec is not None:
            times = get_first(hz_rec, cfg["hz_times_candidates"], [])
            h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], [])
            H0_t = h0_at_t(times, h0_vals, t)
            if np.isfinite(H0_t):
                risk_vec = 1.0 - np.exp(-np.exp(margins) * H0_t)

        if risk_vec is None:
            continue

        risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)
        if len(risk_vec) != len(val_ids):
            continue

        for i, pid in enumerate(val_ids):
            rv = float(risk_vec[i])
            if np.isfinite(rv):
                rows.append((str(pid), rv))

    if not rows:
        return pd.DataFrame(columns=["id", "risk"])

    return pd.DataFrame(rows, columns=["id", "risk"])

def summarize_risk_ci(df_samples, alpha=0.05):
    """Summarize risk samples with mean and quantile CI."""
    if df_samples.empty:
        return pd.DataFrame(columns=["id", "risk_mean", "risk_ci_low", "risk_ci_high", "n_samples"])

    g = df_samples.groupby("id")["risk"]
    out = g.agg(risk_mean="mean", n_samples="size").reset_index()
    out["risk_ci_low"] = g.quantile(alpha / 2.0).values
    out["risk_ci_high"] = g.quantile(1.0 - alpha / 2.0).values
    return out

def correlation_pairs_report(X_df, threshold=0.85):
    """Find highly correlated feature pairs."""
    X_num = X_df.apply(pd.to_numeric, errors="coerce")
    valid_cols = [c for c in X_num.columns if X_num[c].std(skipna=True) > 0]
    X_num = X_num[valid_cols]
    corr = X_num.corr(method="pearson")

    pairs = []
    cols = list(corr.columns)
    for i in range(len(cols)):
        for j in range(i + 1, len(cols)):
            r = corr.iat[i, j]
            if np.isfinite(r) and abs(r) >= threshold:
                pairs.append((cols[i], cols[j], float(r), float(abs(r))))

    pairs_df = pd.DataFrame(pairs, columns=["feature_1", "feature_2", "pearson_r", "abs_r"])
    if len(pairs_df):
        pairs_df = pairs_df.sort_values("abs_r", ascending=False).reset_index(drop=True)
    return pairs_df, corr

# -----------------------------
# 4) Prepare maps/horizons + multicollinearity check
# -----------------------------
split_map = {
    (int(s["imp_idx"]), int(s["fold_idx"])): s
    for s in cv_splits_log
    if "imp_idx" in s and "fold_idx" in s
}
hz_map = {
    (int(h["imp_idx"]), int(h["fold_idx"])): h
    for h in baseline_hazards_log
    if "imp_idx" in h and "fold_idx" in h
}

available_horizons = discover_horizons(raw_data_log)

# FIXED: Handle empty TARGET_HORIZONS and validate requested horizons
if TARGET_HORIZONS is None or len(TARGET_HORIZONS) == 0:
    horizons = available_horizons if available_horizons else [12.0]
else:
    horizons = sorted(set(float(h) for h in TARGET_HORIZONS))
    # Warn about missing horizons
    missing_h = set(horizons) - set(available_horizons)
    if missing_h:
        nb_print(f"Warning: Requested horizons not in data: {sorted(missing_h)}")
        nb_print(f"Available horizons: {available_horizons}")

corr_pairs_df, corr_mat = correlation_pairs_report(X_all, threshold=CORR_THRESHOLD)
corr_pairs_file = os.path.join(OUT_DIR, f"xgb8_dual_feature_corr_pairs_{FILE_TAG}.csv")
corr_mat_file = os.path.join(OUT_DIR, f"xgb8_dual_feature_corr_matrix_{FILE_TAG}.csv")
corr_pairs_df.to_csv(corr_pairs_file, index=False)
corr_mat.to_csv(corr_mat_file)

display(Markdown(
    f"### Step 8 SHAP (DUAL, Multi-Horizon)\n"
    f"- Source bundle tag: **{source_tag}**\n"
    f"- Run time: **{RUN_TS}**\n"
    f"- Patients: **{X_all.shape[0]}**\n"
    f"- Features: **{X_all.shape[1]}**\n"
    f"- Horizons (months): **{', '.join(fmt_horizon(h) for h in horizons)}**\n"
    f"- SHAP scale: **log-hazard**\n"
    f"- Multicollinearity threshold: **|r| >= {CORR_THRESHOLD:.2f}**"
))

if len(corr_pairs_df) > 0:
    display(Markdown(f"Found **{len(corr_pairs_df)}** correlated feature pairs (|r| >= {CORR_THRESHOLD:.2f})."))
    display(corr_pairs_df.head(20))
else:
    display(Markdown(f"No feature pairs above |r| >= {CORR_THRESHOLD:.2f}."))

# -----------------------------
# 5) Run SHAP per outcome and horizon
# -----------------------------
saved_plot_files = []
saved_out_files = [corr_pairs_file, corr_mat_file]
all_case_rows = []
horizon_rows = []
processed_outcomes = []

# FIXED: id_to_row uses string keys for consistent matching
id_to_row = {str(idx): i for i, idx in enumerate(X_all.index)}

for outcome_name, cfg in OUTCOME_CFG.items():
    shap_key = find_first_key(shap_data, cfg["shap_key_candidates"])
    if shap_key is None:
        print(f"Skipping {cfg['label']}: missing SHAP key among {cfg['shap_key_candidates']}.")
        continue

    shap_vals = np.asarray(shap_data[shap_key], dtype=float)
    if shap_vals.shape != X_all.shape:
        raise ValueError(f"{cfg['label']} SHAP shape mismatch: {shap_vals.shape} vs X_all {X_all.shape}")

    # Recover baseline on margin (log-hazard) scale
    df_margin = collect_margin_by_id(raw_data_log, split_map, cfg)
    df_shap_sum = pd.DataFrame({
        "id": X_all.index.astype(str),
        "shap_sum": shap_vals.sum(axis=1)
    })
    tmp = df_shap_sum.merge(df_margin, on="id", how="inner")
    base_margin = float((tmp["margin"] - tmp["shap_sum"]).mean()) if len(tmp) > 0 else 0.0

    explanation = shap.Explanation(
        values=shap_vals,
        base_values=np.full(X_all.shape[0], base_margin, dtype=float),
        data=X_all.to_numpy(),
        feature_names=feature_names
    )

    # Global mean |SHAP| + CI
    if BOOTSTRAP_CI:
        point, ci_low, ci_high, ci_n = bootstrap_mean_abs_shap(
            shap_vals,
            n_boot=N_BOOTSTRAP,
            alpha=CI_ALPHA,
            seed=2125,
            max_n=BOOTSTRAP_MAX_N
        )
    else:
        point = np.abs(shap_vals).mean(axis=0)
        ci_low = np.full_like(point, np.nan, dtype=float)
        ci_high = np.full_like(point, np.nan, dtype=float)
        ci_n = int(shap_vals.shape[0])

    df_top = pd.DataFrame({
        "outcome": cfg["label"],
        "feature": feature_names,
        "mean_abs_shap_log_hazard": point,
        "ci95_low": ci_low,
        "ci95_high": ci_high,
        "bootstrap_n": int(ci_n),
        "n_bootstrap": int(N_BOOTSTRAP if BOOTSTRAP_CI else 0),
        "ci_alpha": float(CI_ALPHA if BOOTSTRAP_CI else np.nan),
    }).sort_values("mean_abs_shap_log_hazard", ascending=False).reset_index(drop=True)

    top_file = os.path.join(OUT_DIR, f"xgb8_dual_{outcome_name}_shap_top_features_{FILE_TAG}.csv")
    df_top.to_csv(top_file, index=False)
    saved_out_files.append(top_file)

    processed_outcomes.append(outcome_name)

    display(Markdown(f"## {cfg['label']}"))
    display(Markdown("SHAP values and global importance are on the **log-hazard** scale."))
    display(df_top.head(20))

    # Bar plot with CI (log-hazard SHAP)
    df_bar = df_top.head(20).sort_values("mean_abs_shap_log_hazard", ascending=True)

    # FIXED: Explicit dtype specification
    x = df_bar["mean_abs_shap_log_hazard"].to_numpy(dtype=float)
    has_ci = BOOTSTRAP_CI and np.isfinite(df_bar["ci95_low"]).all() and np.isfinite(df_bar["ci95_high"]).all()

    plt.figure(figsize=(11, 8))
    if has_ci:
        lo = df_bar["ci95_low"].to_numpy(dtype=float)
        hi = df_bar["ci95_high"].to_numpy(dtype=float)
        xerr = np.vstack([np.clip(x - lo, 0.0, None), np.clip(hi - x, 0.0, None)])
        plt.barh(df_bar["feature"], x, xerr=xerr, color="#4C72B0", alpha=0.9, ecolor="black", capsize=2)
        plt.title(f"{cfg['label']} Global mean |SHAP| (log-hazard) with 95% bootstrap CI")
    else:
        plt.barh(df_bar["feature"], x, color="#4C72B0", alpha=0.9)
        plt.title(f"{cfg['label']} Global mean |SHAP| (log-hazard)")
    plt.xlabel("mean |SHAP| (log-hazard)")
    plt.tight_layout()
    saved_plot_files.extend(save_current_figure("bar_ci", outcome_name))
    plt.show()
    plt.close()

    # Beeswarm (distribution on log-hazard SHAP scale)
    if X_all.shape[0] > MAX_BEESWARM_N:
        idx = RNG.choice(X_all.shape[0], MAX_BEESWARM_N, replace=False)
        exp_bee = explanation[idx]
    else:
        exp_bee = explanation

    plt.figure(figsize=(12, 8))
    shap.plots.beeswarm(exp_bee, max_display=20, show=False)
    plt.title(f"{cfg['label']} SHAP Beeswarm (log-hazard scale)")
    plt.tight_layout()
    saved_plot_files.extend(save_current_figure("beeswarm", outcome_name))
    plt.show()
    plt.close()

    # Horizon-specific risk ranking + waterfalls
    for h in horizons:
        df_risk_samples = collect_risk_samples_by_id(raw_data_log, split_map, hz_map, cfg, h)
        # FIXED: Use CI_ALPHA instead of hardcoded 0.05
        df_risk = summarize_risk_ci(df_risk_samples, alpha=CI_ALPHA)
        
        # FIXED: Ensure string IDs for consistent matching
        df_risk["id"] = df_risk["id"].astype(str)
        df_risk = df_risk[df_risk["id"].isin(id_to_row.keys())]
        
        n_h = int(len(df_risk))
        n_total = int(X_all.shape[0])

        horizon_rows.append({
            "outcome": cfg["label"],
            "horizon_months": float(h),
            "n_patients_with_risk": n_h,
            "n_total_patients": n_total,
            "coverage_pct": (100.0 * n_h / n_total) if n_total > 0 else np.nan
        })

        if n_h == 0:
            nb_print(f"{cfg['label']} @ {fmt_horizon(h)}m: no absolute risk available; skipping waterfalls.")
            continue

        # Strictly rank by absolute risk (not SHAP score/log-hazard score)
        hi = df_risk.sort_values("risk_mean", ascending=False).iloc[0]
        lo = df_risk.sort_values("risk_mean", ascending=True).iloc[0]
        
        high_id = str(hi["id"])
        low_id = str(lo["id"])

        high_risk = float(hi["risk_mean"])
        low_risk = float(lo["risk_mean"])

        high_low = float(hi["risk_ci_low"])
        high_high = float(hi["risk_ci_high"])
        low_low = float(lo["risk_ci_low"])
        low_high = float(lo["risk_ci_high"])

        high_n = int(hi["n_samples"])
        low_n = int(lo["n_samples"])

        high_row = id_to_row[high_id]
        low_row = id_to_row[low_id]

        all_case_rows.append({
            "outcome": cfg["label"],
            "horizon_months": float(h),
            "case": "highest",
            "id": high_id,
            "risk_at_horizon": high_risk,
            "n_patients_with_risk": n_h
        })
        all_case_rows.append({
            "outcome": cfg["label"],
            "horizon_months": float(h),
            "case": "lowest",
            "id": low_id,
            "risk_at_horizon": low_risk,
            "n_patients_with_risk": n_h
        })

        # Waterfall high risk (SHAP still log-hazard contribution)
        plt.figure(figsize=(10, 7))
        shap.plots.waterfall(explanation[high_row], max_display=12, show=False)
        plt.title(
            f"{cfg['label']} highest absolute risk @ {fmt_horizon(h)}m "
            f"(ID {high_id}, risk={high_risk:.3f} [{high_low:.3f}, {high_high:.3f}], n={high_n})\n"
            f"Waterfall shows SHAP contributions on log-hazard scale"
        )
        plt.tight_layout()
        saved_plot_files.extend(save_current_figure("waterfall_high", outcome_name, h))
        plt.show()
        plt.close()

        # Waterfall low risk
        plt.figure(figsize=(10, 7))
        shap.plots.waterfall(explanation[low_row], max_display=12, show=False)
        plt.title(
            f"{cfg['label']} lowest absolute risk @ {fmt_horizon(h)}m "
            f"(ID {low_id}, risk={low_risk:.3f} [{low_low:.3f}, {low_high:.3f}], n={low_n})\n"
            f"Waterfall shows SHAP contributions on log-hazard scale"
        )
        plt.tight_layout()
        saved_plot_files.extend(save_current_figure("waterfall_low", outcome_name, h))
        plt.show()
        plt.close()

# -----------------------------
# 6) Export combined outputs
# -----------------------------
cases_file = os.path.join(OUT_DIR, f"xgb8_dual_shap_extreme_cases_{FILE_TAG}.csv")
horizon_file = os.path.join(OUT_DIR, f"xgb8_dual_horizon_sample_sizes_{FILE_TAG}.csv")
info_file = os.path.join(OUT_DIR, f"xgb8_dual_shap_run_info_{FILE_TAG}.json")

df_cases = pd.DataFrame(all_case_rows)
df_horizon = pd.DataFrame(horizon_rows)

df_cases.to_csv(cases_file, index=False)
df_horizon.to_csv(horizon_file, index=False)

saved_out_files.extend([cases_file, horizon_file])

run_info = {
    "source_bundle_tag": source_tag,
    "run_timestamp": RUN_TS,
    "file_tag_used_for_outputs": FILE_TAG,
    "source_files": {
        "shap": shap_file,
        "raw": raw_file,
        "baseline_hazards": hz_file,
        "cv_splits": split_file
    },
    "n_patients": int(X_all.shape[0]),
    "n_features": int(X_all.shape[1]),
    "horizons_months": [float(h) for h in horizons],
    "available_eval_times_months": [float(h) for h in available_horizons],
    "outcomes_processed": processed_outcomes,
    "shap_scale": "log-hazard",
    "global_importance_metric": "mean absolute SHAP (log-hazard)",
    "bootstrap_ci": {
        "enabled": bool(BOOTSTRAP_CI),
        "n_bootstrap": int(N_BOOTSTRAP) if BOOTSTRAP_CI else 0,
        "alpha": float(CI_ALPHA) if BOOTSTRAP_CI else None,
        "max_n_for_bootstrap": int(BOOTSTRAP_MAX_N) if BOOTSTRAP_CI else None
    },
    "multicollinearity_check": {
        "method": "pairwise Pearson correlation",
        "threshold_abs_r": float(CORR_THRESHOLD),
        "pairs_file": corr_pairs_file,
        "corr_matrix_file": corr_mat_file
    },
    "note": "Absolute risk is used for extreme-case ranking at each horizon. SHAP values remain log-hazard contributions."
}
with open(info_file, "w", encoding="utf-8") as f:
    json.dump(run_info, f, indent=2)

saved_out_files.append(info_file)

print("\nSaved plots to _figs (PNG/PDF):")
for p in saved_plot_files:
    nb_print(" -", p)

print("\nSaved tables/metadata to _out:")
for p in saved_out_files:
    nb_print(" -", p)

Step 8 SHAP (DUAL, Multi-Horizon)

  • Source bundle tag: 20260306_1821
  • Run time: 20260306_1828
  • Patients: 70521
  • Features: 56
  • Horizons (months): 12, 60
  • SHAP scale: log-hazard
  • Multicollinearity threshold: |r| >= 0.90

No feature pairs above |r| >= 0.90.

Readmission

SHAP values and global importance are on the log-hazard scale.

outcome feature mean_abs_shap_log_hazard ci95_low ci95_high bootstrap_n n_bootstrap ci_alpha
0 Readmission primary_sub_mod_cocaine_paste 0.090015 0.089795 0.090213 70521 500 0.05
1 Readmission adm_age_rec3 0.078743 0.077885 0.079599 70521 500 0.05
2 Readmission porc_pobr 0.073050 0.072710 0.073402 70521 500 0.05
3 Readmission sex_rec_woman 0.070474 0.070155 0.070807 70521 500 0.05
4 Readmission plan_type_corr_pg_pr 0.069722 0.069139 0.070305 70521 500 0.05
5 Readmission plan_type_corr_m_pr 0.056074 0.055319 0.056821 70521 500 0.05
6 Readmission ethnicity 0.054893 0.054140 0.055594 70521 500 0.05
7 Readmission dit_m 0.047543 0.047252 0.047825 70521 500 0.05
8 Readmission eva_consumo 0.046565 0.046266 0.046839 70521 500 0.05
9 Readmission ed_attainment_corr 0.041594 0.041322 0.041856 70521 500 0.05
10 Readmission occupation_condition_corr24_unemployed 0.036454 0.036315 0.036597 70521 500 0.05
11 Readmission dg_psiq_cie_10_dg 0.028379 0.028299 0.028446 70521 500 0.05
12 Readmission sub_dep_icd10_status_drug_dependence 0.028375 0.028231 0.028495 70521 500 0.05
13 Readmission polysubstance_strict 0.025277 0.025137 0.025427 70521 500 0.05
14 Readmission primary_sub_mod_alcohol 0.024515 0.024404 0.024631 70521 500 0.05
15 Readmission eva_sm 0.023360 0.023182 0.023537 70521 500 0.05
16 Readmission primary_sub_mod_cocaine_powder 0.023091 0.022960 0.023230 70521 500 0.05
17 Readmission evaluacindelprocesoteraputico 0.021649 0.021516 0.021774 70521 500 0.05
18 Readmission tr_outcome_adm_discharge_rule_violation_undet 0.021562 0.021317 0.021812 70521 500 0.05
19 Readmission prim_sub_freq_rec 0.019463 0.019331 0.019586 70521 500 0.05

Death

SHAP values and global importance are on the log-hazard scale.

outcome feature mean_abs_shap_log_hazard ci95_low ci95_high bootstrap_n n_bootstrap ci_alpha
0 Death adm_age_rec3 0.497441 0.495344 0.499279 70521 500 0.05
1 Death primary_sub_mod_alcohol 0.267121 0.266603 0.267636 70521 500 0.05
2 Death prim_sub_freq_rec 0.114488 0.114180 0.114753 70521 500 0.05
3 Death occupation_condition_corr24_unemployed 0.099876 0.099530 0.100233 70521 500 0.05
4 Death any_phys_dx 0.092606 0.091780 0.093319 70521 500 0.05
5 Death eva_ocupacion 0.074887 0.074562 0.075242 70521 500 0.05
6 Death cohabitation_with_couple_children 0.058190 0.058002 0.058380 70521 500 0.05
7 Death eva_fisica 0.049436 0.049125 0.049758 70521 500 0.05
8 Death adm_motive_sanitary_sector 0.048966 0.048799 0.049165 70521 500 0.05
9 Death ed_attainment_corr 0.046026 0.045788 0.046245 70521 500 0.05
10 Death sex_rec_woman 0.045834 0.045591 0.046071 70521 500 0.05
11 Death first_sub_used_alcohol 0.041504 0.041370 0.041647 70521 500 0.05
12 Death occupation_condition_corr24_inactive 0.041420 0.041117 0.041719 70521 500 0.05
13 Death primary_sub_mod_cocaine_paste 0.035601 0.035465 0.035751 70521 500 0.05
14 Death dit_m 0.034467 0.034197 0.034742 70521 500 0.05
15 Death plan_type_corr_pg_pr 0.033982 0.033714 0.034284 70521 500 0.05
16 Death porc_pobr 0.033410 0.033156 0.033610 70521 500 0.05
17 Death eva_sm 0.033211 0.032937 0.033489 70521 500 0.05
18 Death polysubstance_strict 0.028167 0.028007 0.028305 70521 500 0.05
19 Death tenure_status_household 0.028083 0.027949 0.028233 70521 500 0.05

 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_bar_ci_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_bar_ci_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_beeswarm_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_beeswarm_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_high_12m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_high_12m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_low_12m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_low_12m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_high_60m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_high_60m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_low_60m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_readm_waterfall_low_60m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_bar_ci_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_bar_ci_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_beeswarm_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_beeswarm_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_high_12m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_high_12m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_low_12m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_low_12m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_high_60m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_high_60m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_low_60m_20260306_1821.png
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb8_dual_death_waterfall_low_60m_20260306_1821.pdf
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_feature_corr_pairs_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_feature_corr_matrix_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_readm_shap_top_features_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_death_shap_top_features_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_shap_extreme_cases_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_horizon_sample_sizes_20260306_1821.csv
 - G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb8_dual_shap_run_info_20260306_1821.json

Stability of predictors

  1. Assesses SHAP rank stability via bootstrap resampling.

  2. Uses mean absolute SHAP for global importance.

  3. Re-ranks features across 500 bootstrap samples.

  4. Estimates variability of feature ranks.

  5. Calculates frequency of inclusion in Top-45.

  6. Flags features stable if ≥90% Top-45 frequency.

  7. Reports SD, CV, and 95% rank intervals.

  8. Applies composite rule for Cox selection.

  9. Diagnoses CV denominator artifacts.

  10. Exports detailed, summary, and audit files.

5 Key Assumptions
  1. Mean |SHAP| reflects true importance.

  2. Bootstrap captures rank variability.

  3. SHAP values come from unbiased models.

  4. Top-45 is meaningful for Cox screening.

  5. Rank stability reflects true signal.

Code
import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
from IPython.display import display, Markdown

# -----------------------------
# 0) Preconditions + config
# -----------------------------
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))

OUT_DIR_LOCAL = globals().get("OUT_DIR", os.path.join(PROJECT_ROOT, "_out"))
if not os.path.isabs(str(OUT_DIR_LOCAL)):
    OUT_DIR_LOCAL = os.path.join(PROJECT_ROOT, str(OUT_DIR_LOCAL))
OUT_DIR_LOCAL = os.path.abspath(str(OUT_DIR_LOCAL))
os.makedirs(OUT_DIR_LOCAL, exist_ok=True)

TS_FMT = "%Y%m%d_%H%M"
TS_RE = re.compile(r"(\d{8}_\d{4})")  # Finds timestamp anywhere in filename

def _extract_ts(path):
    """Extract timestamp from filename. Returns (datetime_obj, timestamp_str) or None."""
    m = TS_RE.search(os.path.basename(path))
    if not m:
        return None
    try:
        ts_str = m.group(1)
        return pd.to_datetime(ts_str, format=TS_FMT), ts_str
    except Exception:
        return None

def _latest_file(patterns):
    """Find the latest file by timestamp across multiple patterns.
    Returns (filepath, timestamp_str) or (None, None)."""
    cands = []
    for pat in patterns:
        cands.extend(glob.glob(os.path.join(OUT_DIR_LOCAL, pat)))
    
    ranked = []
    for p in cands:
        ts = _extract_ts(p)
        if ts is not None:
            ranked.append((ts[0], ts[1], p))
    
    if not ranked:
        return None, None
    
    ranked.sort(key=lambda x: x[0], reverse=True)
    return ranked[0][2], ranked[0][1]

# Auto-load latest SHAP if missing
LOADED_SHAP_PATH = None
if "shap_data" not in globals() or not isinstance(globals().get("shap_data"), dict):
    # Patterns match both: _20260305_2232.pkl and _20260305_2232_mar26.pkl
    shap_path, shap_ts = _latest_file([
        "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl",
        "xgb6_corr_DUAL_final_ev_hyp_*.pkl",
    ])
    if shap_path is None:
        raise RuntimeError(
            f"`shap_data` missing and no SHAP artifact found in {OUT_DIR_LOCAL} "
            f"(expected xgb6_corr_DUAL_*_{TS_FMT}[*_mar26].pkl)."
        )
    with open(shap_path, "rb") as f:
        shap_data = pickle.load(f)
    LOADED_SHAP_PATH = shap_path
    if not isinstance(shap_data, dict):
        raise RuntimeError(f"Loaded SHAP artifact is not a dict: {shap_path}")
else:
    shap_ts = None

SOURCE_TAG_LOCAL = globals().get("source_tag")
if not SOURCE_TAG_LOCAL:
    if shap_ts is None:
        _, shap_ts = _latest_file([
            "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl",
            "xgb6_corr_DUAL_final_ev_hyp_*.pkl",
        ])
    SOURCE_TAG_LOCAL = shap_ts if shap_ts else pd.Timestamp.now().strftime(TS_FMT)

RUN_TS = pd.Timestamp.now().strftime(TS_FMT)
Code
# FIXED: Add this after the auto-loading block (after "print(f"Auto-loaded SHAP data: {shap_path}")" line)

if LOADED_SHAP_PATH:
    display(Markdown(f"- Auto-loaded latest SHAP: `{LOADED_SHAP_PATH}`"))
elif "shap_data" in globals():
    display(Markdown("- Using existing `shap_data` from memory"))
  • Using existing shap_data from memory
Code
#@title Step 8c: SHAP Rank Stability for Cox Candidate Selection (Bootstrap + Top-K Inclusion Frequency)

import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
from IPython.display import display, Markdown, HTML  # FIXED: Added HTML import

# -----------------------------
# 0) Preconditions + config
# -----------------------------

# FIXED: Add robust PROJECT_ROOT check
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))

OUT_DIR_LOCAL = globals().get("OUT_DIR", os.path.join(PROJECT_ROOT, "_out"))
if not os.path.isabs(str(OUT_DIR_LOCAL)):
    OUT_DIR_LOCAL = os.path.join(PROJECT_ROOT, str(OUT_DIR_LOCAL))
OUT_DIR_LOCAL = os.path.abspath(str(OUT_DIR_LOCAL))
os.makedirs(OUT_DIR_LOCAL, exist_ok=True)

TS_FMT = "%Y%m%d_%H%M"
TS_RE = re.compile(r"(\d{8}_\d{4})")  # FIXED: Same pattern as Step 8

def _extract_ts(path):
    """Extract timestamp from filename. Returns (datetime_obj, timestamp_str) or None."""
    m = TS_RE.search(os.path.basename(path))
    if not m:
        return None
    try:
        ts_str = m.group(1)
        return pd.to_datetime(ts_str, format=TS_FMT), ts_str
    except Exception:
        return None

def _latest_file(patterns):
    """Find the latest file by timestamp across multiple patterns."""
    cands = []
    for pat in patterns:
        cands.extend(glob.glob(os.path.join(OUT_DIR_LOCAL, pat)))
    ranked = []
    for p in cands:
        ts = _extract_ts(p)
        if ts is not None:
            ranked.append((ts[0], ts[1], p))
    if not ranked:
        return None, None
    ranked.sort(key=lambda x: x[0], reverse=True)
    return ranked[0][2], ranked[0][1]

# FIXED: Add auto-loading logic consistent with Step 8
LOADED_SHAP_PATH = None
if "shap_data" not in globals() or not isinstance(globals().get("shap_data"), dict):
    shap_path, shap_ts = _latest_file([
        "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl",
        "xgb6_corr_DUAL_final_ev_hyp_*.pkl",
    ])
    if shap_path is None:
        raise RuntimeError(
            f"`shap_data` missing and no SHAP artifact found in {OUT_DIR_LOCAL} "
            f"(expected xgb6_corr_DUAL_*_{TS_FMT}[*_mar26].pkl)."
        )
    with open(shap_path, "rb") as f:
        shap_data = pickle.load(f)
    LOADED_SHAP_PATH = shap_path
    if not isinstance(shap_data, dict):
        raise RuntimeError(f"Loaded SHAP artifact is not a dict: {shap_path}")
    print(f"Auto-loaded SHAP data: {shap_path}")
else:
    shap_ts = None

# FIXED: Proper SOURCE_TAG_LOCAL handling
SOURCE_TAG_LOCAL = globals().get("source_tag")
if not SOURCE_TAG_LOCAL:
    if shap_ts is None:
        _, shap_ts = _latest_file([
            "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl",
            "xgb6_corr_DUAL_final_ev_hyp_*.pkl",
        ])
    SOURCE_TAG_LOCAL = shap_ts if shap_ts else pd.Timestamp.now().strftime(TS_FMT)

RUN_TS = pd.Timestamp.now().strftime(TS_FMT)

def show_scroll_df(df, height_px=420):
    html = df.to_html(index=False, escape=False)
    display(HTML(
        f"""
        <div style="max-height:{height_px}px; overflow-y:auto; border:1px solid #ddd; border-radius:6px;">
            {html}
        </div>
        """
    ))

# Your requested settings
N_BOOTSTRAP = 500
BOOT_SEED = 2125
MAX_N = 90000         # cap of patients used in each bootstrap run; if data has <= MAX_N, all are used
TOP_K = 45

# Stability thresholds
FREQ_THRESHOLD_PCT = 90.0
CV_THRESHOLD_PCT = 30.0
SD_THRESHOLD_RANK = 5.0
RANK_CI_UPPER_THRESHOLD = 45.0  # optional stricter upper CI bound for top-30-style stability

OUTCOME_KEYS = {
    "readm": {"label": "Readmission", "keys": ["shap_r_all", "shap_readm_all"]},
    "death": {"label": "Death", "keys": ["shap_d_all", "shap_death_all", "shap_mort_all"]},
}

# FIXED: More robust feature_names extraction with error handling
feature_names = list(shap_data.get("feature_names", []))
if not feature_names:
    # Try to get from X_all if available
    X_all_temp = shap_data.get("X_all")
    if X_all_temp is not None:
        if hasattr(X_all_temp, "columns"):
            feature_names = list(X_all_temp.columns)
        elif isinstance(X_all_temp, pd.DataFrame):
            feature_names = list(X_all_temp.columns)

if not feature_names:
    raise RuntimeError(
        "Feature names not found. Need shap_data['feature_names'] or "
        "shap_data['X_all'] with columns."
    )

display(Markdown(
    f"### Step 8c: Bootstrap Rank Stability\n"
    f"- Source tag: **{SOURCE_TAG_LOCAL}**\n"
    f"- Run time: **{RUN_TS}**\n"
    f"- Bootstraps: **{N_BOOTSTRAP}**\n"
    f"- Seed: **{BOOT_SEED}**\n"
    f"- Top-K: **{TOP_K}**\n"
    f"- max_n: **{MAX_N}** (max patients used per bootstrap replicate)\n"
    f"- Primary stability: **freq_in_top_{TOP_K} >= {FREQ_THRESHOLD_PCT:.0f}%**"
))

# -----------------------------
# 1) Helpers
# -----------------------------
def first_existing_key(dct, keys):
    for k in keys:
        if k in dct:
            return k
    return None

def denominator_diagnostic(df, n_bins=5):
    x = df["cv_denominator_mean_rank"]
    bins = min(n_bins, int(x.nunique()))
    if bins < 2:
        return pd.DataFrame([{
            "rank_bin": "all",
            "n_features": int(len(df)),
            "median_mean_rank": float(df["mean_rank"].median()),
            "median_sd_rank": float(df["sd_rank"].median()),
            "median_cv_rank_pct": float(df["cv_rank_pct"].median()),
            f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct": float(100.0 * df[f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct"].mean()),
        }])

    tmp = df.copy()
    tmp["rank_bin"] = pd.qcut(tmp["mean_rank"], q=bins, duplicates="drop")
    out = (
        tmp.groupby("rank_bin", as_index=False)
           .agg(
               n_features=("feature", "count"),
               median_mean_rank=("mean_rank", "median"),
               median_sd_rank=("sd_rank", "median"),
               median_cv_rank_pct=("cv_rank_pct", "median"),
               median_freq_topk_pct=(f"freq_in_top_{TOP_K}_pct", "median"),
               stable_freq_pct=(f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct", lambda s: 100.0 * float(np.mean(s))),
           )
    )
    out["rank_bin"] = out["rank_bin"].astype(str)
    return out

def bootstrap_rank_stability(shap_vals, feature_names, n_boot, seed, max_n, top_k):
    X = np.asarray(shap_vals, dtype=np.float32)
    if X.ndim != 2:
        raise ValueError(f"Expected 2D SHAP array, got shape {X.shape}")

    n, p = X.shape
    if p != len(feature_names):
        raise ValueError(f"SHAP has {p} features but feature_names has {len(feature_names)}.")

    rng = np.random.RandomState(seed)

    # Optional subsample for speed on very large cohorts
    if (max_n is not None) and (n > max_n):
        keep = rng.choice(n, size=max_n, replace=False)
        X = X[keep]
    n_eff = int(X.shape[0])

    absX = np.abs(X)
    point_importance = absX.mean(axis=0)

    point_order = np.argsort(-point_importance, kind="mergesort")
    point_rank = np.empty(p, dtype=np.float32)
    point_rank[point_order] = np.arange(1, p + 1, dtype=np.float32)

    rank_mat = np.empty((n_boot, p), dtype=np.float32)

    for b in range(n_boot):
        idx = rng.randint(0, n_eff, size=n_eff)
        imp_b = absX[idx].mean(axis=0)
        ord_b = np.argsort(-imp_b, kind="mergesort")
        rb = np.empty(p, dtype=np.float32)
        rb[ord_b] = np.arange(1, p + 1, dtype=np.float32)
        rank_mat[b] = rb

    mean_rank = rank_mat.mean(axis=0)
    median_rank = np.median(rank_mat, axis=0)
    sd_rank = rank_mat.std(axis=0, ddof=1 if n_boot > 1 else 0)
    cv_rank_pct = 100.0 * sd_rank / np.clip(mean_rank, 1e-12, None)
    q025 = np.quantile(rank_mat, 0.025, axis=0)
    q975 = np.quantile(rank_mat, 0.975, axis=0)

    # Key metric for Cox candidate stability
    freq_in_top_k = 100.0 * np.mean(rank_mat <= top_k, axis=0)

    df = pd.DataFrame({
        "feature": feature_names,
        "point_importance_mean_abs_shap": point_importance,
        "point_rank": point_rank,
        "mean_rank": mean_rank,
        "median_rank": median_rank,
        "sd_rank": sd_rank,
        "cv_rank_pct": cv_rank_pct,
        "rank_ci2p5": q025,
        "rank_ci97p5": q975,

        f"freq_in_top_{top_k}_pct": freq_in_top_k,

        # Explicit CV components (artifact audit)
        "cv_numerator_sd_rank": sd_rank,
        "cv_denominator_mean_rank": mean_rank,
        "mean_rank_pct_of_feature_count": 100.0 * mean_rank / p,

        # Stability flags
        f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct": freq_in_top_k >= FREQ_THRESHOLD_PCT,
        "stable_cv_lt_30pct": cv_rank_pct < CV_THRESHOLD_PCT,
        "stable_abs_sd_le_5": sd_rank <= SD_THRESHOLD_RANK,
        "stable_composite_for_cox": (
            (freq_in_top_k >= FREQ_THRESHOLD_PCT) &
            (q975 <= RANK_CI_UPPER_THRESHOLD)
        ),

        "n_bootstrap": int(n_boot),
        "n_patients_used": int(n_eff),
    })

    df = df.sort_values(["mean_rank", "point_rank"], ascending=[True, True]).reset_index(drop=True)
    return df, n_eff

# -----------------------------
# 2) Run for readmission + death
# -----------------------------
details_all = []
diags_all = []
summary_rows = []
saved_files = []

for outcome_code, cfg in OUTCOME_KEYS.items():
    shap_key = first_existing_key(shap_data, cfg["keys"])
    if shap_key is None:
        print(f"Skipping {cfg['label']}: no SHAP key found in {cfg['keys']}")
        continue

    shap_vals = np.asarray(shap_data[shap_key], dtype=np.float32)

    df_detail, n_used = bootstrap_rank_stability(
        shap_vals=shap_vals,
        feature_names=feature_names,
        n_boot=N_BOOTSTRAP,
        seed=BOOT_SEED,
        max_n=MAX_N,
        top_k=TOP_K
    )
    df_detail.insert(0, "outcome", cfg["label"])

    df_diag = denominator_diagnostic(df_detail, n_bins=5)
    df_diag.insert(0, "outcome", cfg["label"])

    # Quantify denominator artifact directly
    pearson_r = float(df_detail["cv_denominator_mean_rank"].corr(df_detail["cv_rank_pct"], method="pearson"))
    spearman_r = float(df_detail["cv_denominator_mean_rank"].corr(df_detail["cv_rank_pct"], method="spearman"))

    summary_rows.append({
        "outcome": cfg["label"],
        "n_features": int(len(df_detail)),
        "n_patients_used": int(n_used),
        f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct_n": int(df_detail[f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct"].sum()),
        f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct_pct": float(100.0 * df_detail[f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct"].mean()),
        "stable_composite_for_cox_n": int(df_detail["stable_composite_for_cox"].sum()),
        "stable_composite_for_cox_pct": float(100.0 * df_detail["stable_composite_for_cox"].mean()),
        "pearson_meanRank_vs_cvRankPct": pearson_r,
        "spearman_meanRank_vs_cvRankPct": spearman_r
    })

    # Save per outcome
    f_detail = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_{outcome_code}_rank_stability_{SOURCE_TAG_LOCAL}.csv")
    f_diag = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_{outcome_code}_rank_denominator_diag_{SOURCE_TAG_LOCAL}.csv")
    df_detail.to_csv(f_detail, index=False)
    df_diag.to_csv(f_diag, index=False)
    saved_files.extend([f_detail, f_diag])

    details_all.append(df_detail)
    diags_all.append(df_diag)

if not details_all:
    raise RuntimeError("No outcome SHAP arrays found in shap_data.")

# -----------------------------
# 3) Save combined
# -----------------------------
df_all = pd.concat(details_all, ignore_index=True)
df_diag_all = pd.concat(diags_all, ignore_index=True)
df_summary = pd.DataFrame(summary_rows)

f_all = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_rank_stability_all_{SOURCE_TAG_LOCAL}.csv")
f_diag_all = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_rank_denominator_diag_all_{SOURCE_TAG_LOCAL}.csv")
f_summary = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_rank_stability_summary_{SOURCE_TAG_LOCAL}.csv")
f_info = os.path.join(OUT_DIR_LOCAL, f"xgb8c_dual_rank_stability_run_info_{SOURCE_TAG_LOCAL}.json")

df_all.to_csv(f_all, index=False)
df_diag_all.to_csv(f_diag_all, index=False)
df_summary.to_csv(f_summary, index=False)

info = {
    "source_tag": SOURCE_TAG_LOCAL,
    "run_timestamp": RUN_TS,
    "n_bootstrap": int(N_BOOTSTRAP),
    "seed": int(BOOT_SEED),
    "max_n": int(MAX_N) if MAX_N is not None else None,
    "max_n_definition": "Maximum patients used in bootstrap stability analysis; if dataset size <= max_n, all patients are used.",
    "top_k": int(TOP_K),
    "primary_stability_rule": f"freq_in_top_{TOP_K}_pct >= {FREQ_THRESHOLD_PCT}",
    "secondary_metrics_reported": [
        "cv_rank_pct",
        "rank_ci2p5",
        "rank_ci97p5",
        "cv_numerator_sd_rank",
        "cv_denominator_mean_rank",
        "mean_rank_pct_of_feature_count"
    ]
}
with open(f_info, "w", encoding="utf-8") as f:
    json.dump(info, f, indent=2)

saved_files.extend([f_all, f_diag_all, f_summary, f_info])

# -----------------------------
# 4) Display
# -----------------------------
display(Markdown("### Rank Stability Summary"))
show_scroll_df(df_summary, height_px=260)

for out in df_all["outcome"].unique():
    display(Markdown(f"### {out}: Top {TOP_K} features by mean rank"))
    cols = [
        "feature",
        "point_rank",
        "mean_rank",
        f"freq_in_top_{TOP_K}_pct",
        "sd_rank",
        "rank_ci2p5",
        "rank_ci97p5",
        "cv_rank_pct",
        "cv_denominator_mean_rank",
        f"stable_freq_ge_{int(FREQ_THRESHOLD_PCT)}pct",
        "stable_composite_for_cox"
    ]
    top_df = (
        df_all[df_all["outcome"] == out]
        .sort_values("mean_rank", ascending=True)
        .head(TOP_K)[cols]
    )
    show_scroll_df(top_df, height_px=320)

print("Saved files:")
for p in saved_files:
    print(" -", p)

Step 8c: Bootstrap Rank Stability

  • Source tag: 20260306_1821
  • Run time: 20260306_1831
  • Bootstraps: 500
  • Seed: 2125
  • Top-K: 45
  • max_n: 90000 (max patients used per bootstrap replicate)
  • Primary stability: freq_in_top_45 >= 90%

Rank Stability Summary

outcome n_features n_patients_used stable_freq_ge_90pct_n stable_freq_ge_90pct_pct stable_composite_for_cox_n stable_composite_for_cox_pct pearson_meanRank_vs_cvRankPct spearman_meanRank_vs_cvRankPct
Readmission 56 70521 45 80.357143 45 80.357143 -0.363039 -0.311011
Death 56 70521 45 80.357143 45 80.357143 -0.348093 -0.293735

Readmission: Top 45 features by mean rank

feature point_rank mean_rank freq_in_top_45_pct sd_rank rank_ci2p5 rank_ci97p5 cv_rank_pct cv_denominator_mean_rank stable_freq_ge_90pct stable_composite_for_cox
primary_sub_mod_cocaine_paste 1.0 1.000000 100.0 0.000000 1.000000 1.000000 0.000000 1.000000 True True
adm_age_rec3 2.0 2.000000 100.0 0.000000 2.000000 2.000000 0.000000 2.000000 True True
porc_pobr 3.0 3.000000 100.0 0.000000 3.000000 3.000000 0.000000 3.000000 True True
sex_rec_woman 4.0 4.018000 100.0 0.133084 4.000000 4.000000 3.312200 4.018000 True True
plan_type_corr_pg_pr 5.0 4.982000 100.0 0.133084 5.000000 5.000000 2.671300 4.982000 True True
plan_type_corr_m_pr 6.0 6.006000 100.0 0.077304 6.000000 6.000000 1.287118 6.006000 True True
ethnicity 7.0 6.994000 100.0 0.077304 7.000000 7.000000 1.105294 6.994000 True True
dit_m 8.0 8.000000 100.0 0.000000 8.000000 8.000000 0.000000 8.000000 True True
eva_consumo 9.0 9.000000 100.0 0.000000 9.000000 9.000000 0.000000 9.000000 True True
ed_attainment_corr 10.0 10.000000 100.0 0.000000 10.000000 10.000000 0.000000 10.000000 True True
occupation_condition_corr24_unemployed 11.0 11.000000 100.0 0.000000 11.000000 11.000000 0.000000 11.000000 True True
dg_psiq_cie_10_dg 12.0 12.480000 100.0 0.500101 12.000000 13.000000 4.007217 12.480000 True True
sub_dep_icd10_status_drug_dependence 13.0 12.520000 100.0 0.500101 12.000000 13.000000 3.994414 12.520000 True True
polysubstance_strict 14.0 14.000000 100.0 0.000000 14.000000 14.000000 0.000000 14.000000 True True
primary_sub_mod_alcohol 15.0 15.000000 100.0 0.000000 15.000000 15.000000 0.000000 15.000000 True True
eva_sm 16.0 16.007999 100.0 0.089173 16.000000 16.000000 0.557055 16.007999 True True
primary_sub_mod_cocaine_powder 17.0 16.992001 100.0 0.089173 17.000000 17.000000 0.524796 16.992001 True True
evaluacindelprocesoteraputico 18.0 18.260000 100.0 0.439073 18.000000 19.000000 2.404560 18.260000 True True
tr_outcome_adm_discharge_rule_violation_undet 19.0 18.740000 100.0 0.439073 18.000000 19.000000 2.342970 18.740000 True True
prim_sub_freq_rec 20.0 20.000000 100.0 0.000000 20.000000 20.000000 0.000000 20.000000 True True
eva_transgnorma 21.0 21.000000 100.0 0.000000 21.000000 21.000000 0.000000 21.000000 True True
tr_outcome_referral 23.0 22.497999 100.0 0.553723 22.000000 24.000000 2.461209 22.497999 True True
eva_ocupacion 22.0 22.532000 100.0 0.499474 22.000000 23.000000 2.216732 22.532000 True True
adm_motive_justice_sector 24.0 24.068001 100.0 0.357254 23.000000 25.000000 1.484354 24.068001 True True
cohabitation_with_couple_children 25.0 24.902000 100.0 0.297613 24.000000 25.000000 1.195135 24.902000 True True
tenure_status_household 26.0 26.000000 100.0 0.000000 26.000000 26.000000 0.000000 26.000000 True True
any_violence_1_domestic_violence_sex_abuse 27.0 27.002001 100.0 0.044721 27.000000 27.000000 0.165622 27.002001 True True
eva_fam 28.0 27.997999 100.0 0.044721 28.000000 28.000000 0.159730 27.997999 True True
dx_f6_personality 29.0 29.040001 100.0 0.196156 29.000000 30.000000 0.675468 29.040001 True True
tr_outcome_dropout 30.0 29.959999 100.0 0.196156 29.000000 30.000000 0.654726 29.959999 True True
any_phys_dx 31.0 31.000000 100.0 0.000000 31.000000 31.000000 0.000000 31.000000 True True
adm_motive_sanitary_sector 32.0 32.000000 100.0 0.000000 32.000000 32.000000 0.000000 32.000000 True True
urbanicity_cat 33.0 33.026001 100.0 0.159295 33.000000 33.525024 0.482332 33.026001 True True
eva_relinterp 34.0 33.973999 100.0 0.159295 33.474998 34.000000 0.468873 33.973999 True True
occupation_condition_corr24_inactive 35.0 35.000000 100.0 0.000000 35.000000 35.000000 0.000000 35.000000 True True
first_sub_used_alcohol 36.0 36.000000 100.0 0.000000 36.000000 36.000000 0.000000 36.000000 True True
cohabitation_others 37.0 37.000000 100.0 0.000000 37.000000 37.000000 0.000000 37.000000 True True
plan_type_corr_pg_pai 38.0 38.000000 100.0 0.000000 38.000000 38.000000 0.000000 38.000000 True True
cohabitation_family_of_origin 39.0 39.000000 100.0 0.000000 39.000000 39.000000 0.000000 39.000000 True True
marital_status_rec_single 40.0 40.000000 100.0 0.000000 40.000000 40.000000 0.000000 40.000000 True True
eva_fisica 41.0 41.000000 100.0 0.000000 41.000000 41.000000 0.000000 41.000000 True True
dx_f3_mood 42.0 42.000000 100.0 0.000000 42.000000 42.000000 0.000000 42.000000 True True
dg_psiq_cie_10_instudy 43.0 43.304001 100.0 0.460442 43.000000 44.000000 1.063278 43.304001 True True
plan_type_corr_m_pai 44.0 43.695999 100.0 0.460442 43.000000 44.000000 1.053740 43.695999 True True
adm_motive_another_sud_facility_fonodrogas_senda_previene 45.0 45.000000 100.0 0.000000 45.000000 45.000000 0.000000 45.000000 True True

Death: Top 45 features by mean rank

feature point_rank mean_rank freq_in_top_45_pct sd_rank rank_ci2p5 rank_ci97p5 cv_rank_pct cv_denominator_mean_rank stable_freq_ge_90pct stable_composite_for_cox
adm_age_rec3 1.0 1.000000 100.0 0.000000 1.0 1.0 0.000000 1.000000 True True
primary_sub_mod_alcohol 2.0 2.000000 100.0 0.000000 2.0 2.0 0.000000 2.000000 True True
prim_sub_freq_rec 3.0 3.000000 100.0 0.000000 3.0 3.0 0.000000 3.000000 True True
occupation_condition_corr24_unemployed 4.0 4.000000 100.0 0.000000 4.0 4.0 0.000000 4.000000 True True
any_phys_dx 5.0 5.000000 100.0 0.000000 5.0 5.0 0.000000 5.000000 True True
eva_ocupacion 6.0 6.000000 100.0 0.000000 6.0 6.0 0.000000 6.000000 True True
cohabitation_with_couple_children 7.0 7.000000 100.0 0.000000 7.0 7.0 0.000000 7.000000 True True
eva_fisica 8.0 8.002000 100.0 0.044721 8.0 8.0 0.558877 8.002000 True True
adm_motive_sanitary_sector 9.0 8.998000 100.0 0.044721 9.0 9.0 0.497014 8.998000 True True
ed_attainment_corr 10.0 10.154000 100.0 0.361310 10.0 11.0 3.558305 10.154000 True True
sex_rec_woman 11.0 10.846000 100.0 0.361310 10.0 11.0 3.331277 10.846000 True True
first_sub_used_alcohol 12.0 12.308000 100.0 0.462129 12.0 13.0 3.754705 12.308000 True True
occupation_condition_corr24_inactive 13.0 12.692000 100.0 0.462129 12.0 13.0 3.641105 12.692000 True True
primary_sub_mod_cocaine_paste 14.0 14.000000 100.0 0.000000 14.0 14.0 0.000000 14.000000 True True
dit_m 15.0 15.006000 100.0 0.077304 15.0 15.0 0.515156 15.006000 True True
plan_type_corr_pg_pr 16.0 15.998000 100.0 0.118423 16.0 16.0 0.740238 15.998000 True True
porc_pobr 17.0 17.139999 100.0 0.353057 17.0 18.0 2.059844 17.139999 True True
eva_sm 18.0 17.856001 100.0 0.351442 17.0 18.0 1.968203 17.856001 True True
polysubstance_strict 19.0 19.174000 100.0 0.379490 19.0 20.0 1.979188 19.174000 True True
tenure_status_household 20.0 19.826000 100.0 0.379490 19.0 20.0 1.914100 19.826000 True True
marital_status_rec_separated_divorced_annulled_widowed 21.0 21.000000 100.0 0.000000 21.0 21.0 0.000000 21.000000 True True
tipo_de_vivienda_rec2_other_unknown 22.0 22.000000 100.0 0.000000 22.0 22.0 0.000000 22.000000 True True
sub_dep_icd10_status_drug_dependence 23.0 23.000000 100.0 0.000000 23.0 23.0 0.000000 23.000000 True True
primary_sub_mod_cocaine_powder 24.0 24.000000 100.0 0.000000 24.0 24.0 0.000000 24.000000 True True
marital_status_rec_single 25.0 25.000000 100.0 0.000000 25.0 25.0 0.000000 25.000000 True True
tr_outcome_dropout 26.0 26.042000 100.0 0.200790 26.0 27.0 0.771025 26.042000 True True
tr_outcome_adm_discharge_rule_violation_undet 27.0 27.018000 100.0 0.319186 26.0 28.0 1.181384 27.018000 True True
eva_fam 28.0 27.940001 100.0 0.237724 27.0 28.0 0.850838 27.940001 True True
dg_psiq_cie_10_instudy 29.0 29.002001 100.0 0.044721 29.0 29.0 0.154201 29.002001 True True
tr_outcome_referral 30.0 29.997999 100.0 0.044721 30.0 30.0 0.149081 29.997999 True True
dx_f_any_severe_mental 31.0 31.038000 100.0 0.220575 31.0 32.0 0.710660 31.038000 True True
eva_transgnorma 32.0 31.986000 100.0 0.223392 31.0 32.0 0.698407 31.986000 True True
evaluacindelprocesoteraputico 33.0 32.976002 100.0 0.153203 33.0 33.0 0.464588 32.976002 True True
eva_consumo 34.0 34.000000 100.0 0.000000 34.0 34.0 0.000000 34.000000 True True
plan_type_corr_m_pr 35.0 35.133999 100.0 0.340994 35.0 36.0 0.970552 35.133999 True True
ethnicity 36.0 35.866001 100.0 0.340994 35.0 36.0 0.950743 35.866001 True True
dx_f6_personality 37.0 37.428001 100.0 0.495283 37.0 38.0 1.323295 37.428001 True True
any_violence_1_domestic_violence_sex_abuse 38.0 37.571999 100.0 0.495283 37.0 38.0 1.318223 37.571999 True True
eva_relinterp 39.0 39.000000 100.0 0.000000 39.0 39.0 0.000000 39.000000 True True
dg_psiq_cie_10_dg 40.0 40.000000 100.0 0.000000 40.0 40.0 0.000000 40.000000 True True
plan_type_corr_pg_pai 41.0 41.000000 100.0 0.000000 41.0 41.0 0.000000 41.000000 True True
adm_motive_another_sud_facility_fonodrogas_senda_previene 42.0 42.001999 100.0 0.044721 42.0 42.0 0.106474 42.001999 True True
urbanicity_cat 43.0 42.998001 100.0 0.044721 43.0 43.0 0.104008 42.998001 True True
cohabitation_others 44.0 44.000000 100.0 0.000000 44.0 44.0 0.000000 44.000000 True True
dx_f3_mood 45.0 45.000000 100.0 0.000000 45.0 45.0 0.000000 45.000000 True True
  1. Shifts focus from rank stability to predictor profiling.

  2. Ranks features by mean absolute SHAP.

  3. Adds importance tiers (Dominant to Minor).

  4. Merges Step 8 stability into ranking table.

  5. Flags predictors stable at ≥90% Top-K.

  6. Adds exploratory SHAP direction analysis.

  7. Uses bootstrap CI to label risk direction.

  8. Identifies risk-decreasing predictors separately.

  9. Produces structured CSV and Excel outputs.

  10. Keeps analysis data-driven, no interpretation.

Assumptions:

  1. Mean |SHAP| reflects overall influence.

  2. Stability from Step 8 remains valid here.

  3. Direction via SHAP delta is informative.

  4. Quartile contrasts capture numeric effects.

  5. Direction labels are exploratory, not causal.

Code
#@title Step 9: Predictor Analysis (No Clinical Interpretation, Step-8 Aligned)

import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
from IPython.display import display, Markdown, HTML

# -----------------------------
# 0) Config + nb_print definition
# -----------------------------
# Define nb_print for notebook output
def nb_print(*args):
    """Print function for notebooks that works with display system."""
    display(HTML(" ".join(str(a) for a in args)))

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))
IN_DIR = os.path.join(PROJECT_ROOT, "_out")
OUT_DIR = os.path.join(PROJECT_ROOT, "_out")
os.makedirs(IN_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

nb_print("PROJECT_ROOT:", PROJECT_ROOT)
nb_print("IN_DIR:", IN_DIR)
nb_print("OUT_DIR:", OUT_DIR)
nb_print("SHAP files found:", len(glob.glob(os.path.join(IN_DIR, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl"))))

TOP_N_KEY = 15  # how many top predictors to show
TOP_N_PROTECTIVE = 10  # max number of predictors shown in risk-decreasing
TOP_K_STABILITY = 15  # top-k stability columns from 8c
STABILITY_FREQ_THRESHOLD = 90.0  # if it appears in top-k at least 90% bootstraps

# Direction analysis (exploratory) for top-N by importance only
DIRECTION_TOP_N = 20  # only top 20 features by importance get direction analysis
DIRECTION_N_BOOT = 500  # number of bootstraps for direction analysis
DIRECTION_SEED = 2125
DIRECTION_MAX_N = 71000  # cap on patients used in direction bootstrap
DIRECTION_MIN_VALID_N = 300  # minimum valid rows required to run direction for a feature
DIRECTION_MIN_GROUP_N = 30  # minimum size for each comparison group

OUTCOME_CFG = {
    "readm": {
        "label": "Readmission",
        "shap_key_candidates": ["shap_r_all", "shap_readm_all"],
    },
    "death": {
        "label": "Death",
        "shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],
    },
}

# FIXED: Use same pattern as Step 8 to match both _20260305_2232.pkl and _20260305_2232_mar26.pkl
TS_RE = re.compile(r"(\d{8}_\d{4})")

# -----------------------------
# 1) Helpers
# -----------------------------
def _tag_from_path(path):
    m = TS_RE.search(os.path.basename(path))
    return m.group(1) if m else None

def pick_latest_shap_bundle(in_dir):
    """Pick latest SHAP bundle, handling both _mar26 and non-_mar26 suffixes."""
    shap_files = glob.glob(os.path.join(in_dir, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl"))
    candidates = []
    for f in shap_files:
        tag = _tag_from_path(f)
        if not tag:
            continue
        dt = pd.to_datetime(tag, format="%Y%m%d_%H%M", errors="coerce")
        if pd.notna(dt):
            candidates.append((dt, tag, f))
    if not candidates:
        raise FileNotFoundError(f"No xgb6_corr_DUAL_SHAP_Aggregated_*.pkl found in '{in_dir}'.")
    candidates.sort(key=lambda x: x[0])
    return candidates[-1][1], candidates[-1][2]

def find_first_key(dct, keys):
    for k in keys:
        if k in dct:
            return k
    return None

def assign_importance_tier(rel_pct):
    if rel_pct >= 50.0:
        return "DOMINANT (>=50% max)"
    if rel_pct >= 20.0:
        return "MAJOR (20-50% max)"
    if rel_pct >= 10.0:
        return "MODERATE (10-20% max)"
    return "MINOR (<10% max)"

def pick_freq_col(columns, top_k=30):
    exact = f"freq_in_top_{top_k}_pct"
    if exact in columns:
        return exact
    cands = []
    for c in columns:
        m = re.match(r"^freq_in_top_(\d+)_pct$", str(c))
        if m:
            k = int(m.group(1))
            cands.append((abs(k - top_k), k, c))
    if not cands:
        return None
    cands.sort(key=lambda x: (x[0], x[1]))
    return cands[0][2]

def load_stability_table(outcome_code, source_tag, out_dir):
    """Load stability table from Step 8c, trying multiple naming patterns."""
    # FIXED: Try both xgb8c_dual_ (Step 8c) and xgb8_dual_ (older) prefixes
    exact_patterns = [
        os.path.join(out_dir, f"xgb8c_dual_{outcome_code}_rank_stability_{source_tag}.csv"),
        os.path.join(out_dir, f"xgb8c_dual_{outcome_code}_rank_stability_cv_{source_tag}.csv"),
        os.path.join(out_dir, f"xgb8_dual_{outcome_code}_rank_stability_{source_tag}.csv"),
        os.path.join(out_dir, f"xgb8_dual_{outcome_code}_rank_stability_cv_{source_tag}.csv"),
    ]
    for p in exact_patterns:
        if os.path.exists(p):
            return pd.read_csv(p), p

    # Try wildcard patterns
    wildcard_patterns = [
        os.path.join(out_dir, f"xgb8c_dual_{outcome_code}_rank_stability_*.csv"),
        os.path.join(out_dir, f"xgb8c_dual_{outcome_code}_rank_stability_cv_*.csv"),
        os.path.join(out_dir, f"xgb8_dual_{outcome_code}_rank_stability_*.csv"),
        os.path.join(out_dir, f"xgb8_dual_{outcome_code}_rank_stability_cv_*.csv"),
    ]
    files = []
    for pat in wildcard_patterns:
        files.extend(glob.glob(pat))
    if not files:
        return None, None

    files = sorted(files, key=os.path.getmtime)
    p = files[-1]
    return pd.read_csv(p), p

def bootstrap_diff_means(group_hi, group_lo, n_boot=300, seed=2125):
    hi = np.asarray(group_hi, dtype=float).ravel()
    lo = np.asarray(group_lo, dtype=float).ravel()
    hi = hi[np.isfinite(hi)]
    lo = lo[np.isfinite(lo)]
    if len(hi) == 0 or len(lo) == 0:
        return np.nan, np.nan, np.nan, 0

    rng = np.random.RandomState(seed)
    point = float(np.mean(hi) - np.mean(lo))
    bvals = np.empty(n_boot, dtype=float)

    n_hi = len(hi)
    n_lo = len(lo)

    for b in range(n_boot):
        s_hi = hi[rng.randint(0, n_hi, size=n_hi)]
        s_lo = lo[rng.randint(0, n_lo, size=n_lo)]
        bvals[b] = np.mean(s_hi) - np.mean(s_lo)

    ci_low, ci_high = np.quantile(bvals, [0.025, 0.975])
    return point, float(ci_low), float(ci_high), int(n_hi + n_lo)

def direction_from_ci(ci_low, ci_high):
    if np.isfinite(ci_low) and ci_low > 0:
        return "Risk-Increasing"
    if np.isfinite(ci_high) and ci_high < 0:
        return "Risk-Decreasing"
    if np.isfinite(ci_low) and np.isfinite(ci_high):
        return "Mixed/Uncertain"
    return "Unknown"

def direction_bootstrap_feature(x_raw, s_raw, n_boot, seed, max_n, min_valid_n, min_group_n):
    x = pd.to_numeric(pd.Series(x_raw), errors="coerce").to_numpy(dtype=float)
    s = np.asarray(s_raw, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(s)
    x = x[m]
    s = s[m]

    if len(x) < min_valid_n:
        return {
            "direction_method": "insufficient_n",
            "delta_point": np.nan,
            "delta_ci_low": np.nan,
            "delta_ci_high": np.nan,
            "direction_n": int(len(x)),
            "direction_label": "Unknown",
        }

    rng = np.random.RandomState(seed)
    if (max_n is not None) and (len(x) > max_n):
        keep = rng.choice(len(x), size=max_n, replace=False)
        x = x[keep]
        s = s[keep]

    vals = np.unique(x[np.isfinite(x)])
    vals_round = np.unique(np.round(vals, 10))
    is_binary = (len(vals_round) <= 2) and set(vals_round).issubset({0.0, 1.0})

    if is_binary:
        g_lo = s[x == 0.0]
        g_hi = s[x == 1.0]
        method = "binary_x1_minus_x0"
    else:
        q1 = np.nanquantile(x, 0.25)
        q3 = np.nanquantile(x, 0.75)
        g_lo = s[x <= q1]   # Q1
        g_hi = s[x >= q3]   # Q4
        method = "quartile_q4_minus_q1"

    if (len(g_lo) < min_group_n) or (len(g_hi) < min_group_n):
        return {
            "direction_method": method,
            "delta_point": np.nan,
            "delta_ci_low": np.nan,
            "delta_ci_high": np.nan,
            "direction_n": int(len(x)),
            "direction_label": "Unknown",
        }

    point, lo, hi, n_used = bootstrap_diff_means(g_hi, g_lo, n_boot=n_boot, seed=seed)
    return {
        "direction_method": method,
        "delta_point": point,
        "delta_ci_low": lo,
        "delta_ci_high": hi,
        "direction_n": int(n_used),
        "direction_label": direction_from_ci(lo, hi),
    }

def format_key_list(df, n=10, freq_col=None):
    rows = []
    for i, r in enumerate(df.head(n).itertuples(index=False), 1):
        if freq_col and hasattr(r, freq_col) and pd.notna(getattr(r, freq_col)):
            rows.append(f"{i}. {r.feature} (|SHAP|={r.mean_abs_shap_log_hazard:.4f}, freq={getattr(r, freq_col):.1f}%)")
        else:
            rows.append(f"{i}. {r.feature} (|SHAP|={r.mean_abs_shap_log_hazard:.4f})")
    return "\n".join(rows)

# -----------------------------
# 2) Load SHAP source
# -----------------------------
source_tag, shap_file = pick_latest_shap_bundle(IN_DIR)
nb_print(f"Using SHAP bundle: <b>{source_tag}</b>")

with open(shap_file, "rb") as f:
    shap_data = pickle.load(f)

if "X_all" not in shap_data or "feature_names" not in shap_data:
    raise KeyError("SHAP file missing X_all and/or feature_names.")

X_all = shap_data["X_all"]
feature_names = list(shap_data["feature_names"])

if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(X_all, columns=feature_names)

if list(X_all.columns) != feature_names:
    X_all = X_all.reindex(columns=feature_names)

# FIXED: Check for duplicate feature names
if len(feature_names) != len(set(feature_names)):
    dupes = [f for f in feature_names if feature_names.count(f) > 1]
    raise ValueError(f"Duplicate feature names detected: {set(dupes)}")

X_num = X_all.apply(pd.to_numeric, errors="coerce")
X_mat = X_num.to_numpy(dtype=float)
f2i = {f: i for i, f in enumerate(feature_names)}

display(Markdown(
    f"### Step 9 Predictor Analysis\n"
    f"- Source tag: **{source_tag}**\n"
    f"- Patients: **{X_all.shape[0]}**\n"
    f"- Features: **{X_all.shape[1]}**\n"
    f"- Direction method: **bootstrap CI on SHAP delta** (top {DIRECTION_TOP_N} features per outcome)"
))

# -----------------------------
# 3) Build per-outcome predictor tables
# -----------------------------
per_outcome_tables = []
tier_rows = []
summary_rows = []
take_home_rows = []
protective_rows = []
stability_files_used = []

for oc, cfg in OUTCOME_CFG.items():
    shap_key = find_first_key(shap_data, cfg["shap_key_candidates"])
    if shap_key is None:
        nb_print(f"Skipping {cfg['label']}: SHAP key not found ({cfg['shap_key_candidates']}).")
        continue

    S = np.asarray(shap_data[shap_key], dtype=float)
    if S.shape != X_mat.shape:
        raise ValueError(f"{cfg['label']} SHAP shape mismatch: {S.shape} vs {X_mat.shape}")

    imp = np.abs(S).mean(axis=0)
    df = pd.DataFrame({
        "feature": feature_names,
        "mean_abs_shap_log_hazard": imp,
    }).sort_values("mean_abs_shap_log_hazard", ascending=False).reset_index(drop=True)

    df["rank"] = np.arange(1, len(df) + 1, dtype=int)
    max_imp = float(df["mean_abs_shap_log_hazard"].max()) if len(df) else np.nan
    df["rel_to_max_pct"] = 100.0 * df["mean_abs_shap_log_hazard"] / max_imp if np.isfinite(max_imp) and max_imp > 0 else np.nan
    df["importance_tier"] = df["rel_to_max_pct"].apply(assign_importance_tier)

    # Optional stability merge from Step 8c
    stab_df, stab_path = load_stability_table(oc, source_tag, OUT_DIR)
    freq_col = None
    if stab_df is not None and len(stab_df):
        # FIXED: Step 8c uses 'outcome' column with full label, not 'outcome_code'
        if "outcome" in stab_df.columns:
            tmp = stab_df[stab_df["outcome"].astype(str).str.lower() == cfg["label"].lower()]
            if len(tmp):
                stab_df = tmp

        freq_col = pick_freq_col(stab_df.columns, TOP_K_STABILITY)
        keep = ["feature"]
        for c in ["mean_rank", "cv_rank_pct", "rank_ci2p5", "rank_ci97p5", freq_col]:
            if c and c in stab_df.columns:
                keep.append(c)

        stab_df = stab_df[keep].drop_duplicates("feature")
        df = df.merge(stab_df, on="feature", how="left")

        if freq_col and freq_col in df.columns:
            df[f"stable_freq_ge_{int(STABILITY_FREQ_THRESHOLD)}pct"] = df[freq_col] >= STABILITY_FREQ_THRESHOLD

        stability_files_used.append(stab_path)

    # Direction bootstrap for top-N only (exploratory)
    df["direction_method"] = "not_evaluated"
    df["delta_point"] = np.nan
    df["delta_ci_low"] = np.nan
    df["delta_ci_high"] = np.nan
    df["direction_n"] = np.nan
    df["direction_label"] = "Not evaluated"

    top_features_for_direction = df.head(DIRECTION_TOP_N)["feature"].tolist()
    for feat in top_features_for_direction:
        j = f2i[feat]
        # FIXED: Use hash-based seed to avoid overflow
        feat_seed = (DIRECTION_SEED + hash(feat) % 10000) % (2**31)
        res = direction_bootstrap_feature(
            x_raw=X_mat[:, j],
            s_raw=S[:, j],
            n_boot=DIRECTION_N_BOOT,
            seed=feat_seed,
            max_n=DIRECTION_MAX_N,
            min_valid_n=DIRECTION_MIN_VALID_N,
            min_group_n=DIRECTION_MIN_GROUP_N,
        )
        idx = df.index[df["feature"] == feat][0]
        df.loc[idx, "direction_method"] = res["direction_method"]
        df.loc[idx, "delta_point"] = res["delta_point"]
        df.loc[idx, "delta_ci_low"] = res["delta_ci_low"]
        df.loc[idx, "delta_ci_high"] = res["delta_ci_high"]
        df.loc[idx, "direction_n"] = res["direction_n"]
        df.loc[idx, "direction_label"] = res["direction_label"]

    # Dominance stats
    top1 = float(df.loc[0, "mean_abs_shap_log_hazard"]) if len(df) > 0 else np.nan
    top2 = float(df.loc[1, "mean_abs_shap_log_hazard"]) if len(df) > 1 else np.nan
    top2_ratio = (top1 / top2) if np.isfinite(top1) and np.isfinite(top2) and top2 > 0 else np.nan

    df.insert(0, "outcome", cfg["label"])
    per_outcome_tables.append(df)

    # Tier summary
    tier_order = ["DOMINANT (>=50% max)", "MAJOR (20-50% max)", "MODERATE (10-20% max)", "MINOR (<10% max)"]
    for t in tier_order:
        dft = df[df["importance_tier"] == t].copy()
        if len(dft) == 0:
            continue
        tier_rows.append({
            "outcome": cfg["label"],
            "tier": t,
            "n_features": int(len(dft)),
            "importance_min": float(dft["mean_abs_shap_log_hazard"].min()),
            "importance_max": float(dft["mean_abs_shap_log_hazard"].max()),
            "features": ", ".join(dft["feature"].tolist()),
        })

    summary_row = {
        "outcome": cfg["label"],
        "n_features": int(len(df)),
        "top1_feature": df.loc[0, "feature"] if len(df) else None,
        "top1_importance": top1,
        "top2_feature": df.loc[1, "feature"] if len(df) > 1 else None,
        "top2_importance": top2,
        "top1_to_top2_ratio": top2_ratio,
    }

    if freq_col and freq_col in df.columns:
        summary_row[f"top{TOP_N_KEY}_stable_freq_ge_{int(STABILITY_FREQ_THRESHOLD)}pct_n"] = int(
            df.head(TOP_N_KEY)[f"stable_freq_ge_{int(STABILITY_FREQ_THRESHOLD)}pct"].fillna(False).sum()
        )
    summary_rows.append(summary_row)

    # Take-home list (data only, no interpretation)
    key_txt = format_key_list(df, n=TOP_N_KEY, freq_col=freq_col if freq_col in df.columns else None)
    take_home_rows.append({
        "Risk Profile": cfg["label"],
        "Key Predictors (Ranked; mean |SHAP| on log-hazard scale)": key_txt,
        "Notes": "Purely data-driven ranking from aggregated CV-SHAP."
    })

    # Protective list (direction exploratory only)
    prot = df[df["direction_label"] == "Risk-Decreasing"].sort_values("mean_abs_shap_log_hazard", ascending=False)
    prot_txt = format_key_list(prot, n=TOP_N_PROTECTIVE, freq_col=freq_col if freq_col in prot.columns else None) if len(prot) else "None identified among evaluated features."
    protective_rows.append({
        "Outcome": cfg["label"],
        "Top Risk-Decreasing Predictors (exploratory)": prot_txt,
        "Direction basis": "Bootstrap CI of SHAP delta (Q4-Q1 for numeric, x=1-x=0 for binary)."
    })

# -----------------------------
# 4) Combine + display
# -----------------------------
if not per_outcome_tables:
    raise RuntimeError("No outcome tables created (missing SHAP arrays).")

predictor_table = pd.concat(per_outcome_tables, ignore_index=True)
tier_table = pd.DataFrame(tier_rows)
summary_table = pd.DataFrame(summary_rows)
take_home_msg = pd.DataFrame(take_home_rows)
protective_msg = pd.DataFrame(protective_rows)

display(Markdown("### Predictor Table"))
display(predictor_table.head(20))

display(Markdown("### Importance Tier Summary"))
display(tier_table)

display(Markdown("### Outcome Summary"))
display(summary_table)

display(Markdown("### Key Predictors (No Clinical Interpretation)"))
pd.set_option("display.max_colwidth", None)
display(take_home_msg.style.set_properties(**{"text-align": "left", "white-space": "pre-wrap"}))

display(Markdown("### Risk-Decreasing Predictors (Exploratory Direction Only)"))
display(protective_msg.style.set_properties(**{"text-align": "left", "white-space": "pre-wrap"}))

# Quick sanity check for known features (if present)
check_features = ["porc_pobr", "adm_age_rec3", "sex_rec_woman"]
check_df = predictor_table[predictor_table["feature"].isin(check_features)][
    ["outcome", "feature", "rank", "mean_abs_shap_log_hazard", "direction_method", "delta_point", "delta_ci_low", "delta_ci_high", "direction_label"]
].sort_values(["outcome", "rank"])
if len(check_df):
    display(Markdown("### Direction Sanity Check"))
    display(check_df)

# -----------------------------
# 5) Save outputs
# -----------------------------
base = f"xgb9_dual_predictor_analysis_no_clinical_{source_tag}"
f_pred = os.path.join(OUT_DIR, f"{base}.csv")
f_tier = os.path.join(OUT_DIR, f"xgb9_dual_importance_tiers_{source_tag}.csv")
f_sum = os.path.join(OUT_DIR, f"xgb9_dual_predictor_summary_{source_tag}.csv")
f_take = os.path.join(OUT_DIR, f"xgb9_dual_takehome_no_clinical_{source_tag}.csv")
f_prot = os.path.join(OUT_DIR, f"xgb9_dual_protective_no_clinical_{source_tag}.csv")
f_xlsx = os.path.join(OUT_DIR, f"{base}.xlsx")
f_info = os.path.join(OUT_DIR, f"xgb9_dual_predictor_run_info_{source_tag}.json")

predictor_table.to_csv(f_pred, index=False)
tier_table.to_csv(f_tier, index=False)
summary_table.to_csv(f_sum, index=False)
take_home_msg.to_csv(f_take, index=False)
protective_msg.to_csv(f_prot, index=False)

with pd.ExcelWriter(f_xlsx) as writer:
    predictor_table.to_excel(writer, sheet_name="Predictors", index=False)
    tier_table.to_excel(writer, sheet_name="Importance_Tiers", index=False)
    summary_table.to_excel(writer, sheet_name="Summary", index=False)
    take_home_msg.to_excel(writer, sheet_name="TakeHome_NoClinical", index=False)
    protective_msg.to_excel(writer, sheet_name="Protective_NoClinical", index=False)

run_info = {
    "source_tag": source_tag,
    "shap_file": shap_file,
    "stability_files_used": sorted(set([p for p in stability_files_used if p])),
    "top_n_key": int(TOP_N_KEY),
    "top_k_stability": int(TOP_K_STABILITY),
    "stability_freq_threshold_pct": float(STABILITY_FREQ_THRESHOLD),
    "direction_analysis": {
        "top_n_features_per_outcome": int(DIRECTION_TOP_N),
        "n_bootstrap": int(DIRECTION_N_BOOT),
        "seed": int(DIRECTION_SEED),
        "max_n": int(DIRECTION_MAX_N),
        "min_valid_n": int(DIRECTION_MIN_VALID_N),
        "min_group_n": int(DIRECTION_MIN_GROUP_N),
        "rule": "Risk-Increasing if CI_low>0; Risk-Decreasing if CI_high<0; else Mixed/Uncertain."
    },
    "notes": [
        "No clinical interpretation text included.",
        "SHAP magnitude is on log-hazard scale.",
        "Direction labels are exploratory and not causal."
    ]
}
with open(f_info, "w", encoding="utf-8") as f:
    json.dump(run_info, f, indent=2)

# FIXED: Use nb_print instead of print
nb_print("<b>Saved:</b>")
for p in [f_pred, f_tier, f_sum, f_take, f_prot, f_xlsx, f_info]:
    nb_print(" -", p)
PROJECT_ROOT: G:\My Drive\Alvacast\SISTRAT 2023\cons
IN_DIR: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out
OUT_DIR: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out
SHAP files found: 2
Using SHAP bundle: 20260306_1821

Step 9 Predictor Analysis

  • Source tag: 20260306_1821
  • Patients: 70521
  • Features: 56
  • Direction method: bootstrap CI on SHAP delta (top 20 features per outcome)

Predictor Table

outcome feature mean_abs_shap_log_hazard rank rel_to_max_pct importance_tier mean_rank cv_rank_pct rank_ci2p5 rank_ci97p5 freq_in_top_45_pct stable_freq_ge_90pct direction_method delta_point delta_ci_low delta_ci_high direction_n direction_label
0 Readmission primary_sub_mod_cocaine_paste 0.090015 1 100.000000 DOMINANT (>=50% max) 1.000 0.000000 1.0 1.0 100.0 True binary_x1_minus_x0 0.184797 0.184398 0.185186 70521.0 Risk-Increasing
1 Readmission adm_age_rec3 0.078743 2 87.477933 DOMINANT (>=50% max) 2.000 0.000000 2.0 2.0 100.0 True quartile_q4_minus_q1 -0.163150 -0.166379 -0.159930 35281.0 Risk-Decreasing
2 Readmission porc_pobr 0.073050 3 81.153816 DOMINANT (>=50% max) 3.000 0.000000 3.0 3.0 100.0 True quartile_q4_minus_q1 0.119147 0.117414 0.120873 35755.0 Risk-Increasing
3 Readmission sex_rec_woman 0.070474 4 78.291204 DOMINANT (>=50% max) 4.018 3.312200 4.0 4.0 100.0 True binary_x1_minus_x0 0.183644 0.183199 0.184149 70521.0 Risk-Increasing
4 Readmission plan_type_corr_pg_pr 0.069722 5 77.456557 DOMINANT (>=50% max) 4.982 2.671300 5.0 5.0 100.0 True binary_x1_minus_x0 0.324924 0.323363 0.326535 70521.0 Risk-Increasing
5 Readmission plan_type_corr_m_pr 0.056074 6 62.293789 DOMINANT (>=50% max) 6.006 1.287118 6.0 6.0 100.0 True binary_x1_minus_x0 0.591616 0.587229 0.596070 70521.0 Risk-Increasing
6 Readmission ethnicity 0.054893 7 60.982778 DOMINANT (>=50% max) 6.994 1.105294 7.0 7.0 100.0 True binary_x1_minus_x0 0.442310 0.440231 0.444277 70521.0 Risk-Increasing
7 Readmission dit_m 0.047543 8 52.816541 DOMINANT (>=50% max) 8.000 0.000000 8.0 8.0 100.0 True quartile_q4_minus_q1 0.078665 0.077371 0.079864 35358.0 Risk-Increasing
8 Readmission eva_consumo 0.046565 9 51.730951 DOMINANT (>=50% max) 9.000 0.000000 9.0 9.0 100.0 True quartile_q4_minus_q1 0.130131 0.129766 0.130504 50240.0 Risk-Increasing
9 Readmission ed_attainment_corr 0.041594 10 46.208175 MAJOR (20-50% max) 10.000 0.000000 10.0 10.0 100.0 True quartile_q4_minus_q1 -0.105240 -0.105754 -0.104720 70369.0 Risk-Decreasing
10 Readmission occupation_condition_corr24_unemployed 0.036454 11 40.498109 MAJOR (20-50% max) 11.000 0.000000 11.0 11.0 100.0 True binary_x1_minus_x0 0.080797 0.080547 0.081077 70521.0 Risk-Increasing
11 Readmission dg_psiq_cie_10_dg 0.028379 12 31.527343 MAJOR (20-50% max) 12.480 4.007217 12.0 13.0 100.0 True binary_x1_minus_x0 0.057477 0.057349 0.057613 70521.0 Risk-Increasing
12 Readmission sub_dep_icd10_status_drug_dependence 0.028375 13 31.522556 MAJOR (20-50% max) 12.520 3.994414 12.0 13.0 100.0 True binary_x1_minus_x0 0.072180 0.071856 0.072502 70521.0 Risk-Increasing
13 Readmission polysubstance_strict 0.025277 14 28.080779 MAJOR (20-50% max) 14.000 0.000000 14.0 14.0 100.0 True binary_x1_minus_x0 0.065937 0.065635 0.066265 70521.0 Risk-Increasing
14 Readmission primary_sub_mod_alcohol 0.024515 15 27.234327 MAJOR (20-50% max) 15.000 0.000000 15.0 15.0 100.0 True binary_x1_minus_x0 -0.056816 -0.057048 -0.056602 70521.0 Risk-Decreasing
15 Readmission eva_sm 0.023360 16 25.950871 MAJOR (20-50% max) 16.008 0.557055 16.0 16.0 100.0 True quartile_q4_minus_q1 0.028271 0.027898 0.028629 70512.0 Risk-Increasing
16 Readmission primary_sub_mod_cocaine_powder 0.023091 17 25.652840 MAJOR (20-50% max) 16.992 0.524796 17.0 17.0 100.0 True binary_x1_minus_x0 0.071789 0.071470 0.072106 70521.0 Risk-Increasing
17 Readmission evaluacindelprocesoteraputico 0.021649 18 24.050369 MAJOR (20-50% max) 18.260 2.404560 18.0 19.0 100.0 True quartile_q4_minus_q1 0.032717 0.032414 0.033054 70520.0 Risk-Increasing
18 Readmission tr_outcome_adm_discharge_rule_violation_undet 0.021562 19 23.953403 MAJOR (20-50% max) 18.740 2.342970 18.0 19.0 100.0 True binary_x1_minus_x0 0.143383 0.142511 0.144313 70521.0 Risk-Increasing
19 Readmission prim_sub_freq_rec 0.019463 20 21.621934 MAJOR (20-50% max) 20.000 0.000000 20.0 20.0 100.0 True quartile_q4_minus_q1 0.037986 0.037728 0.038277 70310.0 Risk-Increasing

Importance Tier Summary

outcome tier n_features importance_min importance_max features
0 Readmission DOMINANT (>=50% max) 9 0.046565 0.090015 primary_sub_mod_cocaine_paste, adm_age_rec3, porc_pobr, sex_rec_woman, plan_type_corr_pg_pr, plan_type_corr_m_pr, ethnicity, dit_m, eva_consumo
1 Readmission MAJOR (20-50% max) 11 0.019463 0.041594 ed_attainment_corr, occupation_condition_corr24_unemployed, dg_psiq_cie_10_dg, sub_dep_icd10_status_drug_dependence, polysubstance_strict, primary_sub_mod_alcohol, eva_sm, primary_sub_mod_cocaine_powder, evaluacindelprocesoteraputico, tr_outcome_adm_discharge_rule_violation_undet, prim_sub_freq_rec
2 Readmission MODERATE (10-20% max) 16 0.009266 0.017740 eva_transgnorma, eva_ocupacion, tr_outcome_referral, adm_motive_justice_sector, cohabitation_with_couple_children, tenure_status_household, any_violence_1_domestic_violence_sex_abuse, eva_fam, dx_f6_personality, tr_outcome_dropout, any_phys_dx, adm_motive_sanitary_sector, urbanicity_cat, eva_relinterp, occupation_condition_corr24_inactive, first_sub_used_alcohol
3 Readmission MINOR (<10% max) 20 0.000000 0.008886 cohabitation_others, plan_type_corr_pg_pai, cohabitation_family_of_origin, marital_status_rec_single, eva_fisica, dx_f3_mood, dg_psiq_cie_10_instudy, plan_type_corr_m_pai, adm_motive_another_sud_facility_fonodrogas_senda_previene, first_sub_used_cocaine_powder, marital_status_rec_separated_divorced_annulled_widowed, dx_f_any_severe_mental, tipo_de_vivienda_rec2_other_unknown, primary_sub_mod_others, adm_motive_other, first_sub_used_cocaine_paste, first_sub_used_other, tr_outcome_adm_discharge_adm_reasons, national_foreign, tr_outcome_other
4 Death DOMINANT (>=50% max) 2 0.267121 0.497441 adm_age_rec3, primary_sub_mod_alcohol
5 Death MAJOR (20-50% max) 2 0.099876 0.114488 prim_sub_freq_rec, occupation_condition_corr24_unemployed
6 Death MODERATE (10-20% max) 3 0.058190 0.092606 any_phys_dx, eva_ocupacion, cohabitation_with_couple_children
7 Death MINOR (<10% max) 49 0.000000 0.049436 eva_fisica, adm_motive_sanitary_sector, ed_attainment_corr, sex_rec_woman, first_sub_used_alcohol, occupation_condition_corr24_inactive, primary_sub_mod_cocaine_paste, dit_m, plan_type_corr_pg_pr, porc_pobr, eva_sm, polysubstance_strict, tenure_status_household, marital_status_rec_separated_divorced_annulled_widowed, tipo_de_vivienda_rec2_other_unknown, sub_dep_icd10_status_drug_dependence, primary_sub_mod_cocaine_powder, marital_status_rec_single, tr_outcome_dropout, tr_outcome_adm_discharge_rule_violation_undet, eva_fam, dg_psiq_cie_10_instudy, tr_outcome_referral, dx_f_any_severe_mental, eva_transgnorma, evaluacindelprocesoteraputico, eva_consumo, plan_type_corr_m_pr, ethnicity, dx_f6_personality, any_violence_1_domestic_violence_sex_abuse, eva_relinterp, dg_psiq_cie_10_dg, plan_type_corr_pg_pai, adm_motive_another_sud_facility_fonodrogas_senda_previene, urbanicity_cat, cohabitation_others, dx_f3_mood, cohabitation_family_of_origin, adm_motive_other, adm_motive_justice_sector, first_sub_used_cocaine_paste, plan_type_corr_m_pai, first_sub_used_other, first_sub_used_cocaine_powder, primary_sub_mod_others, national_foreign, tr_outcome_adm_discharge_adm_reasons, tr_outcome_other

Outcome Summary

outcome n_features top1_feature top1_importance top2_feature top2_importance top1_to_top2_ratio top15_stable_freq_ge_90pct_n
0 Readmission 56 primary_sub_mod_cocaine_paste 0.090015 adm_age_rec3 0.078743 1.143145 15
1 Death 56 adm_age_rec3 0.497441 primary_sub_mod_alcohol 0.267121 1.862228 15

Key Predictors (No Clinical Interpretation)

  Risk Profile Key Predictors (Ranked; mean |SHAP| on log-hazard scale) Notes
0 Readmission 1. primary_sub_mod_cocaine_paste (|SHAP|=0.0900, freq=100.0%) 2. adm_age_rec3 (|SHAP|=0.0787, freq=100.0%) 3. porc_pobr (|SHAP|=0.0731, freq=100.0%) 4. sex_rec_woman (|SHAP|=0.0705, freq=100.0%) 5. plan_type_corr_pg_pr (|SHAP|=0.0697, freq=100.0%) 6. plan_type_corr_m_pr (|SHAP|=0.0561, freq=100.0%) 7. ethnicity (|SHAP|=0.0549, freq=100.0%) 8. dit_m (|SHAP|=0.0475, freq=100.0%) 9. eva_consumo (|SHAP|=0.0466, freq=100.0%) 10. ed_attainment_corr (|SHAP|=0.0416, freq=100.0%) 11. occupation_condition_corr24_unemployed (|SHAP|=0.0365, freq=100.0%) 12. dg_psiq_cie_10_dg (|SHAP|=0.0284, freq=100.0%) 13. sub_dep_icd10_status_drug_dependence (|SHAP|=0.0284, freq=100.0%) 14. polysubstance_strict (|SHAP|=0.0253, freq=100.0%) 15. primary_sub_mod_alcohol (|SHAP|=0.0245, freq=100.0%) Purely data-driven ranking from aggregated CV-SHAP.
1 Death 1. adm_age_rec3 (|SHAP|=0.4974, freq=100.0%) 2. primary_sub_mod_alcohol (|SHAP|=0.2671, freq=100.0%) 3. prim_sub_freq_rec (|SHAP|=0.1145, freq=100.0%) 4. occupation_condition_corr24_unemployed (|SHAP|=0.0999, freq=100.0%) 5. any_phys_dx (|SHAP|=0.0926, freq=100.0%) 6. eva_ocupacion (|SHAP|=0.0749, freq=100.0%) 7. cohabitation_with_couple_children (|SHAP|=0.0582, freq=100.0%) 8. eva_fisica (|SHAP|=0.0494, freq=100.0%) 9. adm_motive_sanitary_sector (|SHAP|=0.0490, freq=100.0%) 10. ed_attainment_corr (|SHAP|=0.0460, freq=100.0%) 11. sex_rec_woman (|SHAP|=0.0458, freq=100.0%) 12. first_sub_used_alcohol (|SHAP|=0.0415, freq=100.0%) 13. occupation_condition_corr24_inactive (|SHAP|=0.0414, freq=100.0%) 14. primary_sub_mod_cocaine_paste (|SHAP|=0.0356, freq=100.0%) 15. dit_m (|SHAP|=0.0345, freq=100.0%) Purely data-driven ranking from aggregated CV-SHAP.

Risk-Decreasing Predictors (Exploratory Direction Only)

  Outcome Top Risk-Decreasing Predictors (exploratory) Direction basis
0 Readmission 1. adm_age_rec3 (|SHAP|=0.0787, freq=100.0%) 2. ed_attainment_corr (|SHAP|=0.0416, freq=100.0%) 3. primary_sub_mod_alcohol (|SHAP|=0.0245, freq=100.0%) Bootstrap CI of SHAP delta (Q4-Q1 for numeric, x=1-x=0 for binary).
1 Death 1. cohabitation_with_couple_children (|SHAP|=0.0582, freq=100.0%) 2. sex_rec_woman (|SHAP|=0.0458, freq=100.0%) 3. primary_sub_mod_cocaine_paste (|SHAP|=0.0356, freq=100.0%) 4. porc_pobr (|SHAP|=0.0334, freq=100.0%) 5. polysubstance_strict (|SHAP|=0.0282, freq=100.0%) Bootstrap CI of SHAP delta (Q4-Q1 for numeric, x=1-x=0 for binary).

Direction Sanity Check

outcome feature rank mean_abs_shap_log_hazard direction_method delta_point delta_ci_low delta_ci_high direction_label
56 Death adm_age_rec3 1 0.497441 quartile_q4_minus_q1 1.372988 1.368431 1.377258 Risk-Increasing
66 Death sex_rec_woman 11 0.045834 binary_x1_minus_x0 -0.122627 -0.123087 -0.122155 Risk-Decreasing
72 Death porc_pobr 17 0.033410 quartile_q4_minus_q1 -0.008292 -0.009465 -0.007206 Risk-Decreasing
1 Readmission adm_age_rec3 2 0.078743 quartile_q4_minus_q1 -0.163150 -0.166379 -0.159930 Risk-Decreasing
2 Readmission porc_pobr 3 0.073050 quartile_q4_minus_q1 0.119147 0.117414 0.120873 Risk-Increasing
3 Readmission sex_rec_woman 4 0.070474 binary_x1_minus_x0 0.183644 0.183199 0.184149 Risk-Increasing
Saved:
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_predictor_analysis_no_clinical_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_importance_tiers_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_predictor_summary_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_takehome_no_clinical_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_protective_no_clinical_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_predictor_analysis_no_clinical_20260306_1821.xlsx
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb9_dual_predictor_run_info_20260306_1821.json
Code
#@title Composed Academic Figure: Bar (A) + Beeswarm (B) per Outcome

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import rcParams
import shap
import numpy as np
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML

# Academic style settings
rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman"],
    "font.size": 12,
    "axes.labelsize": 12,
    "axes.titlesize": 13,
    "xtick.labelsize": 10,
    "ytick.labelsize": 10,
    "legend.fontsize": 11,
    "figure.dpi": 300,
    "savefig.dpi": 300,
})

TOP_N = 20  # how many top features to show in bar and beeswarm
MAX_BEE_N = 3000
FIG_SIZE = (14, 8)

def composed_figure(outcome_name, shap_mat, X_df, feat_names, color_bar, figsize=FIG_SIZE):
    """Panel A: Bar plot. Panel B: Beeswarm. Clean academic style."""
    
    n = min(X_df.shape[0], shap_mat.shape[0])
    p = min(X_df.shape[1], shap_mat.shape[1])
    S = shap_mat[:n, :p].astype(float)
    Xd = X_df.iloc[:n, :p].copy()
    
    # Ranking
    imp = np.abs(S).mean(axis=0)
    order = np.argsort(imp)[::-1][:TOP_N]
    feats = [feat_names[i] for i in order]
    imp_top = imp[order]
    se = np.std(np.abs(S[:, order]), axis=0) / np.sqrt(n)
    
    # Subsample for beeswarm
    rng = np.random.RandomState(42)
    idx = rng.choice(n, size=min(n, MAX_BEE_N), replace=False) if n > MAX_BEE_N else np.arange(n)
    S_bee = S[idx][:, order]
    X_bee = Xd.iloc[idx, list(order)].copy()
    X_bee.columns = feats
    
    # Figure layout
    fig = plt.figure(figsize=figsize)
    # gs = gridspec.GridSpec(1, 2, figure=fig, wspace=0.35,
    #                       left=0.06, right=0.98, top=0.92, bottom=0.08,
    #                       width_ratios=[1, 1.4])
    gs = gridspec.GridSpec(1, 2, figure=fig, wspace=0.25,  # slightly tighter if no y-text
                      left=0.06, right=0.98, top=0.94, bottom=0.08,  # top=0.94 since no title
                      width_ratios=[1, 1.9])  # Give B more width (was 1, 1.4)
    # Panel A: Bar
    ax_bar = fig.add_subplot(gs[0, 0])
    y_pos = np.arange(TOP_N)[::-1]
    
    ax_bar.barh(y_pos, imp_top, xerr=se, color=color_bar, alpha=0.85,
                edgecolor="white", linewidth=0.5,
                error_kw=dict(ecolor="black", elinewidth=1, capsize=3),
                height=0.6)
    
    ax_bar.set_yticks(y_pos)
    ax_bar.set_yticklabels([f.replace("_", " ").title() for f in feats], fontsize=10)
    ax_bar.set_xlabel("Mean |SHAP| (log-hazard)", fontsize=12)
    ax_bar.set_title("A", loc="left", fontweight="bold", fontsize=16, pad=6)
    ax_bar.spines["top"].set_visible(False)
    ax_bar.spines["right"].set_visible(False)
    ax_bar.tick_params(axis="x", labelsize=10)
    
    # Panel B: Beeswarm
    ax_bee = fig.add_subplot(gs[0, 1])
    
    exp = shap.Explanation(values=S_bee, data=X_bee.to_numpy(), feature_names=feats)
    
    plt.sca(ax_bee)
    shap.plots.beeswarm(exp, max_display=TOP_N, show=False, color=plt.get_cmap("RdBu_r"))
    
    ax_bee.set_title("B", loc="left", fontweight="bold", fontsize=16, pad=6)
    ax_bee.spines["top"].set_visible(False)
    ax_bee.spines["right"].set_visible(False)
    ax_bee.tick_params(axis="both", labelsize=12)
    ax_bee.set_xlabel("SHAP value (log-hazard)", fontsize=12)
    # HIDE Y-AXIS FEATURE NAMES FOR CLEANER LOOK
    ax_bee.set_yticklabels([])  
    ax_bee.tick_params(left=False)   
    # Remove extra colorbar if created
    if len(fig.axes) > 2:
        fig.axes[-1].remove()
    
    #fig.suptitle(f"{outcome_name} (n = {n:,})", fontsize=13, fontweight="bold", y=0.97)
    
    return fig

# Generate for each outcome
outcomes = {
    "Readmission": ("shap_r_all", "#2166ac"),
    "Death": ("shap_d_all", "#b2182b")
}

# Add this right before the loop (around line 100):
from pathlib import Path
if isinstance(FIG_DIR, str):
    FIG_DIR = Path(FIG_DIR)

for name, (key, color) in outcomes.items():
    if key not in shap_data:
        continue
        
    fig = composed_figure(
        outcome_name=name,
        shap_mat=np.asarray(shap_data[key]),
        X_df=X_all,
        feat_names=feature_names,
        color_bar=color
    )
    
    # Save
    stem = name.lower().replace(" ", "_")
    png = Path(FIG_DIR) / f"XGB_Composed_{stem}_{timestamp}.png"
    pdf = Path(FIG_DIR) / f"XGB_Composed_{stem}_{timestamp}.pdf"
    
    fig.savefig(png, dpi=300, bbox_inches="tight", facecolor="white")
    fig.savefig(pdf, dpi=300, bbox_inches="tight", facecolor="white")
    
    plt.show()
    display(HTML(f"<b>{name}</b> saved"))
    plt.close(fig)

Readmission saved

Death saved
Code
#@title Composed Waterfall Plots: Highest vs. Lowest Risk at Month 60

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import rcParams
import shap
import numpy as np
import pandas as pd
from pathlib import Path
from IPython.display import display, HTML
import pickle

# Academic style settings
rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman"],
    "font.size": 13,
    "axes.labelsize": 14,
    "axes.titlesize": 14,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.02,
})

# Ensure FIG_DIR exists
if "FIG_DIR" not in globals():
    FIG_DIR = Path(PROJECT_ROOT) / "_figs"
FIG_DIR = Path(FIG_DIR)
FIG_DIR.mkdir(exist_ok=True, parents=True)

if "timestamp" not in globals():
    timestamp = pd.Timestamp.now().strftime("%Y%m%d_%H%M")

if "source_tag" not in globals():
    source_tag = timestamp

# --- Recreate necessary variables from Step 8 outputs ---

OUTCOME_CFG = {
    "readm": {
        "label": "Readmission",
        "margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],
        "probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],
        "hz_times_candidates": ["times_r", "times_readm"],
        "hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"],
    },
    "death": {
        "label": "Death",
        "margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],
        "probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],
        "hz_times_candidates": ["times_d", "times_death", "times_mort"],
        "hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"],
    },
}

VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]

def get_first(dct, keys, default=None):
    for k in keys:
        if k in dct and dct[k] is not None:
            return dct[k]
    return default

def h0_at_t(times, h0_vals, t):
    times = np.asarray(times, dtype=float).ravel()
    h0_vals = np.asarray(h0_vals, dtype=float).ravel()
    if times.size == 0 or h0_vals.size == 0 or len(times) != len(h0_vals):
        return np.nan
    t = float(t)
    if t < times[0]:
        return 0.0
    i = np.searchsorted(times, t, side="right") - 1
    i = max(0, min(i, len(h0_vals) - 1))
    return float(h0_vals[i])

def collect_risk_samples_by_id(raw_log, split_map, hz_map, cfg, horizon):
    rows = []
    t = float(horizon)
    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))
        split_rec = split_map.get(key)
        hz_rec = hz_map.get(key)
        if split_rec is None:
            continue
        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()
        if len(val_ids) != len(margins) or len(val_ids) == 0:
            continue
        
        risk_vec = None
        eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)
        
        if eval_times.size > 0 and probs_mat.ndim == 2:
            j = np.where(np.isclose(eval_times, t))[0]
            if j.size > 0:
                jj = int(j[0])
                if probs_mat.shape[0] == len(val_ids) and probs_mat.shape[1] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[:, jj]
                elif probs_mat.shape[1] == len(val_ids) and probs_mat.shape[0] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[jj, :]
        
        if risk_vec is None and hz_rec is not None:
            times = get_first(hz_rec, cfg["hz_times_candidates"], [])
            h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], [])
            H0_t = h0_at_t(times, h0_vals, t)
            if np.isfinite(H0_t):
                risk_vec = 1.0 - np.exp(-np.exp(margins) * H0_t)
        
        if risk_vec is None:
            continue
        risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)
        if len(risk_vec) != len(val_ids):
            continue
        for i, pid in enumerate(val_ids):
            rv = float(risk_vec[i])
            if np.isfinite(rv):
                rows.append((str(pid), rv))
    
    if not rows:
        return pd.DataFrame(columns=["id", "risk"])
    return pd.DataFrame(rows, columns=["id", "risk"])

def summarize_risk_ci(df_samples, alpha=0.05):
    if df_samples.empty:
        return pd.DataFrame(columns=["id", "risk_mean", "risk_ci_low", "risk_ci_high", "n_samples"])
    g = df_samples.groupby("id")["risk"]
    out = g.agg(risk_mean="mean", n_samples="size").reset_index()
    out["risk_ci_low"] = g.quantile(alpha / 2.0).values
    out["risk_ci_high"] = g.quantile(1.0 - alpha / 2.0).values
    return out

# Load Step 8 artifacts
IN_DIR = Path(PROJECT_ROOT) / "_out"

# Find latest Step 8 bundle
shap_files = list(IN_DIR.glob("xgb6_corr_DUAL_SHAP_Aggregated_*.pkl"))
if not shap_files:
    raise FileNotFoundError("No Step 8 SHAP files found")
latest_shap = sorted(shap_files)[-1]

with open(latest_shap, "rb") as f:
    shap_data = pickle.load(f)

X_all = shap_data["X_all"]
feature_names = list(shap_data["feature_names"])
if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(X_all, columns=feature_names)

# Create id_to_row mapping
id_to_row = {str(idx): i for i, idx in enumerate(X_all.index)}

# Load raw data for risk calculation
raw_file = sorted(IN_DIR.glob("xgb6_corr_DUAL_final_ev_hyp_*.pkl"))[-1]
hz_file = sorted(IN_DIR.glob("xgb6_corr_DUAL_BaselineHazards_*.pkl"))[-1]
split_file = sorted(IN_DIR.glob("xgb6_corr_DUAL_CV_Splits_*.pkl"))[-1]

with open(raw_file, "rb") as f:
    raw_data_log = pickle.load(f)
with open(hz_file, "rb") as f:
    baseline_hazards_log = pickle.load(f)
with open(split_file, "rb") as f:
    cv_splits_log = pickle.load(f)

split_map = {(int(s["imp_idx"]), int(s["fold_idx"])): s for s in cv_splits_log if "imp_idx" in s and "fold_idx" in s}
hz_map = {(int(h["imp_idx"]), int(h["fold_idx"])): h for h in baseline_hazards_log if "imp_idx" in h and "fold_idx" in h}

# Create Explanation objects
def create_explanation(shap_key):
    if shap_key not in shap_data:
        return None
    shap_vals = np.asarray(shap_data[shap_key], dtype=float)
    return shap.Explanation(
        values=shap_vals,
        base_values=np.zeros(X_all.shape[0]),
        data=X_all.to_numpy(),
        feature_names=feature_names
    )

explanation_r = create_explanation("shap_r_all")
explanation_d = create_explanation("shap_d_all")

# Get risk data at month 60
df_risk_r = summarize_risk_ci(collect_risk_samples_by_id(raw_data_log, split_map, hz_map, OUTCOME_CFG["readm"], 60))
df_risk_d = summarize_risk_ci(collect_risk_samples_by_id(raw_data_log, split_map, hz_map, OUTCOME_CFG["death"], 60))

df_risk_r["id"] = df_risk_r["id"].astype(str)
df_risk_d["id"] = df_risk_d["id"].astype(str)
df_risk_r = df_risk_r[df_risk_r["id"].isin(id_to_row.keys())]
df_risk_d = df_risk_d[df_risk_d["id"].isin(id_to_row.keys())]

# --- Plotting Function ---
def composed_waterfalls(outcome_name, explanation, df_risk, horizon=60, max_display=12, figsize=(14, 7)):
    """Side-by-side waterfalls: Highest risk (Panel A) vs. Lowest risk (Panel B)."""
    
    if explanation is None or len(df_risk) == 0:
        print(f"Skipping {outcome_name}: missing data")
        return None
    
    # Get extreme cases
    hi = df_risk.sort_values("risk_mean", ascending=False).iloc[0]
    lo = df_risk.sort_values("risk_mean", ascending=True).iloc[0]
    
    high_id, low_id = str(hi["id"]), str(lo["id"])
    high_risk, low_risk = float(hi["risk_mean"]), float(lo["risk_mean"])
    high_n, low_n = int(hi["n_samples"]), int(lo["n_samples"])
    
    high_row = id_to_row[high_id]
    low_row = id_to_row[low_id]
    
    # Create figure
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(1, 2, figure=fig, wspace=0.40,
                          left=0.06, right=0.98, top=0.88, bottom=0.12)
    
    # Panel A: Highest Risk
    ax1 = fig.add_subplot(gs[0, 0])
    plt.sca(ax1)
    shap.plots.waterfall(explanation[high_row], max_display=max_display, show=False)
    ax1.set_title(f"A)    Highest risk (ID: {high_id[-6:]}, {high_risk:.1%}, n={high_n})", 
                 loc='left', fontweight='bold', fontsize=12, pad=10)
    ax1.set_xlabel("SHAP value (log-hazard)", fontsize=11)
    ax1.tick_params(axis='both', labelsize=10)
    
    # Panel B: Lowest Risk  
    ax2 = fig.add_subplot(gs[0, 1])
    plt.sca(ax2)
    shap.plots.waterfall(explanation[low_row], max_display=max_display, show=False)
    ax2.set_title(f"B)    Lowest risk (ID: {low_id[-6:]}, {low_risk:.1%}, n={low_n})", 
                 loc='left', fontweight='bold', fontsize=12, pad=10)
    ax2.set_xlabel("SHAP value (log-hazard)", fontsize=11)
    ax2.tick_params(axis='both', labelsize=10)
    
    # Optional: remove suptitle for completely clean look
    #fig.suptitle(f"{outcome_name} at {horizon} months", fontsize=13, fontweight='bold', y=0.96)
    
    return fig

# --- Generate Plots ---
outcomes = {
    "Readmission": (explanation_r, df_risk_r),
    "Death": (explanation_d, df_risk_d)
}

for name, (exp, df_risk) in outcomes.items():
    if exp is None or df_risk is None or len(df_risk) == 0:
        continue
        
    fig = composed_waterfalls(name, exp, df_risk, horizon=60, max_display=12, figsize=(14, 7))
    if fig is None:
        continue
    
    # Save
    stem = name.lower().replace(" ", "_")
    png = FIG_DIR / f"XGB_Waterfall_Composed_{stem}_m60_{timestamp}.png"
    pdf = FIG_DIR / f"XGB_Waterfall_Composed_{stem}_m60_{timestamp}.pdf"
    
    fig.savefig(png, dpi=300, bbox_inches="tight", facecolor="white")
    fig.savefig(pdf, dpi=300, bbox_inches="tight", facecolor="white")
    
    plt.show()
    display(HTML(f"<b>{name}</b> composed waterfalls saved"))
    plt.close(fig)

Readmission composed waterfalls saved

Death composed waterfalls saved

Interactions

Unlike DeepHit (which creates different risk functions over time), XGBoost Cox models are time-invariant (Proportional Hazards). This means the model’s structure—and its interactions—are constant across all time horizons. Therefore, we do not need to loop through 3, 6, 12 months. We run the discovery once per outcome to find the “Global Interactions” inherent in the model structure.

Code
#@title ⚡ Step 11: Interaction Discovery (DUAL, Global + Time-Dependent @ 3/12/36/60m)

import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import display, Markdown, HTML  # FIXED: Added HTML
from scipy.stats import t as student_t

# -----------------------------
# 0) Config
# -----------------------------
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))

IN_DIR = os.path.join(PROJECT_ROOT, "_out")
OUT_DIR = os.path.join(PROJECT_ROOT, "_out")
FIG_DIR = os.path.join(PROJECT_ROOT, "_figs")

# FIXED: Removed duplicate os.makedirs calls
os.makedirs(IN_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)

TARGET_HORIZONS = [3, 6, 12, 36, 60, 96]
MAIN_TOP_K = 15           # top main effects to scan for interactions
INTER_TOP_K = 5           # top interactors per main feature for reporting
MIN_VALID_N = 500         # minimum patients required for a valid correlation
WEIGHT_CLIP_PCT = 99.5    # cap extreme risk-gradient weights
TIME_DEP_RANK_RANGE_MIN = 8 #marks a time-dependent signal if ranking changes in at least 8 position between horizons
TIME_DEP_ABS_DELTA_MIN = 0.05 #if interaction force/influence change in terms of correlation between horizons 
PLOT_TOP_TIME_DEP = 12 #maximum of time-dependencies interactions to plot
FIG_DPI = 300
DIRECTION_MIN_GROUP_N = 30 #requires at least 30 cases in each group to compare Q1 vs. Q4
FDR_ALPHA = 0.01 # significance threshold

OUTCOME_CFG = {
    "readm": {
        "label": "Readmission",
        "shap_key_candidates": ["shap_r_all", "shap_readm_all"],
        "margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],
        "probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],
        "hz_times_candidates": ["times_r", "times_readm"],
        "hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"],
    },
    "death": {
        "label": "Death",
        "shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],
        "margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],
        "probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],
        "hz_times_candidates": ["times_d", "times_death", "times_mort"],
        "hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"],
    },
}

VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]
# FIXED: Changed TS_RE to match both _mar26 and non-_mar26 suffixes
TS_RE = re.compile(r"(\d{8}_\d{4})")

# -----------------------------
# 1) Helpers
# -----------------------------
def _tag_from_path(path):
    m = TS_RE.search(os.path.basename(path))
    return m.group(1) if m else None

def pick_latest_complete_bundle(in_dir):
    """Pick latest SHAP bundle, handling both _mar26 and non-_mar26 suffixes."""
    shap_files = glob.glob(os.path.join(in_dir, "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl"))
    candidates = []
    for shapf in shap_files:
        tag = _tag_from_path(shapf)
        if not tag:
            continue
        
        # FIXED: Try both naming patterns
        for suffix in [f"_{tag}_mar26.pkl", f"_{tag}.pkl"]:
            rawf = os.path.join(in_dir, f"xgb6_corr_DUAL_final_ev_hyp{suffix}")
            hzf = os.path.join(in_dir, f"xgb6_corr_DUAL_BaselineHazards{suffix}")
            splitf = os.path.join(in_dir, f"xgb6_corr_DUAL_CV_Splits{suffix}")
            
            if all(os.path.exists(p) for p in (rawf, hzf, splitf)):
                dt = pd.to_datetime(tag, format="%Y%m%d_%H%M", errors="coerce")
                if pd.notna(dt):
                    candidates.append((dt, tag, shapf, rawf, hzf, splitf))
                    break  # Found complete set
    
    if not candidates:
        raise FileNotFoundError(
            f"No complete Step 5 bundle found in '{in_dir}' with prefix xgb6_corr_DUAL_*."
        )
    candidates.sort(key=lambda x: x[0])
    _, tag, shapf, rawf, hzf, splitf = candidates[-1]
    return tag, shapf, rawf, hzf, splitf

def get_first(dct, keys, default=None):
    for k in keys:
        if k in dct and dct[k] is not None:
            return dct[k]
    return default

def find_first_key(dct, keys):
    for k in keys:
        if k in dct:
            return k
    return None

def h0_at_t(times, h0_vals, t):
    times = np.asarray(times, dtype=float).ravel()
    h0_vals = np.asarray(h0_vals, dtype=float).ravel()
    # FIXED: Added length check and improved boundary handling
    if times.size == 0 or h0_vals.size == 0 or len(times) != len(h0_vals):
        return np.nan
    t = float(t)
    if t < times[0]:
        return 0.0  # Before first event
    i = np.searchsorted(times, t, side="right") - 1
    i = max(0, min(i, len(h0_vals) - 1))  # Clip to valid range
    return float(h0_vals[i])

def safe_corr(x, y, min_n=500, eps=1e-12):
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(y)
    n = int(m.sum())
    if n < min_n:
        return np.nan, n
    xx = x[m]
    yy = y[m]
    if np.std(xx) < eps or np.std(yy) < eps:
        return np.nan, n
    r = np.corrcoef(xx, yy)[0, 1]
    return float(r), n

def linear_residual(x, y, min_n=500, eps=1e-12):
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(y)
    n = int(m.sum())
    if n < min_n:
        return None, n, np.nan, np.nan

    xx = x[m]
    yy = y[m]
    resid = np.full_like(y, np.nan, dtype=float)

    if np.std(xx) < eps:
        mu = float(np.mean(yy))
        resid[m] = yy - mu
        return resid, n, 0.0, mu

    A = np.column_stack([xx, np.ones_like(xx)])
    coef, *_ = np.linalg.lstsq(A, yy, rcond=None)
    slope = float(coef[0])
    intercept = float(coef[1])
    resid[m] = yy - (slope * xx + intercept)
    return resid, n, slope, intercept

def interaction_scores(X, shap_mat, feature_names, main_indices, main_importance=None, min_n=500):
    X = np.asarray(X, dtype=float)
    S = np.asarray(shap_mat, dtype=float)
    p = X.shape[1]
    rows = []

    for main_rank, i in enumerate(main_indices, start=1):
        resid, n_main, slope, intercept = linear_residual(X[:, i], S[:, i], min_n=min_n)
        if resid is None:
            continue

        for j in range(p):
            if j == i:
                continue

            r, n_valid = safe_corr(X[:, j], resid, min_n=min_n)
            if not np.isfinite(r):
                continue

            delta_q, dir_q, n_dir, n_q1, n_q4 = quartile_delta_direction(
                X[:, j], resid, min_n=min_n, min_group_n=DIRECTION_MIN_GROUP_N
            )

            rows.append({
                "main_idx": int(i),
                "inter_idx": int(j),
                "main_feature": feature_names[i],
                "main_rank_global": int(main_rank),
                "main_importance_global": float(main_importance[i]) if main_importance is not None else np.nan,
                "interactor": feature_names[j],
                "corr_resid_vs_interactor": float(r),
                "abs_corr": float(abs(r)),
                "delta_q4_q1_resid": float(delta_q) if np.isfinite(delta_q) else np.nan,
                "direction_q4_q1": dir_q,
                "n_dir_valid": int(n_dir),
                "n_q1": int(n_q1),
                "n_q4": int(n_q4),
                "n_valid": int(n_valid),
                "main_linear_slope": float(slope),
                "main_linear_intercept": float(intercept),
            })

    if not rows:
        return pd.DataFrame(columns=[
            "main_idx","inter_idx","main_feature","main_rank_global","main_importance_global","interactor",
            "corr_resid_vs_interactor","abs_corr","delta_q4_q1_resid","direction_q4_q1",
            "n_dir_valid","n_q1","n_q4","n_valid","main_linear_slope","main_linear_intercept"
        ])
    return pd.DataFrame(rows)

def top_interactions(df_scores, top_k):
    if len(df_scores) == 0:
        return df_scores.copy()
    df = df_scores.sort_values(
        ["outcome", "horizon_months", "main_rank_global", "abs_corr"],
        ascending=[True, True, True, False]
    )
    return (
        df.groupby(["outcome", "horizon_months", "main_feature"], dropna=False, as_index=False, group_keys=False)
          .head(top_k)
          .reset_index(drop=True)
    )

def collect_risk_by_id(raw_log, split_map, hz_map, cfg, horizon):
    rows = []
    t = float(horizon)

    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))
        split_rec = split_map.get(key)
        hz_rec = hz_map.get(key)
        if split_rec is None:
            continue

        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()
        if len(val_ids) == 0 or len(margins) != len(val_ids):
            continue

        risk_vec = None

        eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)

        if eval_times.size > 0 and probs_mat.ndim == 2:
            j = np.where(np.isclose(eval_times, t))[0]
            if j.size > 0:
                jj = int(j[0])
                if probs_mat.shape[0] == len(val_ids) and probs_mat.shape[1] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[:, jj]
                elif probs_mat.shape[1] == len(val_ids) and probs_mat.shape[0] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[jj, :]
                if risk_vec is not None:
                    risk_vec = np.asarray(risk_vec, dtype=float).ravel()

        if risk_vec is None and hz_rec is not None:
            times = get_first(hz_rec, cfg["hz_times_candidates"], [])
            h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], [])
            H0_t = h0_at_t(times, h0_vals, t)
            if np.isfinite(H0_t):
                surv = np.exp(-np.exp(margins) * H0_t)
                risk_vec = 1.0 - surv

        if risk_vec is not None and len(risk_vec) == len(val_ids):
            risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)
            for i, pid in enumerate(val_ids):
                rv = float(risk_vec[i])
                if np.isfinite(rv):
                    rows.append((str(pid), rv))

    if not rows:
        return pd.DataFrame(columns=["id", "risk"])
    return pd.DataFrame(rows, columns=["id", "risk"]).groupby("id", as_index=False)["risk"].mean()

def save_current_figure(stem):
    fig = plt.gcf()
    png = os.path.join(FIG_DIR, f"{stem}.png")
    pdf = os.path.join(FIG_DIR, f"{stem}.pdf")
    fig.savefig(png, dpi=FIG_DPI, bbox_inches="tight")
    fig.savefig(pdf, bbox_inches="tight")
    return [png, pdf]

def corr_to_pvalue(r, n):
    if (not np.isfinite(r)) or (n <= 2) or (abs(r) >= 1):
        return np.nan
    tval = r * np.sqrt((n - 2) / max(1.0 - r * r, 1e-12))
    return float(2.0 * (1.0 - student_t.cdf(abs(tval), df=n - 2)))

def bh_fdr(pvals):
    p = np.asarray(pvals, dtype=float)
    out = np.full_like(p, np.nan, dtype=float)
    m = np.isfinite(p)
    pv = p[m]
    n = len(pv)
    if n == 0:
        return out
    order = np.argsort(pv)
    ranked = pv[order]
    q = ranked * n / np.arange(1, n + 1, dtype=float)
    q = np.minimum.accumulate(q[::-1])[::-1]
    q = np.clip(q, 0.0, 1.0)
    unsorted_q = np.empty_like(q)
    unsorted_q[order] = q
    out[m] = unsorted_q
    return out

def quartile_delta_direction(x, resid, min_n=500, min_group_n=40, eps=1e-12):
    x = np.asarray(x, dtype=float).ravel()
    r = np.asarray(resid, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(r)
    n = int(m.sum())
    if n < min_n:
        return np.nan, "Unknown", n, 0, 0

    xx = x[m]
    rr = r[m]
    if np.std(xx) < eps or np.std(rr) < eps:
        return np.nan, "Unknown", n, 0, 0

    q1, q3 = np.quantile(xx, [0.25, 0.75])
    g1 = rr[xx <= q1]
    g4 = rr[xx >= q3]
    n1, n4 = int(len(g1)), int(len(g4))
    if n1 < min_group_n or n4 < min_group_n:
        return np.nan, "Unknown", n, n1, n4

    delta = float(np.mean(g4) - np.mean(g1))
    if abs(delta) < eps:
        d = "Neutral"
    elif delta > 0:
        d = "Positive"
    else:
        d = "Negative"
    return delta, d, n, n1, n4

def add_significance(df, family_cols, alpha=0.05):
    if len(df) == 0:
        return df
    out = df.copy()
    out["p_value"] = np.nan
    out["p_fdr"] = np.nan
    out[f"signif_fdr_{alpha:.2f}"] = False

    groups = out.groupby(family_cols, dropna=False).groups
    for _, idx in groups.items():
        idx = list(idx)
        p = np.array([corr_to_pvalue(out.at[i, "corr_resid_vs_interactor"], int(out.at[i, "n_valid"])) for i in idx], dtype=float)
        q = bh_fdr(p)
        out.loc[idx, "p_value"] = p
        out.loc[idx, "p_fdr"] = q
        out.loc[idx, f"signif_fdr_{alpha:.2f}"] = q < alpha

    return out


# -----------------------------
# 2) Load latest bundle
# -----------------------------
source_tag, shap_file, raw_file, hz_file, split_file = pick_latest_complete_bundle(IN_DIR)

with open(shap_file, "rb") as f:
    shap_data = pickle.load(f)
with open(raw_file, "rb") as f:
    raw_data_log = pickle.load(f)
with open(hz_file, "rb") as f:
    baseline_hazards_log = pickle.load(f)
with open(split_file, "rb") as f:
    cv_splits_log = pickle.load(f)

if "X_all" not in shap_data or "feature_names" not in shap_data:
    raise KeyError("SHAP file must contain X_all and feature_names.")

X_all = shap_data["X_all"]
feature_names = list(shap_data["feature_names"])
if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(X_all, columns=feature_names)
if list(X_all.columns) != feature_names:
    X_all = X_all.reindex(columns=feature_names)
if not X_all.index.is_unique:
    raise ValueError("X_all index must be unique.")

# FIXED: Added duplicate feature name check
if len(feature_names) != len(set(feature_names)):
    dupes = [f for f in feature_names if feature_names.count(f) > 1]
    raise ValueError(f"Duplicate feature names detected: {set(dupes)}")

X_np = X_all.to_numpy(dtype=float)
id_to_row = {str(idx): i for i, idx in enumerate(X_all.index)}

split_map = {
    (int(s["imp_idx"]), int(s["fold_idx"])): s
    for s in cv_splits_log
    if "imp_idx" in s and "fold_idx" in s
}
hz_map = {
    (int(h["imp_idx"]), int(h["fold_idx"])): h
    for h in baseline_hazards_log
    if "imp_idx" in h and "fold_idx" in h
}

display(Markdown(
    f"### Step 11 Interaction Discovery (DUAL)\n"
    f"- Source bundle tag: **{source_tag}**\n"
    f"- Patients: **{X_all.shape[0]}**\n"
    f"- Features: **{X_all.shape[1]}**\n"
    f"- Horizons: **{', '.join(map(str, TARGET_HORIZONS))} months**\n"
    f"- Global interactions: **log-hazard SHAP scale**\n"
    f"- Horizon interactions: **risk-scale approximation via `dr/deta * SHAP`**"
))

# -----------------------------
# 3) Run per outcome
# -----------------------------
global_scores_all = []
horizon_scores_all = []
horizon_n_all = []
saved_plot_files = []
processed_outcomes = []

for out_code, cfg in OUTCOME_CFG.items():
    shap_key = find_first_key(shap_data, cfg["shap_key_candidates"])
    if shap_key is None:
        # FIXED: Use display(HTML(...)) instead of print
        display(HTML(f"<p>Skipping {cfg['label']}: SHAP key not found ({cfg['shap_key_candidates']}).</p>"))
        continue

    shap_vals = np.asarray(shap_data[shap_key], dtype=float)
    if shap_vals.shape != X_np.shape:
        raise ValueError(f"{cfg['label']} SHAP shape mismatch: {shap_vals.shape} vs {X_np.shape}")

    processed_outcomes.append(out_code)

    # Main features fixed by global importance (for horizon comparability)
    global_imp = np.abs(shap_vals).mean(axis=0)
    main_idx = np.argsort(-global_imp)[:MAIN_TOP_K]

    # Global interaction scan (time-invariant model structure)
    df_g = interaction_scores(
        X=X_np,
        shap_mat=shap_vals,
        feature_names=feature_names,
        main_indices=main_idx,
        main_importance=global_imp,
        min_n=MIN_VALID_N
    )
    df_g["outcome"] = cfg["label"]
    df_g["outcome_code"] = out_code
    df_g["scope"] = "global_log_hazard"
    df_g["horizon_months"] = np.nan
    df_g["n_patients"] = int(X_np.shape[0])
    global_scores_all.append(df_g)

    # Horizon-specific interaction scan (risk-scale approximation)
    for h in TARGET_HORIZONS:
        df_risk = collect_risk_by_id(raw_data_log, split_map, hz_map, cfg, h)
        df_risk = df_risk[df_risk["id"].isin(id_to_row.keys())].copy()
        n_h = int(len(df_risk))

        horizon_n_all.append({
            "outcome": cfg["label"],
            "outcome_code": out_code,
            "horizon_months": float(h),
            "n_patients_with_risk": n_h,
            "n_total_patients": int(X_np.shape[0]),
            "coverage_pct": (100.0 * n_h / X_np.shape[0]) if X_np.shape[0] > 0 else np.nan
        })

        if n_h < MIN_VALID_N:
            continue

        row_idx = df_risk["id"].map(id_to_row).to_numpy(dtype=int)
        risk = np.clip(df_risk["risk"].to_numpy(dtype=float), 1e-12, 1.0 - 1e-12)

        # Risk-gradient weight from Cox transform: dr/deta = (-ln(1-r))*(1-r)
        w = (-np.log(1.0 - risk)) * (1.0 - risk)
        if np.isfinite(w).sum() == 0:
            continue

        cap = np.nanpercentile(w, WEIGHT_CLIP_PCT)
        w = np.clip(w, 0.0, cap)

        X_h = X_np[row_idx, :]
        S_h = shap_vals[row_idx, :] * w[:, None]

        df_h = interaction_scores(
            X=X_h,
            shap_mat=S_h,
            feature_names=feature_names,
            main_indices=main_idx,
            main_importance=global_imp,
            min_n=MIN_VALID_N
        )
        df_h["outcome"] = cfg["label"]
        df_h["outcome_code"] = out_code
        df_h["scope"] = "horizon_risk_approx"
        df_h["horizon_months"] = float(h)
        df_h["n_patients"] = n_h
        df_h["weight_clip_pct"] = float(WEIGHT_CLIP_PCT)
        horizon_scores_all.append(df_h)

# -----------------------------
# 4) Combine + summarize
# -----------------------------
df_global_scores = pd.concat(global_scores_all, ignore_index=True) if global_scores_all else pd.DataFrame()
df_horizon_scores = pd.concat(horizon_scores_all, ignore_index=True) if horizon_scores_all else pd.DataFrame()
df_global_scores = add_significance(df_global_scores, ["outcome", "scope"], alpha=FDR_ALPHA) if len(df_global_scores) else df_global_scores
df_horizon_scores = add_significance(df_horizon_scores, ["outcome", "scope", "horizon_months"], alpha=FDR_ALPHA) if len(df_horizon_scores) else df_horizon_scores
df_horizon_n = pd.DataFrame(horizon_n_all)

df_global_top = top_interactions(df_global_scores, INTER_TOP_K) if len(df_global_scores) else pd.DataFrame()
df_horizon_top = top_interactions(df_horizon_scores, INTER_TOP_K) if len(df_horizon_scores) else pd.DataFrame()

if len(df_horizon_scores):
    tmp = df_horizon_scores.copy()
    tmp["rank_within_main_hz"] = (
        tmp.groupby(["outcome", "horizon_months", "main_feature"])["abs_corr"]
           .rank(method="min", ascending=False)
    )
    tmp["dir_pos"] = (tmp["direction_q4_q1"] == "Positive").astype(int)
    tmp["dir_neg"] = (tmp["direction_q4_q1"] == "Negative").astype(int)

    df_time = (
        tmp.groupby(["outcome", "main_feature", "interactor"], as_index=False)
           .agg(
               horizons_seen=("horizon_months", "nunique"),
               mean_abs_corr=("abs_corr", "mean"),
               sd_abs_corr=("abs_corr", "std"),
               min_abs_corr=("abs_corr", "min"),
               max_abs_corr=("abs_corr", "max"),
               min_rank=("rank_within_main_hz", "min"),
               max_rank=("rank_within_main_hz", "max"),
               n_rows=("abs_corr", "size"),
               n_positive=("dir_pos", "sum"),
               n_negative=("dir_neg", "sum"),
               min_p_fdr=("p_fdr", "min"),
           )
    )

    df_time["sd_abs_corr"] = df_time["sd_abs_corr"].fillna(0.0)
    df_time["cv_abs_corr_pct"] = 100.0 * df_time["sd_abs_corr"] / np.clip(df_time["mean_abs_corr"], 1e-12, None)
    df_time["abs_corr_delta"] = df_time["max_abs_corr"] - df_time["min_abs_corr"]
    df_time["rank_range"] = df_time["max_rank"] - df_time["min_rank"]

    df_time["n_direction_known"] = df_time["n_positive"] + df_time["n_negative"]
    df_time["direction_consistent"] = (df_time["n_direction_known"] > 0) & (
        (df_time["n_positive"] == 0) | (df_time["n_negative"] == 0)
    )
    df_time["direction_flip_flag"] = (df_time["n_positive"] > 0) & (df_time["n_negative"] > 0)
    df_time["dominant_direction"] = np.where(
        df_time["n_positive"] > df_time["n_negative"], "Positive",
        np.where(df_time["n_negative"] > df_time["n_positive"], "Negative", "Mixed")
    )

    top_keys = set(zip(df_horizon_top["outcome"], df_horizon_top["main_feature"], df_horizon_top["interactor"])) if len(df_horizon_top) else set()
    df_time["in_top_at_least_once"] = [(o, m, i) in top_keys for o, m, i in zip(df_time["outcome"], df_time["main_feature"], df_time["interactor"])]

    df_time["time_dependent_flag"] = (
        (df_time["horizons_seen"] >= 2) &
        ((df_time["rank_range"] >= TIME_DEP_RANK_RANGE_MIN) | (df_time["abs_corr_delta"] >= TIME_DEP_ABS_DELTA_MIN))
    )
    df_time["time_dependent_top_flag"] = df_time["time_dependent_flag"] & df_time["in_top_at_least_once"]
    # Add these aliases (clearer terminology)
    df_time["horizon_salience_variation_flag"] = df_time["time_dependent_flag"]
    df_time["horizon_salience_variation_top_flag"] = df_time["time_dependent_top_flag"]
else:
    df_time = pd.DataFrame()

# -----------------------------
# 5) Plots: time profiles of top time-dependent interactions
# -----------------------------
if len(df_horizon_scores) and len(df_time):
    for out in sorted(df_horizon_scores["outcome"].unique()):
        td = df_time[
            (df_time["outcome"] == out) &
            (df_time["time_dependent_top_flag"])
        ].head(PLOT_TOP_TIME_DEP)

        if len(td) == 0:
            continue

        keys = set(zip(td["main_feature"], td["interactor"]))
        dsub = df_horizon_scores[
            (df_horizon_scores["outcome"] == out) &
            (df_horizon_scores.apply(lambda r: (r["main_feature"], r["interactor"]) in keys, axis=1))
        ].copy()

        pivot = dsub.pivot_table(
            index=["main_feature", "interactor"],
            columns="horizon_months",
            values="abs_corr",
            aggfunc="mean"
        )

        plt.figure(figsize=(12, 8))
        for (mf, it), row in pivot.iterrows():
            xs = [h for h in TARGET_HORIZONS if h in row.index and np.isfinite(row[h])]
            ys = [row[h] for h in xs]
            if len(xs) >= 2:
                plt.plot(xs, ys, marker="o", linewidth=1.8, label=f"{mf} × {it}")

        plt.title(f"{out}: Time-Dependent Interaction Profiles (abs corr of residual signal)")
        plt.xlabel("Horizon (months)")
        plt.ylabel("Interaction strength (abs corr)")
        plt.xticks(TARGET_HORIZONS)
        plt.grid(alpha=0.25)
        plt.legend(loc="best", fontsize=8, ncol=1)
        plt.tight_layout()
        saved_plot_files.extend(save_current_figure(f"xgb11_dual_{out.lower()}_time_profiles_{source_tag}"))
        plt.show()
        plt.close()

# -----------------------------
# 6) Save outputs
# -----------------------------
f_global_scores = os.path.join(OUT_DIR, f"xgb11_dual_interactions_global_scores_{source_tag}.csv")
f_global_top = os.path.join(OUT_DIR, f"xgb11_dual_interactions_global_top_{source_tag}.csv")
f_horizon_scores = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_scores_{source_tag}.csv")
f_horizon_top = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_top_{source_tag}.csv")
f_time = os.path.join(OUT_DIR, f"xgb11_dual_interactions_time_dependent_{source_tag}.csv")
f_hn = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_sample_sizes_{source_tag}.csv")
f_info = os.path.join(OUT_DIR, f"xgb11_dual_interactions_run_info_{source_tag}.json")

if len(df_global_scores): df_global_scores.to_csv(f_global_scores, index=False)
if len(df_global_top): df_global_top.to_csv(f_global_top, index=False)
if len(df_horizon_scores): df_horizon_scores.to_csv(f_horizon_scores, index=False)
if len(df_horizon_top): df_horizon_top.to_csv(f_horizon_top, index=False)
if len(df_time): df_time.to_csv(f_time, index=False)
if len(df_horizon_n): df_horizon_n.to_csv(f_hn, index=False)

run_info = {
    "source_bundle_tag": source_tag,
    "source_files": {
        "shap": shap_file,
        "raw": raw_file,
        "baseline_hazards": hz_file,
        "cv_splits": split_file
    },
    "outcomes_processed": processed_outcomes,
    "n_patients": int(X_all.shape[0]),
    "n_features": int(X_all.shape[1]),
    "main_top_k": int(MAIN_TOP_K),
    "inter_top_k": int(INTER_TOP_K),
    "horizons_months": [float(h) for h in TARGET_HORIZONS],
    "min_valid_n": int(MIN_VALID_N),
    "global_method": "Residual-correlation interaction heuristic on SHAP log-hazard scale",
    "horizon_method": "Approximate risk-scale interaction via weighted SHAP: SHAP * dr/deta, dr/deta=(-ln(1-r))*(1-r)",
    "weight_clip_pct": float(WEIGHT_CLIP_PCT),
    "fdr_alpha": float(FDR_ALPHA),
    "direction_method": "quartile delta on residuals: mean(resid|Q4 interactor) - mean(resid|Q1 interactor)",
    "time_dependent_rule": {
        "rank_range_min": int(TIME_DEP_RANK_RANGE_MIN),
        "abs_corr_delta_min": float(TIME_DEP_ABS_DELTA_MIN),
        "flag_definition": "time_dependent_flag = (rank_range >= threshold) OR (abs_corr_delta >= threshold), with >=2 horizons"
    },
    "terminology_note": (
    "'time_dependent_flag' / 'horizon_salience_variation_flag' indicates horizon-dependent "
    "interaction salience from risk transformation (and possible PH departures), not "
    "time-varying coefficients learned by the Cox model."
    ),
    "note": "Cox model structure is time-invariant; horizon-specific differences here represent risk-transform-dependent interaction salience, not different learned trees."
}
with open(f_info, "w", encoding="utf-8") as f:
    json.dump(run_info, f, indent=2)

# -----------------------------
# 7) Display quick summaries
# -----------------------------
display(Markdown("### Global Interaction Top (log-hazard SHAP)"))
display(df_global_top.head(30) if len(df_global_top) else pd.DataFrame())

display(Markdown("### Horizon Interaction Top (risk-scale approximation)"))
display(df_horizon_top.head(30) if len(df_horizon_top) else pd.DataFrame())

display(Markdown("### Time-Dependent Interaction Candidates"))
if len(df_time):
    display(df_time[df_time["time_dependent_top_flag"]].head(50))
else:
    display(pd.DataFrame())

# FIXED: Use display(HTML(...)) instead of print
display(HTML("<br><b>Saved files:</b>"))
for p in [f_global_scores, f_global_top, f_horizon_scores, f_horizon_top, f_time, f_hn, f_info]:
    if os.path.exists(p):
        display(HTML(f" - {p}"))

if saved_plot_files:
    display(HTML("<br><b>Saved plots:</b>"))
    for p in saved_plot_files:
        display(HTML(f" - {p}"))

Step 11 Interaction Discovery (DUAL)

  • Source bundle tag: 20260306_1821
  • Patients: 70521
  • Features: 56
  • Horizons: 3, 6, 12, 36, 60, 96 months
  • Global interactions: log-hazard SHAP scale
  • Horizon interactions: risk-scale approximation via dr/deta * SHAP

Global Interaction Top (log-hazard SHAP)

main_idx inter_idx main_feature main_rank_global main_importance_global interactor corr_resid_vs_interactor abs_corr delta_q4_q1_resid direction_q4_q1 n_dir_valid n_q1 n_q4 n_valid main_linear_slope main_linear_intercept outcome outcome_code scope horizon_months n_patients p_value p_fdr signif_fdr_0.01
0 0 7 adm_age_rec3 1 0.497441 ed_attainment_corr 0.064949 0.064949 0.011032 Positive 70521 51972 18397 70521 0.051331 -1.986121 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
1 0 10 adm_age_rec3 1 0.497441 eva_fam 0.060500 0.060500 0.008246 Positive 70521 39316 31202 70521 0.051331 -1.986121 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
2 0 51 adm_age_rec3 1 0.497441 occupation_condition_corr24_unemployed 0.060183 0.060183 0.012890 Positive 70521 46023 24498 70521 0.051331 -1.986121 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
3 0 11 adm_age_rec3 1 0.497441 eva_relinterp 0.059781 0.059781 0.008260 Positive 70521 39720 30796 70521 0.051331 -1.986121 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
4 0 13 adm_age_rec3 1 0.497441 eva_sm 0.058153 0.058153 0.006740 Positive 70521 40420 30092 70521 0.051331 -1.986121 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
5 42 0 primary_sub_mod_alcohol 2 0.267121 adm_age_rec3 -0.145518 0.145518 -0.022639 Negative 70521 17641 17640 70521 0.556013 -0.243998 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
6 42 4 primary_sub_mod_alcohol 2 0.267121 prim_sub_freq_rec -0.100948 0.100948 -0.012761 Negative 70521 39308 31002 70521 0.556013 -0.243998 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
7 42 44 primary_sub_mod_alcohol 2 0.267121 primary_sub_mod_cocaine_powder 0.099220 0.099220 0.002855 Positive 70521 56727 70521 70521 0.556013 -0.243998 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
8 42 47 primary_sub_mod_alcohol 2 0.267121 plan_type_corr_pg_pr -0.084876 0.084876 -0.001739 Negative 70521 62782 70521 70521 0.556013 -0.243998 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
9 42 43 primary_sub_mod_alcohol 2 0.267121 primary_sub_mod_cocaine_paste -0.079737 0.079737 -0.009595 Negative 70521 43856 26665 70521 0.556013 -0.243998 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
10 4 42 prim_sub_freq_rec 3 0.114488 primary_sub_mod_alcohol -0.062824 0.062824 -0.008331 Negative 70521 46546 23975 70521 0.155660 -0.213759 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
11 4 44 prim_sub_freq_rec 3 0.114488 primary_sub_mod_cocaine_powder 0.053674 0.053674 0.001663 Positive 70521 56727 70521 70521 0.155660 -0.213759 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
12 4 6 prim_sub_freq_rec 3 0.114488 urbanicity_cat -0.044370 0.044370 -0.001224 Negative 70521 57504 70521 70521 0.155660 -0.213759 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
13 4 35 prim_sub_freq_rec 3 0.114488 adm_motive_sanitary_sector 0.029000 0.029000 0.003935 Positive 70521 48588 21933 70521 0.155660 -0.213759 Death death global_log_hazard NaN 70521 1.332268e-14 4.774942e-14 True
14 4 28 prim_sub_freq_rec 3 0.114488 sub_dep_icd10_status_drug_dependence 0.028462 0.028462 0.004015 Positive 70521 19222 51299 70521 0.155660 -0.213759 Death death global_log_hazard NaN 70521 4.041212e-14 1.423209e-13 True
15 51 52 occupation_condition_corr24_unemployed 4 0.099876 occupation_condition_corr24_inactive 0.145013 0.145013 0.002365 Positive 70521 59054 70521 70521 0.217878 -0.079246 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
16 51 0 occupation_condition_corr24_unemployed 4 0.099876 adm_age_rec3 0.105125 0.105125 0.009976 Positive 70521 17641 17640 70521 0.217878 -0.079246 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
17 51 24 occupation_condition_corr24_unemployed 4 0.099876 sex_rec_woman 0.052826 0.052826 0.004477 Positive 70521 52439 18082 70521 0.217878 -0.079246 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
18 51 43 occupation_condition_corr24_unemployed 4 0.099876 primary_sub_mod_cocaine_paste -0.052387 0.052387 -0.003998 Negative 70521 43856 26665 70521 0.217878 -0.079246 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
19 51 53 occupation_condition_corr24_unemployed 4 0.099876 marital_status_rec_single -0.042157 0.042157 -0.003140 Negative 70521 31856 38601 70521 0.217878 -0.079246 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
20 22 0 any_phys_dx 5 0.092606 adm_age_rec3 -0.177597 0.177597 -0.011847 Negative 70521 17641 17640 70521 0.465739 -0.059663 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
21 22 42 any_phys_dx 5 0.092606 primary_sub_mod_alcohol -0.151679 0.151679 -0.008320 Negative 70521 46546 23975 70521 0.465739 -0.059663 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
22 22 39 any_phys_dx 5 0.092606 first_sub_used_alcohol -0.100611 0.100611 -0.005474 Negative 70521 26547 39050 70521 0.465739 -0.059663 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
23 22 43 any_phys_dx 5 0.092606 primary_sub_mod_cocaine_paste 0.089074 0.089074 0.004773 Positive 70521 43856 26665 70521 0.465739 -0.059663 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
24 22 23 any_phys_dx 5 0.092606 polysubstance_strict 0.063871 0.063871 0.003744 Positive 70521 18940 51581 70521 0.465739 -0.059663 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
25 12 42 eva_ocupacion 6 0.074887 primary_sub_mod_alcohol 0.096788 0.096788 0.005576 Positive 70521 46546 23975 70521 0.096627 -0.120813 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
26 12 43 eva_ocupacion 6 0.074887 primary_sub_mod_cocaine_paste -0.087721 0.087721 -0.004937 Negative 70521 43856 26665 70521 0.096627 -0.120813 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
27 12 2 eva_ocupacion 6 0.074887 dit_m -0.054846 0.054846 -0.002005 Negative 70521 17697 17661 70521 0.096627 -0.120813 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
28 12 17 eva_ocupacion 6 0.074887 dg_psiq_cie_10_instudy -0.049600 0.049600 -0.000620 Negative 70521 58292 70521 70521 0.096627 -0.120813 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True
29 12 31 eva_ocupacion 6 0.074887 tr_outcome_dropout 0.046922 0.046922 0.002567 Positive 70521 32933 37588 70521 0.096627 -0.120813 Death death global_log_hazard NaN 70521 0.000000e+00 0.000000e+00 True

Horizon Interaction Top (risk-scale approximation)

main_idx inter_idx main_feature main_rank_global main_importance_global interactor corr_resid_vs_interactor abs_corr delta_q4_q1_resid direction_q4_q1 n_dir_valid n_q1 n_q4 n_valid main_linear_slope main_linear_intercept outcome outcome_code scope horizon_months n_patients weight_clip_pct p_value p_fdr signif_fdr_0.01
0 0 26 adm_age_rec3 1 0.497441 cohabitation_with_couple_children -0.152021 0.152021 -0.001262 Negative 70521 38858 31663 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
1 0 42 adm_age_rec3 1 0.497441 primary_sub_mod_alcohol 0.150514 0.150514 0.001312 Positive 70521 46546 23975 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
2 0 22 adm_age_rec3 1 0.497441 any_phys_dx 0.148578 0.148578 0.000199 Positive 70521 63776 70521 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
3 0 12 adm_age_rec3 1 0.497441 eva_ocupacion 0.124464 0.124464 0.001223 Positive 70521 22291 29413 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
4 0 14 adm_age_rec3 1 0.497441 eva_fisica 0.124111 0.124111 0.001257 Positive 70521 21442 27083 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
5 42 4 primary_sub_mod_alcohol 2 0.267121 prim_sub_freq_rec 0.179168 0.179168 0.000848 Positive 70521 39308 31002 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
6 42 0 primary_sub_mod_alcohol 2 0.267121 adm_age_rec3 0.167873 0.167873 0.000744 Positive 70521 17641 17640 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
7 42 22 primary_sub_mod_alcohol 2 0.267121 any_phys_dx 0.151457 0.151457 0.000105 Positive 70521 63776 70521 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
8 42 28 primary_sub_mod_alcohol 2 0.267121 sub_dep_icd10_status_drug_dependence 0.132395 0.132395 0.000634 Positive 70521 19222 51299 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
9 42 51 primary_sub_mod_alcohol 2 0.267121 occupation_condition_corr24_unemployed 0.111628 0.111628 0.000500 Positive 70521 46023 24498 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
10 4 42 prim_sub_freq_rec 3 0.114488 primary_sub_mod_alcohol 0.177511 0.177511 0.000425 Positive 70521 46546 23975 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
11 4 0 prim_sub_freq_rec 3 0.114488 adm_age_rec3 0.138041 0.138041 0.000392 Positive 70521 17641 17640 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
12 4 43 prim_sub_freq_rec 3 0.114488 primary_sub_mod_cocaine_paste -0.122979 0.122979 -0.000287 Negative 70521 43856 26665 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
13 4 39 prim_sub_freq_rec 3 0.114488 first_sub_used_alcohol 0.099664 0.099664 0.000237 Positive 70521 26547 39050 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
14 4 22 prim_sub_freq_rec 3 0.114488 any_phys_dx 0.097974 0.097974 0.000036 Positive 70521 63776 70521 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
15 51 42 occupation_condition_corr24_unemployed 4 0.099876 primary_sub_mod_alcohol 0.128877 0.128877 0.000230 Positive 70521 46546 23975 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
16 51 43 occupation_condition_corr24_unemployed 4 0.099876 primary_sub_mod_cocaine_paste -0.112797 0.112797 -0.000196 Negative 70521 43856 26665 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
17 51 0 occupation_condition_corr24_unemployed 4 0.099876 adm_age_rec3 0.087931 0.087931 0.000211 Positive 70521 17641 17640 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
18 51 39 occupation_condition_corr24_unemployed 4 0.099876 first_sub_used_alcohol 0.072079 0.072079 0.000125 Positive 70521 26547 39050 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
19 51 53 occupation_condition_corr24_unemployed 4 0.099876 marital_status_rec_single -0.048971 0.048971 -0.000083 Negative 70521 31856 38601 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
20 22 42 any_phys_dx 5 0.092606 primary_sub_mod_alcohol 0.067958 0.067958 0.000204 Positive 70521 46546 23975 70521 0.004545 -0.000261 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
21 22 0 any_phys_dx 5 0.092606 adm_age_rec3 0.046603 0.046603 0.000147 Positive 70521 17641 17640 70521 0.004545 -0.000261 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
22 22 24 any_phys_dx 5 0.092606 sex_rec_woman -0.041313 0.041313 -0.000134 Negative 70521 52439 18082 70521 0.004545 -0.000261 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
23 22 43 any_phys_dx 5 0.092606 primary_sub_mod_cocaine_paste -0.040115 0.040115 -0.000118 Negative 70521 43856 26665 70521 0.004545 -0.000261 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
24 22 39 any_phys_dx 5 0.092606 first_sub_used_alcohol 0.038524 0.038524 0.000115 Positive 70521 26547 39050 70521 0.004545 -0.000261 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
25 12 42 eva_ocupacion 6 0.074887 primary_sub_mod_alcohol 0.076359 0.076359 0.000098 Positive 70521 46546 23975 70521 0.000470 -0.000508 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
26 12 43 eva_ocupacion 6 0.074887 primary_sub_mod_cocaine_paste -0.068948 0.068948 -0.000086 Negative 70521 43856 26665 70521 0.000470 -0.000508 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
27 12 0 eva_ocupacion 6 0.074887 adm_age_rec3 0.064855 0.064855 0.000119 Positive 70521 17641 17640 70521 0.000470 -0.000508 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
28 12 39 eva_ocupacion 6 0.074887 first_sub_used_alcohol 0.050938 0.050938 0.000064 Positive 70521 26547 39050 70521 0.000470 -0.000508 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True
29 12 17 eva_ocupacion 6 0.074887 dg_psiq_cie_10_instudy -0.045739 0.045739 -0.000013 Negative 70521 58292 70521 70521 0.000470 -0.000508 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True

Time-Dependent Interaction Candidates

outcome main_feature interactor horizons_seen mean_abs_corr sd_abs_corr min_abs_corr max_abs_corr min_rank max_rank n_rows n_positive n_negative min_p_fdr cv_abs_corr_pct abs_corr_delta rank_range n_direction_known direction_consistent direction_flip_flag dominant_direction in_top_at_least_once time_dependent_flag time_dependent_top_flag horizon_salience_variation_flag horizon_salience_variation_top_flag
108 Death any_phys_dx adm_age_rec3 6 0.034211 0.015852 0.005973 0.046603 2.0 32.0 6 6 0 0.000000e+00 46.334685 0.040630 30.0 6 True False Positive True True True True True
324 Death eva_fisica adm_age_rec3 6 0.036758 0.018491 0.004784 0.052027 1.0 37.0 6 6 0 0.000000e+00 50.304049 0.047243 36.0 6 True False Positive True True True True True
363 Death eva_fisica porc_pobr 6 0.025681 0.001073 0.024557 0.027237 5.0 14.0 6 6 0 1.649984e-12 4.179453 0.002680 9.0 6 True False Positive True True True True True
365 Death eva_fisica primary_sub_mod_alcohol 6 0.038989 0.014340 0.014170 0.050810 2.0 17.0 6 6 0 0.000000e+00 36.781100 0.036640 15.0 6 True False Positive True True True True True
376 Death eva_fisica tr_outcome_referral 6 0.026308 0.001334 0.025182 0.028431 4.0 13.0 6 6 0 1.593246e-13 5.069541 0.003249 9.0 6 True False Positive True True True True True
842 Readmission adm_age_rec3 occupation_condition_corr24_inactive 6 0.041640 0.016242 0.022616 0.062608 4.0 19.0 6 0 6 0.000000e+00 39.004466 0.039992 15.0 6 True False Negative True True True True True
845 Readmission adm_age_rec3 plan_type_corr_m_pr 6 0.065182 0.035428 0.019816 0.106975 1.0 17.0 6 6 0 0.000000e+00 54.352246 0.087160 16.0 6 True False Positive True True True True True
855 Readmission adm_age_rec3 sex_rec_woman 6 0.044455 0.018339 0.020435 0.065640 5.0 16.0 6 6 0 0.000000e+00 41.252924 0.045205 11.0 6 True False Positive True True True True True
899 Readmission dg_psiq_cie_10_dg plan_type_corr_m_pr 6 0.155484 0.049038 0.085606 0.207529 1.0 1.0 6 6 0 0.000000e+00 31.538817 0.121923 0.0 6 True False Positive True True True True True
909 Readmission dg_psiq_cie_10_dg sex_rec_woman 6 0.037152 0.018991 0.010226 0.057401 5.0 23.0 6 6 0 0.000000e+00 51.116498 0.047176 18.0 6 True False Positive True True True True True
953 Readmission dit_m plan_type_corr_m_pr 6 0.148101 0.037666 0.097855 0.190840 1.0 1.0 6 6 0 0.000000e+00 25.432443 0.092985 0.0 6 True False Positive True True True True True
963 Readmission dit_m sex_rec_woman 6 0.055991 0.014129 0.037076 0.071948 4.0 13.0 6 6 0 0.000000e+00 25.234850 0.034872 9.0 6 True False Positive True True True True True
1007 Readmission ed_attainment_corr plan_type_corr_m_pr 6 0.052324 0.022455 0.022006 0.077533 1.0 18.0 6 6 0 0.000000e+00 42.914296 0.055527 17.0 6 True False Positive True True True True True
1051 Readmission ethnicity first_sub_used_alcohol 6 0.029712 0.010699 0.013323 0.040150 4.0 16.0 6 0 6 0.000000e+00 36.007700 0.026827 12.0 6 True False Negative True True True True True
1061 Readmission ethnicity plan_type_corr_m_pr 6 0.020604 0.012193 0.009086 0.041859 2.0 28.0 6 3 3 0.000000e+00 59.177944 0.032773 26.0 6 False True Mixed True True True True True
1065 Readmission ethnicity porc_pobr 6 0.039169 0.007161 0.031173 0.048783 1.0 10.0 6 0 6 0.000000e+00 18.281690 0.017611 9.0 6 True False Negative True True True True True
1067 Readmission ethnicity primary_sub_mod_alcohol 6 0.038197 0.013067 0.018221 0.051016 1.0 11.0 6 0 6 0.000000e+00 34.211010 0.032796 10.0 6 True False Negative True True True True True
1068 Readmission ethnicity primary_sub_mod_cocaine_paste 6 0.034273 0.016735 0.009008 0.050998 2.0 23.0 6 6 0 0.000000e+00 48.829741 0.041990 21.0 6 True False Positive True True True True True
1071 Readmission ethnicity sex_rec_woman 6 0.010890 0.008346 0.001943 0.024748 5.0 51.0 6 3 3 1.340725e-10 76.642631 0.022805 46.0 6 False True Mixed True True True True True
1170 Readmission occupation_condition_corr24_unemployed plan_type_corr_pg_pai 6 0.081703 0.010481 0.066008 0.092004 5.0 14.0 6 0 6 0.000000e+00 12.828482 0.025996 9.0 6 True False Negative True True True True True
1171 Readmission occupation_condition_corr24_unemployed plan_type_corr_pg_pr 6 0.101673 0.017952 0.074779 0.119353 1.0 12.0 6 6 0 0.000000e+00 17.656655 0.044574 11.0 6 True False Positive True True True True True
1176 Readmission occupation_condition_corr24_unemployed primary_sub_mod_cocaine_paste 6 0.079447 0.011109 0.066258 0.093353 5.0 16.0 6 0 6 0.000000e+00 13.982321 0.027095 11.0 6 True False Negative True True True True True
1200 Readmission plan_type_corr_m_pr dit_m 6 0.064257 0.003388 0.057631 0.066431 3.0 13.0 6 6 0 0.000000e+00 5.272467 0.008800 10.0 6 True False Positive True True True True True
1260 Readmission plan_type_corr_pg_pr eva_consumo 6 0.059048 0.017228 0.039921 0.082455 4.0 13.0 6 0 6 0.000000e+00 29.175505 0.042535 9.0 6 True False Negative True True True True True
1278 Readmission plan_type_corr_pg_pr plan_type_corr_m_pr 6 0.055519 0.029815 0.012239 0.086585 2.0 31.0 6 0 6 0.000000e+00 53.702793 0.074345 29.0 6 True False Negative True True True True True
1287 Readmission plan_type_corr_pg_pr sex_rec_woman 6 0.043727 0.014149 0.022308 0.057678 4.0 23.0 6 0 6 0.000000e+00 32.356654 0.035370 19.0 6 True False Negative True True True True True
1292 Readmission plan_type_corr_pg_pr tr_outcome_adm_discharge_rule_violation_undet 6 0.071386 0.021337 0.040463 0.093659 1.0 15.0 6 6 0 0.000000e+00 29.889457 0.053196 14.0 6 True False Positive True True True True True
1332 Readmission polysubstance_strict plan_type_corr_m_pr 6 0.056049 0.022202 0.026792 0.081400 2.0 20.0 6 6 0 0.000000e+00 39.612464 0.054608 18.0 6 True False Positive True True True True True
1333 Readmission polysubstance_strict plan_type_corr_pg_pai 6 0.047180 0.011910 0.030664 0.060102 4.0 17.0 6 0 6 0.000000e+00 25.244293 0.029439 13.0 6 True False Negative True True True True True
1440 Readmission primary_sub_mod_alcohol plan_type_corr_m_pr 6 0.054807 0.031967 0.013113 0.091903 3.0 33.0 6 6 0 0.000000e+00 58.326777 0.078790 30.0 6 True False Positive True True True True True
1458 Readmission primary_sub_mod_cocaine_paste adm_age_rec3 6 0.059261 0.018697 0.038220 0.084279 1.0 22.0 6 6 0 0.000000e+00 31.549766 0.046059 21.0 6 True False Positive True True True True True
1494 Readmission primary_sub_mod_cocaine_paste plan_type_corr_m_pr 6 0.087245 0.041531 0.030005 0.132855 1.0 26.0 6 6 0 0.000000e+00 47.602505 0.102850 25.0 6 True False Positive True True True True True
1503 Readmission primary_sub_mod_cocaine_paste sex_rec_woman 6 0.058790 0.025934 0.022417 0.086724 2.0 32.0 6 6 0 0.000000e+00 44.112467 0.064307 30.0 6 True False Positive True True True True True
1548 Readmission sex_rec_woman plan_type_corr_m_pr 6 0.391598 0.093082 0.255008 0.487011 1.0 1.0 6 6 0 0.000000e+00 23.769867 0.232003 0.0 6 True False Positive True True True True True
1555 Readmission sex_rec_woman primary_sub_mod_cocaine_paste 6 0.073602 0.014975 0.051218 0.088624 4.0 18.0 6 6 0 0.000000e+00 20.346553 0.037405 14.0 6 True False Positive True True True True True
1602 Readmission sub_dep_icd10_status_drug_dependence plan_type_corr_m_pr 6 0.163591 0.058847 0.084133 0.229403 1.0 1.0 6 6 0 0.000000e+00 35.972393 0.145270 0.0 6 True False Positive True True True True True
1603 Readmission sub_dep_icd10_status_drug_dependence plan_type_corr_pg_pai 6 0.094575 0.022120 0.063708 0.118436 2.0 3.0 6 0 6 0.000000e+00 23.388715 0.054728 1.0 6 True False Negative True True True True True
1609 Readmission sub_dep_icd10_status_drug_dependence primary_sub_mod_cocaine_paste 6 0.052870 0.008618 0.042582 0.063723 2.0 10.0 6 0 6 0.000000e+00 16.299522 0.021141 8.0 6 True False Negative True True True True True
1612 Readmission sub_dep_icd10_status_drug_dependence sex_rec_woman 6 0.048065 0.024907 0.014542 0.076024 4.0 27.0 6 6 0 0.000000e+00 51.818747 0.061483 23.0 6 True False Positive True True True True True

Saved files:
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_global_scores_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_global_top_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_horizon_scores_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_horizon_top_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_time_dependent_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_horizon_sample_sizes_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_run_info_20260306_1821.json

Saved plots:
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb11_dual_death_time_profiles_20260306_1821.png
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb11_dual_death_time_profiles_20260306_1821.pdf
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb11_dual_readmission_time_profiles_20260306_1821.png
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs\xgb11_dual_readmission_time_profiles_20260306_1821.pdf
Code
#@title Step 11b: Statistical Augmentation (FDR + SMD + Scale/Flip/Threshold Diagnostics)

import os
import re
import glob
import json
import pickle
import numpy as np
import pandas as pd
from scipy.stats import t as tdist
from IPython.display import display, Markdown, HTML  # FIXED: Added HTML

# -----------------------------
# 0) Config
# -----------------------------
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = os.path.abspath(str(PROJECT_ROOT))
IN_DIR = os.path.join(PROJECT_ROOT, "_out")
OUT_DIR = os.path.join(PROJECT_ROOT, "_out")

# FIXED: Removed duplicate os.makedirs calls
os.makedirs(IN_DIR, exist_ok=True)
os.makedirs(OUT_DIR, exist_ok=True)

TARGET_HORIZONS = [3, 6, 12, 36, 60, 96]
INTER_TOP_K = 10  # top k interactions
MIN_VALID_N = 300  # minimum size to compute imbalance
MIN_GROUP_N = 30  # minimum size per group
WEIGHT_CLIP_PCT = 99.5  # clips/caps extreme risk-gradient weights

FDR_ALPHA = 0.01  # significance cutoff
ABS_R_MIN = 0.05  # minimum absolute correlation
ABS_SMD_MIN = 0.20  # minimum absolute SMD

TIME_DEP_RANK_RANGE_MIN = 8  # minimum rank changes between horizons to flag time-dep
TIME_DEP_ABS_DELTA_MIN = 0.05  # minimum correlation change to flag time-dependence
SENS_RANK_GRID = [5, 8, 10, 15]
SENS_DELTA_GRID = [0.02, 0.03, 0.05]

OUTCOME_CFG = {
    "readm": {
        "label": "Readmission",
        "shap_key_candidates": ["shap_r_all", "shap_readm_all"],
        "margin_key_candidates": ["risk_pred_readm", "risk_pred_r", "margin_pred_readm"],
        "probs_key_candidates": ["probs_readm_matrix", "probs_r_matrix", "surv_probs_readm_matrix"],
        "hz_times_candidates": ["times_r", "times_readm"],
        "hz_vals_candidates": ["h0_r", "H0_r", "h0_readm"],
    },
    "death": {
        "label": "Death",
        "shap_key_candidates": ["shap_d_all", "shap_death_all", "shap_mort_all"],
        "margin_key_candidates": ["risk_pred_death", "risk_pred_d", "risk_pred_mort", "margin_pred_death"],
        "probs_key_candidates": ["probs_death_matrix", "probs_d_matrix", "probs_mort_matrix", "surv_probs_death_matrix"],
        "hz_times_candidates": ["times_d", "times_death", "times_mort"],
        "hz_vals_candidates": ["h0_d", "H0_d", "h0_death", "h0_mort"],
    },
}
VAL_ID_KEYS = ["val_ids", "valid_ids", "val_idx", "val_index", "idx_val"]

# FIXED: Changed TS_RE to match timestamp anywhere (handles _mar26 suffix)
TS_RE = re.compile(r"(\d{8}_\d{4})")

# -----------------------------
# 1) Helpers
# -----------------------------
def _tag_from_path(path):
    m = TS_RE.search(os.path.basename(path))
    return m.group(1) if m else None

def find_first_key(dct, keys):
    for k in keys:
        if k in dct:
            return k
    return None

def get_first(dct, keys, default=None):
    for k in keys:
        if k in dct and dct[k] is not None:
            return dct[k]
    return default

def label_to_code(lbl):
    s = str(lbl).lower()
    if "readm" in s:
        return "readm"
    if "death" in s or "mort" in s:
        return "death"
    return None

def latest_step11_tag():
    """Find latest Step 11 run info file, handling _mar26 suffix."""
    # FIXED: Try both patterns
    patterns = [
        os.path.join(OUT_DIR, "xgb11_dual_interactions_run_info_*.json"),
        os.path.join(OUT_DIR, "xgb11_dual_interactions_run_info_*_mar26.json")
    ]
    cands = []
    for pat in patterns:
        cands.extend(glob.glob(pat))
    
    if not cands:
        raise FileNotFoundError(
            f"Run Step 11 first. Missing xgb11_dual_interactions_run_info_*.json in '{OUT_DIR}'"
        )
    
    cands = sorted(cands, key=os.path.getmtime)
    p = cands[-1]
    tag = _tag_from_path(p)
    if tag:
        return tag, p
    
    # Fallback: read from JSON
    with open(p, "r", encoding="utf-8") as f:
        d = json.load(f)
    tag = d.get("source_bundle_tag")
    if not tag:
        raise ValueError("Could not infer source_tag from Step 11 run_info.")
    return tag, p

def corr_to_pvalues(r, n):
    r = np.asarray(r, dtype=float)
    n = np.asarray(n, dtype=float)
    p = np.full(r.shape, np.nan, dtype=float)
    ok = np.isfinite(r) & np.isfinite(n) & (n > 2) & (np.abs(r) < 1)
    if not np.any(ok):
        return p
    tval = r[ok] * np.sqrt((n[ok] - 2.0) / np.clip(1.0 - r[ok] * r[ok], 1e-12, None))
    p[ok] = 2.0 * (1.0 - tdist.cdf(np.abs(tval), df=n[ok] - 2.0))
    return p

def bh_fdr(pvals):
    p = np.asarray(pvals, dtype=float)
    out = np.full(p.shape, np.nan, dtype=float)
    m = np.isfinite(p)
    pv = p[m]
    n = len(pv)
    if n == 0:
        return out
    order = np.argsort(pv)
    ranked = pv[order]
    q = ranked * n / np.arange(1, n + 1, dtype=float)
    q = np.minimum.accumulate(q[::-1])[::-1]
    q = np.clip(q, 0.0, 1.0)
    unsort = np.empty_like(q)
    unsort[order] = q
    out[m] = unsort
    return out

def add_fdr(df, group_cols, alpha=0.05):
    out = df.copy()
    if "p_value" not in out.columns:
        out["p_value"] = corr_to_pvalues(out["corr_resid_vs_interactor"], out["n_valid"])
    if "p_fdr" not in out.columns:
        out["p_fdr"] = np.nan
        for _, idx in out.groupby(group_cols, dropna=False).groups.items():
            idx = list(idx)
            out.loc[idx, "p_fdr"] = bh_fdr(out.loc[idx, "p_value"].to_numpy(dtype=float))
    fcol = f"signif_fdr_{alpha:.2f}"
    out[fcol] = out["p_fdr"] < alpha
    return out, fcol

def top_interactions(df, top_k):
    if len(df) == 0:
        return df.copy()
    if "horizon_months" in df.columns and df["horizon_months"].notna().any():
        group_cols = ["outcome", "horizon_months", "main_feature"]
        sort_cols = ["outcome", "horizon_months", "main_rank_global", "abs_corr"]
    else:
        group_cols = ["outcome", "main_feature"]
        sort_cols = ["outcome", "main_rank_global", "abs_corr"]
    z = df.sort_values(sort_cols, ascending=[True] * (len(sort_cols) - 1) + [False])
    return z.groupby(group_cols, dropna=False, group_keys=False).head(top_k).reset_index(drop=True)

def h0_at_t(times, h0_vals, t):
    times = np.asarray(times, dtype=float).ravel()
    h0_vals = np.asarray(h0_vals, dtype=float).ravel()
    # FIXED: Added length check and boundary handling
    if times.size == 0 or h0_vals.size == 0 or len(times) != len(h0_vals):
        return np.nan
    t = float(t)
    if t < times[0]:
        return 0.0
    i = np.searchsorted(times, t, side="right") - 1
    i = max(0, min(i, len(h0_vals) - 1))
    return float(h0_vals[i])

def collect_risk_by_id(raw_log, split_map, hz_map, cfg, horizon):
    rows = []
    t = float(horizon)
    for rec in raw_log:
        if "imp_idx" not in rec or "fold_idx" not in rec:
            continue
        key = (int(rec["imp_idx"]), int(rec["fold_idx"]))
        split_rec = split_map.get(key)
        hz_rec = hz_map.get(key)
        if split_rec is None:
            continue

        val_ids = get_first(split_rec, VAL_ID_KEYS, [])
        margins = np.asarray(get_first(rec, cfg["margin_key_candidates"], []), dtype=float).ravel()
        if len(val_ids) == 0 or len(margins) != len(val_ids):
            continue

        risk_vec = None
        eval_times = np.asarray(rec.get("eval_times", []), dtype=float).ravel()
        probs_mat = np.asarray(get_first(rec, cfg["probs_key_candidates"], []), dtype=float)

        if eval_times.size > 0 and probs_mat.ndim == 2:
            j = np.where(np.isclose(eval_times, t))[0]
            if j.size > 0:
                jj = int(j[0])
                if probs_mat.shape[0] == len(val_ids) and probs_mat.shape[1] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[:, jj]
                elif probs_mat.shape[1] == len(val_ids) and probs_mat.shape[0] == eval_times.size:
                    risk_vec = 1.0 - probs_mat[jj, :]
                if risk_vec is not None:
                    risk_vec = np.asarray(risk_vec, dtype=float).ravel()

        if risk_vec is None and hz_rec is not None:
            times = get_first(hz_rec, cfg["hz_times_candidates"], [])
            h0_vals = get_first(hz_rec, cfg["hz_vals_candidates"], [])
            H0_t = h0_at_t(times, h0_vals, t)
            if np.isfinite(H0_t):
                surv = np.exp(-np.exp(margins) * H0_t)
                risk_vec = 1.0 - surv

        if risk_vec is not None and len(risk_vec) == len(val_ids):
            risk_vec = np.clip(np.asarray(risk_vec, dtype=float).ravel(), 0.0, 1.0)
            for i, pid in enumerate(val_ids):
                rv = float(risk_vec[i])
                if np.isfinite(rv):
                    rows.append((str(pid), rv))

    if not rows:
        return pd.DataFrame(columns=["id", "risk"])
    return pd.DataFrame(rows, columns=["id", "risk"]).groupby("id", as_index=False)["risk"].mean()

def linear_residual(x, y, min_n=500, eps=1e-12):
    x = np.asarray(x, dtype=float).ravel()
    y = np.asarray(y, dtype=float).ravel()
    m = np.isfinite(x) & np.isfinite(y)
    if int(m.sum()) < min_n:
        return None
    xx = x[m]
    yy = y[m]
    resid = np.full_like(y, np.nan, dtype=float)
    if np.std(xx) < eps:
        resid[m] = yy - float(np.mean(yy))
        return resid
    A = np.column_stack([xx, np.ones_like(xx)])
    coef, *_ = np.linalg.lstsq(A, yy, rcond=None)
    resid[m] = yy - (coef[0] * xx + coef[1])
    return resid

def pooled_sd(a, b, eps=1e-12):
    a = np.asarray(a, dtype=float)
    b = np.asarray(b, dtype=float)
    a = a[np.isfinite(a)]
    b = b[np.isfinite(b)]
    n1, n0 = len(a), len(b)
    if n1 < 2 or n0 < 2:
        return np.nan
    v1 = np.var(a, ddof=1)
    v0 = np.var(b, ddof=1)
    den = (n1 + n0 - 2)
    if den <= 0:
        return np.nan
    sp = np.sqrt(((n1 - 1) * v1 + (n0 - 1) * v0) / den)
    return np.nan if sp < eps else float(sp)

def smd(hi, lo):
    sp = pooled_sd(hi, lo)
    if not np.isfinite(sp):
        return np.nan
    return float((np.mean(hi) - np.mean(lo)) / sp)

def dir_from_delta(d, eps=1e-12):
    if not np.isfinite(d):
        return "Unknown"
    if abs(d) < eps:
        return "Neutral"
    return "Positive" if d > 0 else "Negative"

# -----------------------------
# 2) Load Step11 outputs
# -----------------------------
source_tag, run_info_path = latest_step11_tag()

# FIXED: Try both naming patterns for Step 11 outputs
def find_step11_file(prefix, tag, out_dir):
    """Find Step 11 file with or without _mar26 suffix."""
    patterns = [
        os.path.join(out_dir, f"{prefix}_{tag}.csv"),
        os.path.join(out_dir, f"{prefix}_{tag}_mar26.csv"),
    ]
    for p in patterns:
        if os.path.exists(p):
            return p
    # Fallback: glob search
    files = glob.glob(os.path.join(out_dir, f"{prefix}_*.csv"))
    if files:
        return sorted(files, key=os.path.getmtime)[-1]
    return None

f_global_scores = find_step11_file("xgb11_dual_interactions_global_scores", source_tag, OUT_DIR)
f_horizon_scores = find_step11_file("xgb11_dual_interactions_horizon_scores", source_tag, OUT_DIR)
f_time = find_step11_file("xgb11_dual_interactions_time_dependent", source_tag, OUT_DIR)

if not f_global_scores or not os.path.exists(f_global_scores):
    raise FileNotFoundError(f"Missing Step 11 global scores file for tag {source_tag}")
if not f_horizon_scores or not os.path.exists(f_horizon_scores):
    raise FileNotFoundError(f"Missing Step 11 horizon scores file for tag {source_tag}")

df_global_scores = pd.read_csv(f_global_scores)
df_horizon_scores = pd.read_csv(f_horizon_scores)
df_time = pd.read_csv(f_time) if f_time and os.path.exists(f_time) else pd.DataFrame()

if "horizon_months" in df_horizon_scores.columns:
    df_horizon_scores["horizon_months"] = pd.to_numeric(df_horizon_scores["horizon_months"], errors="coerce")

# Add p/FDR if missing
df_global_scores, fdr_col = add_fdr(df_global_scores, group_cols=["outcome", "scope"] if "scope" in df_global_scores.columns else ["outcome"], alpha=FDR_ALPHA)
df_horizon_scores, _ = add_fdr(df_horizon_scores, group_cols=["outcome", "horizon_months", "scope"] if "scope" in df_horizon_scores.columns else ["outcome", "horizon_months"], alpha=FDR_ALPHA)

for _df in (df_global_scores, df_horizon_scores):
    _df["clinically_meaningful_r"] = _df["abs_corr"] >= ABS_R_MIN
    _df["stat_and_clinical_r"] = _df[fdr_col] & _df["clinically_meaningful_r"]

df_global_top = top_interactions(df_global_scores, INTER_TOP_K)
df_horizon_top = top_interactions(df_horizon_scores, INTER_TOP_K)

# -----------------------------
# 3) Load source artifacts (to compute SMDs on residuals)
# -----------------------------
# FIXED: Try both naming patterns for artifacts
def find_artifact(prefix, tag, in_dir):
    """Find artifact with or without _mar26 suffix."""
    patterns = [
        os.path.join(in_dir, f"{prefix}_{tag}.pkl"),
        os.path.join(in_dir, f"{prefix}_{tag}_mar26.pkl"),
    ]
    for p in patterns:
        if os.path.exists(p):
            return p
    raise FileNotFoundError(f"Missing artifact for tag {source_tag}: {prefix}")

shap_file = find_artifact("xgb6_corr_DUAL_SHAP_Aggregated", source_tag, IN_DIR)
raw_file = find_artifact("xgb6_corr_DUAL_final_ev_hyp", source_tag, IN_DIR)
hz_file = find_artifact("xgb6_corr_DUAL_BaselineHazards", source_tag, IN_DIR)
split_file = find_artifact("xgb6_corr_DUAL_CV_Splits", source_tag, IN_DIR)

with open(shap_file, "rb") as f:
    shap_data = pickle.load(f)
with open(raw_file, "rb") as f:
    raw_data_log = pickle.load(f)
with open(hz_file, "rb") as f:
    baseline_hazards_log = pickle.load(f)
with open(split_file, "rb") as f:
    cv_splits_log = pickle.load(f)

X_all = shap_data["X_all"]
feature_names = list(shap_data["feature_names"])
if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(X_all, columns=feature_names)
if list(X_all.columns) != feature_names:
    X_all = X_all.reindex(columns=feature_names)

# FIXED: Added duplicate feature check
if len(feature_names) != len(set(feature_names)):
    dupes = [f for f in feature_names if feature_names.count(f) > 1]
    raise ValueError(f"Duplicate feature names detected: {set(dupes)}")

X_all = X_all.apply(pd.to_numeric, errors="coerce")
X_np = X_all.to_numpy(dtype=float)

f2i = {f: i for i, f in enumerate(feature_names)}
id_to_row = {str(idx): i for i, idx in enumerate(X_all.index)}

split_map = {(int(s["imp_idx"]), int(s["fold_idx"])): s for s in cv_splits_log if "imp_idx" in s and "fold_idx" in s}
hz_map = {(int(h["imp_idx"]), int(h["fold_idx"])): h for h in baseline_hazards_log if "imp_idx" in h and "fold_idx" in h}

def shap_for_outcome(outcome_code):
    cfg = OUTCOME_CFG[outcome_code]
    key = find_first_key(shap_data, cfg["shap_key_candidates"])
    if key is None:
        return None
    arr = np.asarray(shap_data[key], dtype=float)
    return arr if arr.shape == X_np.shape else None

dataset_cache = {}
resid_cache = {}

def get_dataset(outcome_label, horizon=None):
    outcome_code = label_to_code(outcome_label)
    if outcome_code is None:
        return None, None
    k = (outcome_code, None if horizon is None else float(horizon))
    if k in dataset_cache:
        return dataset_cache[k]

    S = shap_for_outcome(outcome_code)
    if S is None:
        dataset_cache[k] = (None, None)
        return dataset_cache[k]

    if horizon is None:
        Xd, Sd = X_np, S
    else:
        cfg = OUTCOME_CFG[outcome_code]
        df_risk = collect_risk_by_id(raw_data_log, split_map, hz_map, cfg, float(horizon))
        df_risk = df_risk[df_risk["id"].isin(id_to_row.keys())]
        if len(df_risk) < MIN_VALID_N:
            dataset_cache[k] = (None, None)
            return dataset_cache[k]

        row_idx = df_risk["id"].map(id_to_row).to_numpy(dtype=int)
        risk = np.clip(df_risk["risk"].to_numpy(dtype=float), 1e-12, 1.0 - 1e-12)
        w = (-np.log(1.0 - risk)) * (1.0 - risk)
        if np.isfinite(w).sum() == 0:
            dataset_cache[k] = (None, None)
            return dataset_cache[k]
        cap = np.nanpercentile(w, WEIGHT_CLIP_PCT)
        w = np.clip(w, 0.0, cap)

        Xd = X_np[row_idx, :]
        Sd = S[row_idx, :] * w[:, None]

    dataset_cache[k] = (Xd, Sd)
    return dataset_cache[k]

def get_residual(outcome_label, horizon, main_idx):
    key = (str(outcome_label), None if horizon is None else float(horizon), int(main_idx))
    if key in resid_cache:
        return resid_cache[key]

    Xd, Sd = get_dataset(outcome_label, horizon)
    if Xd is None or Sd is None:
        resid_cache[key] = (None, None)
        return resid_cache[key]

    resid = linear_residual(Xd[:, main_idx], Sd[:, main_idx], min_n=MIN_VALID_N)
    resid_cache[key] = (resid, Xd)
    return resid_cache[key]

def effect_metrics(outcome_label, horizon, main_feature, interactor):
    if main_feature not in f2i or interactor not in f2i:
        return {
            "delta_q4_q1_resid": np.nan, "smd_q4_q1": np.nan, "direction_q4_q1": "Unknown",
            "delta_binarized_resid": np.nan, "smd_binarized": np.nan, "direction_binarized": "Unknown",
            "binarized_method": "unknown",
        }

    i = f2i[main_feature]
    j = f2i[interactor]
    resid, Xd = get_residual(outcome_label, horizon, i)
    if resid is None or Xd is None:
        return {
            "delta_q4_q1_resid": np.nan, "smd_q4_q1": np.nan, "direction_q4_q1": "Unknown",
            "delta_binarized_resid": np.nan, "smd_binarized": np.nan, "direction_binarized": "Unknown",
            "binarized_method": "unknown",
        }

    x = Xd[:, j]
    r = resid
    m = np.isfinite(x) & np.isfinite(r)
    if int(m.sum()) < MIN_VALID_N:
        return {
            "delta_q4_q1_resid": np.nan, "smd_q4_q1": np.nan, "direction_q4_q1": "Unknown",
            "delta_binarized_resid": np.nan, "smd_binarized": np.nan, "direction_binarized": "Unknown",
            "binarized_method": "unknown",
        }

    xx = x[m]
    rr = r[m]

    # Continuous SMD: Q4-Q1
    q1, q3 = np.quantile(xx, [0.25, 0.75])
    g1 = rr[xx <= q1]
    g4 = rr[xx >= q3]
    if len(g1) >= MIN_GROUP_N and len(g4) >= MIN_GROUP_N:
        delta_q = float(np.mean(g4) - np.mean(g1))
        smd_q = smd(g4, g1)
        dir_q = dir_from_delta(delta_q)
    else:
        delta_q, smd_q, dir_q = np.nan, np.nan, "Unknown"

    # Binarized SMD: true binary (1-0) else median split (high-low)
    vals = np.unique(np.round(xx[np.isfinite(xx)], 10))
    is_binary = (len(vals) <= 2) and set(vals).issubset({0.0, 1.0})
    if is_binary:
        lo = rr[xx == 0.0]
        hi = rr[xx == 1.0]
        method = "binary_1_minus_0"
    else:
        med = np.median(xx)
        lo = rr[xx < med]
        hi = rr[xx >= med]
        method = "median_high_minus_low"

    if len(lo) >= MIN_GROUP_N and len(hi) >= MIN_GROUP_N:
        delta_b = float(np.mean(hi) - np.mean(lo))
        smd_b = smd(hi, lo)
        dir_b = dir_from_delta(delta_b)
    else:
        delta_b, smd_b, dir_b = np.nan, np.nan, "Unknown"

    return {
        "delta_q4_q1_resid": delta_q,
        "smd_q4_q1": smd_q,
        "direction_q4_q1": dir_q,
        "delta_binarized_resid": delta_b,
        "smd_binarized": smd_b,
        "direction_binarized": dir_b,
        "binarized_method": method,
    }

def enrich_with_smd(df, is_horizon):
    if len(df) == 0:
        return df.copy()

    rows = []
    for _, r in df.iterrows():
        h = float(r["horizon_months"]) if is_horizon else None
        rows.append(effect_metrics(r["outcome"], h, r["main_feature"], r["interactor"]))

    extra = pd.DataFrame(rows)

    # overwrite existing columns instead of duplicating names
    base = df.copy()
    for c in extra.columns:
        if c in base.columns:
            base = base.drop(columns=[c])

    out = pd.concat([base.reset_index(drop=True), extra], axis=1)
    out["clinically_meaningful_smd"] = (out["smd_q4_q1"].abs() >= ABS_SMD_MIN) | (out["smd_binarized"].abs() >= ABS_SMD_MIN)
    out["clinically_meaningful_combined"] = out["clinically_meaningful_r"] | out["clinically_meaningful_smd"]
    out["stat_and_clinical_combined"] = out[fdr_col] & out["clinically_meaningful_combined"]
    return out

df_global_top_enriched = enrich_with_smd(df_global_top, is_horizon=False)
df_horizon_top_enriched = enrich_with_smd(df_horizon_top, is_horizon=True)

# -----------------------------
# 4) Scale comparison + direction flip summaries
# -----------------------------
if len(df_global_top_enriched) and len(df_horizon_top_enriched):
    gcols = [
        "outcome", "main_feature", "interactor",
        "corr_resid_vs_interactor", "abs_corr", "direction_q4_q1", "smd_q4_q1", "smd_binarized",
        fdr_col, "clinically_meaningful_combined", "stat_and_clinical_combined"
    ]
    g = df_global_top_enriched[gcols].rename(columns={
        "corr_resid_vs_interactor": "corr_resid_vs_interactor_global",
        "abs_corr": "abs_corr_global",
        "direction_q4_q1": "direction_q4_q1_global",
        "smd_q4_q1": "smd_q4_q1_global",
        "smd_binarized": "smd_binarized_global",
        fdr_col: f"{fdr_col}_global",
        "clinically_meaningful_combined": "clinically_meaningful_combined_global",
        "stat_and_clinical_combined": "stat_and_clinical_combined_global",
    })

    scale_cmp = df_horizon_top_enriched.merge(g, on=["outcome", "main_feature", "interactor"], how="inner")
    scale_cmp["delta_abs_corr_risk_minus_global"] = scale_cmp["abs_corr"] - scale_cmp["abs_corr_global"]
    scale_cmp["ratio_abs_corr_risk_over_global"] = scale_cmp["abs_corr"] / np.clip(scale_cmp["abs_corr_global"], 1e-12, None)
    scale_cmp["corr_direction_changed"] = np.sign(scale_cmp["corr_resid_vs_interactor"]) != np.sign(scale_cmp["corr_resid_vs_interactor_global"])
    scale_cmp["q4q1_direction_changed"] = (
        (scale_cmp["direction_q4_q1"] != "Unknown") &
        (scale_cmp["direction_q4_q1_global"] != "Unknown") &
        (scale_cmp["direction_q4_q1"] != scale_cmp["direction_q4_q1_global"])
    )
    scale_cmp["delta_abs_smd_q4q1_risk_minus_global"] = scale_cmp["smd_q4_q1"].abs() - scale_cmp["smd_q4_q1_global"].abs()
else:
    scale_cmp = pd.DataFrame()

# keep last duplicate if any
scale_cmp = scale_cmp.loc[:, ~scale_cmp.columns.duplicated(keep="last")].copy()

dir_risk = scale_cmp["direction_q4_q1"].astype(str)
dir_glob = scale_cmp["direction_q4_q1_global"].astype(str)

scale_cmp["q4q1_direction_changed"] = (
    dir_risk.ne("Unknown") &
    dir_glob.ne("Unknown") &
    dir_risk.ne(dir_glob)
)

# Flip across horizons (using quartile-direction from risk-scale top interactions)
if len(df_horizon_top_enriched):
    d = df_horizon_top_enriched.copy()
    d["dir_pos"] = (d["direction_q4_q1"] == "Positive").astype(int)
    d["dir_neg"] = (d["direction_q4_q1"] == "Negative").astype(int)

    pair_flip = (
        d.groupby(["outcome", "main_feature", "interactor"], as_index=False)
         .agg(
             n_horizons=("horizon_months", "nunique"),
             n_pos=("dir_pos", "sum"),
             n_neg=("dir_neg", "sum"),
         )
    )
    pair_flip["direction_flip_across_horizons"] = (pair_flip["n_pos"] > 0) & (pair_flip["n_neg"] > 0)

    flip_summary_hz = (
        pair_flip.groupby("outcome", as_index=False)
        .agg(
            total_pairs=("direction_flip_across_horizons", "size"),
            flipped_pairs=("direction_flip_across_horizons", "sum")
        )
    )
    flip_summary_hz["flip_pct"] = 100.0 * flip_summary_hz["flipped_pairs"] / np.clip(flip_summary_hz["total_pairs"], 1, None)
else:
    pair_flip = pd.DataFrame()
    flip_summary_hz = pd.DataFrame()

# Flip between scales (global vs risk)
if len(scale_cmp):
    flip_summary_scale = (
        scale_cmp.groupby(["outcome", "horizon_months"], as_index=False)
        .agg(
            n_pairs=("main_feature", "size"),
            corr_flip_n=("corr_direction_changed", "sum"),
            q4q1_flip_n=("q4q1_direction_changed", "sum"),
        )
    )
    flip_summary_scale["corr_flip_pct"] = 100.0 * flip_summary_scale["corr_flip_n"] / np.clip(flip_summary_scale["n_pairs"], 1, None)
    flip_summary_scale["q4q1_flip_pct"] = 100.0 * flip_summary_scale["q4q1_flip_n"] / np.clip(flip_summary_scale["n_pairs"], 1, None)
else:
    flip_summary_scale = pd.DataFrame()

# -----------------------------
# 5) Threshold justification + sensitivity
# -----------------------------
if len(df_time):
    just_rows = []
    for out in sorted(df_time["outcome"].dropna().unique()):
        z = df_time[df_time["outcome"] == out].copy()
        if len(z) == 0:
            continue

        rr = z["rank_range"].to_numpy(dtype=float)
        dd = z["abs_corr_delta"].to_numpy(dtype=float)

        just_rows.append({
            "outcome": out,
            "metric": "rank_range",
            "p50": float(np.nanquantile(rr, 0.50)),
            "p75": float(np.nanquantile(rr, 0.75)),
            "p90": float(np.nanquantile(rr, 0.90)),
            "chosen_threshold": float(TIME_DEP_RANK_RANGE_MIN),
            "chosen_percentile": float(100.0 * np.nanmean(rr <= TIME_DEP_RANK_RANGE_MIN)),
        })
        just_rows.append({
            "outcome": out,
            "metric": "abs_corr_delta",
            "p50": float(np.nanquantile(dd, 0.50)),
            "p75": float(np.nanquantile(dd, 0.75)),
            "p90": float(np.nanquantile(dd, 0.90)),
            "chosen_threshold": float(TIME_DEP_ABS_DELTA_MIN),
            "chosen_percentile": float(100.0 * np.nanmean(dd <= TIME_DEP_ABS_DELTA_MIN)),
        })

    df_threshold_just = pd.DataFrame(just_rows)

    sens_rows = []
    for out in sorted(df_time["outcome"].dropna().unique()):
        z = df_time[df_time["outcome"] == out].copy()
        if len(z) == 0:
            continue
        for rr_th in SENS_RANK_GRID:
            for dd_th in SENS_DELTA_GRID:
                flag = (z["horizons_seen"] >= 2) & ((z["rank_range"] >= rr_th) | (z["abs_corr_delta"] >= dd_th))
                sens_rows.append({
                    "outcome": out,
                    "rank_range_threshold": int(rr_th),
                    "abs_corr_delta_threshold": float(dd_th),
                    "n_flagged": int(flag.sum()),
                    "n_total": int(len(z)),
                    "pct_flagged": float(100.0 * flag.mean()),
                    "is_current_rule": bool((rr_th == TIME_DEP_RANK_RANGE_MIN) and (abs(dd_th - TIME_DEP_ABS_DELTA_MIN) < 1e-12)),
                })
    df_threshold_sens = pd.DataFrame(sens_rows)
else:
    df_threshold_just = pd.DataFrame()
    df_threshold_sens = pd.DataFrame()

# -----------------------------
# 6) Effect-size summary
# -----------------------------
def summarize_effects(df, label):
    if len(df) == 0:
        return pd.DataFrame()
    rows = []
    for out, z in df.groupby("outcome", dropna=False):
        rows.append({
            "table": label,
            "outcome": out,
            "n_rows": int(len(z)),
            "fdr_sig_n": int(z[fdr_col].sum()) if fdr_col in z.columns else np.nan,
            "clinically_meaningful_r_n": int(z["clinically_meaningful_r"].sum()) if "clinically_meaningful_r" in z.columns else np.nan,
            "clinically_meaningful_smd_n": int(z["clinically_meaningful_smd"].sum()) if "clinically_meaningful_smd" in z.columns else np.nan,
            "clinically_meaningful_combined_n": int(z["clinically_meaningful_combined"].sum()) if "clinically_meaningful_combined" in z.columns else np.nan,
            "stat_and_clinical_combined_n": int(z["stat_and_clinical_combined"].sum()) if "stat_and_clinical_combined" in z.columns else np.nan,
        })
    return pd.DataFrame(rows)

effect_summary = pd.concat([
    summarize_effects(df_global_top_enriched, "global_top"),
    summarize_effects(df_horizon_top_enriched, "horizon_top"),
], ignore_index=True)

# -----------------------------
# 7) Save outputs
# -----------------------------
f_global_enr = os.path.join(OUT_DIR, f"xgb11_dual_interactions_global_top_enriched_{source_tag}.csv")
f_horizon_enr = os.path.join(OUT_DIR, f"xgb11_dual_interactions_horizon_top_enriched_{source_tag}.csv")
f_scale_cmp = os.path.join(OUT_DIR, f"xgb11_dual_interactions_scale_comparison_{source_tag}.csv")
f_flip_hz = os.path.join(OUT_DIR, f"xgb11_dual_interactions_direction_flip_horizon_{source_tag}.csv")
f_flip_scale = os.path.join(OUT_DIR, f"xgb11_dual_interactions_direction_flip_scale_{source_tag}.csv")
f_th_just = os.path.join(OUT_DIR, f"xgb11_dual_interactions_threshold_justification_{source_tag}.csv")
f_th_sens = os.path.join(OUT_DIR, f"xgb11_dual_interactions_threshold_sensitivity_{source_tag}.csv")
f_eff = os.path.join(OUT_DIR, f"xgb11_dual_interactions_effectsize_summary_{source_tag}.csv")
f_info = os.path.join(OUT_DIR, f"xgb11_dual_interactions_aug_run_info_{source_tag}.json")

if len(df_global_top_enriched): df_global_top_enriched.to_csv(f_global_enr, index=False)
if len(df_horizon_top_enriched): df_horizon_top_enriched.to_csv(f_horizon_enr, index=False)
if len(scale_cmp): scale_cmp.to_csv(f_scale_cmp, index=False)
if len(flip_summary_hz): flip_summary_hz.to_csv(f_flip_hz, index=False)
if len(flip_summary_scale): flip_summary_scale.to_csv(f_flip_scale, index=False)
if len(df_threshold_just): df_threshold_just.to_csv(f_th_just, index=False)
if len(df_threshold_sens): df_threshold_sens.to_csv(f_th_sens, index=False)
if len(effect_summary): effect_summary.to_csv(f_eff, index=False)

run_info = {
    "source_tag": source_tag,
    "fdr_alpha": float(FDR_ALPHA),
    "clinical_effect_thresholds": {
        "abs_r_min": float(ABS_R_MIN),
        "abs_smd_min": float(ABS_SMD_MIN),
        "rule": "clinically meaningful if |r|>=abs_r_min OR |SMD|>=abs_smd_min"
    },
    "smd_definitions": {
        "continuous": "Q4-Q1 standardized mean difference of residual interaction signal",
        "binarized": "binary(1-0) if interactor is binary else median(high-low) split SMD"
    },
    "time_dep_rule": {
        "rank_range_min": int(TIME_DEP_RANK_RANGE_MIN),
        "abs_corr_delta_min": float(TIME_DEP_ABS_DELTA_MIN),
        "sensitivity_rank_grid": SENS_RANK_GRID,
        "sensitivity_delta_grid": SENS_DELTA_GRID
    }
}
with open(f_info, "w", encoding="utf-8") as f:
    json.dump(run_info, f, indent=2)

# -----------------------------
# 8) Display
# -----------------------------
display(Markdown("### Effect-Size Summary (Top Interactions)"))
display(effect_summary if len(effect_summary) else pd.DataFrame())

display(Markdown("### Scale Comparison (Global vs Risk Scale)"))
if len(scale_cmp):
    display(scale_cmp.sort_values(["outcome", "horizon_months", "abs_corr"], ascending=[True, True, False]).head(30))
else:
    display(pd.DataFrame())

display(Markdown("### Direction Flip Summary Across Horizons"))
display(flip_summary_hz if len(flip_summary_hz) else pd.DataFrame())

display(Markdown("### Direction Flip Summary Between Scales"))
display(flip_summary_scale if len(flip_summary_scale) else pd.DataFrame())

display(Markdown("### Threshold Justification"))
display(df_threshold_just if len(df_threshold_just) else pd.DataFrame())

display(Markdown("### Threshold Sensitivity Grid"))
display(df_threshold_sens if len(df_threshold_sens) else pd.DataFrame())

# FIXED: Use display(HTML(...)) instead of print
display(HTML("<br><b>Saved files:</b>"))
for p in [f_global_enr, f_horizon_enr, f_scale_cmp, f_flip_hz, f_flip_scale, f_th_just, f_th_sens, f_eff, f_info]:
    if os.path.exists(p):
        display(HTML(f" - {p}"))

Effect-Size Summary (Top Interactions)

table outcome n_rows fdr_sig_n clinically_meaningful_r_n clinically_meaningful_smd_n clinically_meaningful_combined_n stat_and_clinical_combined_n
0 global_top Death 150 150 82 40 83 83
1 global_top Readmission 150 150 59 31 62 62
2 horizon_top Death 900 900 501 258 504 504
3 horizon_top Readmission 900 900 653 238 669 669

Scale Comparison (Global vs Risk Scale)

main_idx inter_idx main_feature main_rank_global main_importance_global interactor corr_resid_vs_interactor abs_corr n_dir_valid n_q1 n_q4 n_valid main_linear_slope main_linear_intercept outcome outcome_code scope horizon_months n_patients weight_clip_pct p_value p_fdr signif_fdr_0.01 clinically_meaningful_r stat_and_clinical_r delta_q4_q1_resid smd_q4_q1 direction_q4_q1 delta_binarized_resid smd_binarized direction_binarized binarized_method clinically_meaningful_smd clinically_meaningful_combined stat_and_clinical_combined corr_resid_vs_interactor_global abs_corr_global direction_q4_q1_global smd_q4_q1_global smd_binarized_global signif_fdr_0.01_global clinically_meaningful_combined_global stat_and_clinical_combined_global delta_abs_corr_risk_minus_global ratio_abs_corr_risk_over_global corr_direction_changed q4q1_direction_changed delta_abs_smd_q4q1_risk_minus_global
59 43 42 primary_sub_mod_cocaine_paste 14 0.035601 primary_sub_mod_alcohol 0.325647 0.325647 70521 46546 23975 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000188 0.727078 Positive 0.000188 0.727078 Positive binary_1_minus_0 True True True 0.300033 0.300033 Positive 0.663964 0.663964 True True True 0.025614 1.085370 False False 0.063114
60 43 0 primary_sub_mod_cocaine_paste 14 0.035601 adm_age_rec3 0.291960 0.291960 70521 17641 17640 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000182 0.544160 Positive 0.000097 0.360177 Positive median_high_minus_low True True True 0.088748 0.088748 Positive 0.188665 0.121687 True True True 0.203213 3.289784 False False 0.355494
61 43 44 primary_sub_mod_cocaine_paste 14 0.035601 primary_sub_mod_cocaine_powder -0.288833 0.288833 70521 56727 70521 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000039 -0.138371 Negative -0.000199 -0.760562 Negative binary_1_minus_0 True True True -0.356579 0.356579 Negative -0.176271 -0.962184 True True True -0.067746 0.810011 False False -0.037900
62 43 23 primary_sub_mod_cocaine_paste 14 0.035601 polysubstance_strict -0.211186 0.211186 70521 18940 51581 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000131 -0.487472 Negative -0.000131 -0.487472 Negative binary_1_minus_0 True True True -0.115123 0.115123 Negative -0.261479 -0.261479 True True True 0.096063 1.834439 False False 0.225993
5 42 4 primary_sub_mod_alcohol 2 0.267121 prim_sub_freq_rec 0.179168 0.179168 70521 39308 31002 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000848 0.405865 Positive 0.000467 0.219768 Positive median_high_minus_low True True True -0.100948 0.100948 Negative -0.219851 -0.135415 True True True 0.078220 1.774855 True True 0.186014
9 4 42 prim_sub_freq_rec 3 0.114488 primary_sub_mod_alcohol 0.177511 0.177511 70521 46546 23975 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000425 0.380776 Positive 0.000425 0.380776 Positive binary_1_minus_0 True True True -0.062824 0.062824 Negative -0.132885 -0.132885 True True True 0.114687 2.825531 True True 0.247891
51 39 0 first_sub_used_alcohol 12 0.041504 adm_age_rec3 0.171441 0.171441 70521 17641 17640 70521 0.000380 -0.000168 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000094 0.300822 Positive 0.000050 0.193345 Positive median_high_minus_low True True True -0.107464 0.107464 Negative -0.282832 -0.188272 True True True 0.063977 1.595331 True True 0.017990
28 26 0 cohabitation_with_couple_children 7 0.058190 adm_age_rec3 0.170886 0.170886 70521 17641 17640 70521 -0.000688 0.000372 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000286 0.356825 Positive 0.000176 0.285082 Positive median_high_minus_low True True True 0.127396 0.127396 Positive 0.305185 0.213123 True True True 0.043489 1.341371 False False 0.051640
6 42 0 primary_sub_mod_alcohol 2 0.267121 adm_age_rec3 0.167873 0.167873 70521 17641 17640 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000744 0.281510 Positive 0.000323 0.151803 Positive median_high_minus_low True True True -0.145518 0.145518 Negative -0.390027 -0.271913 True True True 0.022354 1.153617 True True -0.108517
52 39 23 first_sub_used_alcohol 12 0.041504 polysubstance_strict -0.158622 0.158622 70521 18940 51581 70521 0.000380 -0.000168 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000092 -0.362473 Negative -0.000092 -0.362473 Negative binary_1_minus_0 True True True 0.106666 0.106666 Positive 0.242040 0.242040 True True True 0.051957 1.487099 True True 0.120433
0 0 26 adm_age_rec3 1 0.497441 cohabitation_with_couple_children -0.152021 0.152021 70521 38858 31663 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.001262 -0.309227 Negative -0.001262 -0.309227 Negative binary_1_minus_0 True True True -0.039043 0.039043 Negative -0.078555 -0.078555 True False False 0.112978 3.893637 False False 0.230671
53 39 42 first_sub_used_alcohol 12 0.041504 primary_sub_mod_alcohol 0.147320 0.147320 70521 46546 23975 70521 0.000380 -0.000168 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000080 0.314427 Positive 0.000080 0.314427 Positive binary_1_minus_0 True True True -0.118822 0.118822 Negative -0.252624 -0.252624 True True True 0.028499 1.239845 True True 0.061802
29 26 42 cohabitation_with_couple_children 7 0.058190 primary_sub_mod_alcohol 0.131287 0.131287 70521 46546 23975 70521 -0.000688 0.000372 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000172 0.279569 Positive 0.000172 0.279569 Positive binary_1_minus_0 True True True 0.049809 0.049809 Positive 0.105279 0.105279 True False False 0.081477 2.635778 False False 0.174289
63 43 39 primary_sub_mod_cocaine_paste 14 0.035601 first_sub_used_alcohol 0.129522 0.129522 70521 26547 39050 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000075 0.273278 Positive 0.000069 0.255599 Positive median_high_minus_low True True True 0.077659 0.077659 Positive 0.164747 0.153372 True True True 0.051863 1.667836 False False 0.108531
1 0 12 adm_age_rec3 1 0.497441 eva_ocupacion 0.124464 0.124464 70521 22291 29413 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.001223 0.296123 Positive 0.001055 0.257423 Positive median_high_minus_low True True True 0.054993 0.054993 Positive 0.134436 0.138846 True True True 0.069471 2.263252 False False 0.161686
13 51 43 occupation_condition_corr24_unemployed 4 0.099876 primary_sub_mod_cocaine_paste -0.112797 0.112797 70521 43856 26665 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000196 -0.234101 Negative -0.000196 -0.234101 Negative binary_1_minus_0 True True True -0.052387 0.052387 Negative -0.108181 -0.108181 True True True 0.060409 2.153132 False False 0.125921
30 26 25 cohabitation_with_couple_children 7 0.058190 cohabitation_family_of_origin -0.111886 0.111886 70521 44451 26070 70521 -0.000688 0.000372 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000144 -0.233244 Negative -0.000144 -0.233244 Negative binary_1_minus_0 True True True -0.103066 0.103066 Negative -0.214651 -0.214651 True True True 0.008820 1.085577 False False 0.018593
7 42 51 primary_sub_mod_alcohol 2 0.267121 occupation_condition_corr24_unemployed 0.111628 0.111628 70521 46023 24498 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000500 0.235915 Positive 0.000500 0.235915 Positive binary_1_minus_0 True True True -0.064899 0.064899 Negative -0.136588 -0.136588 True True True 0.046729 1.720025 True True 0.099326
31 26 54 cohabitation_with_couple_children 7 0.058190 marital_status_rec_separated_divorced_annulled_widowed 0.104768 0.104768 70521 62329 70521 70521 -0.000688 0.000372 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000024 0.039857 Positive NaN NaN Unknown median_high_minus_low False True True 0.092106 0.092106 Positive 0.033652 NaN True True True 0.012662 1.137466 False False 0.006204
2 0 13 adm_age_rec3 1 0.497441 eva_sm 0.100115 0.100115 70521 40420 30092 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000641 0.155710 Positive 0.000939 0.228642 Positive median_high_minus_low True True True 0.058153 0.058153 Positive 0.066130 0.164697 True True True 0.041961 1.721562 False False 0.089580
10 4 39 prim_sub_freq_rec 3 0.114488 first_sub_used_alcohol 0.099664 0.099664 70521 26547 39050 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000237 0.207729 Positive 0.000218 0.193663 Positive median_high_minus_low True True True -0.023307 0.023307 Negative -0.048801 -0.048779 True False False 0.076357 4.276169 True True 0.158928
3 0 8 adm_age_rec3 1 0.497441 evaluacindelprocesoteraputico 0.098085 0.098085 70521 39975 30545 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000613 0.148868 Positive 0.000945 0.230016 Positive median_high_minus_low True True True 0.055363 0.055363 Positive 0.068248 0.150496 True True True 0.042722 1.771664 False False 0.080620
11 4 22 prim_sub_freq_rec 3 0.114488 any_phys_dx 0.097974 0.097974 70521 63776 70521 70521 0.000912 -0.001065 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000036 0.034450 Positive NaN NaN Unknown median_high_minus_low False True True 0.027026 0.027026 Positive 0.008794 NaN True False False 0.070948 3.625177 False False 0.025655
4 0 10 adm_age_rec3 1 0.497441 eva_fam 0.095774 0.095774 70521 39316 31202 70521 0.000366 -0.011619 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000662 0.160761 Positive 0.000843 0.204852 Positive median_high_minus_low True True True 0.060500 0.060500 Positive 0.080926 0.157122 True True True 0.035274 1.583039 False False 0.079835
64 43 7 primary_sub_mod_cocaine_paste 14 0.035601 ed_attainment_corr 0.093641 0.093641 70521 51972 18397 70521 -0.000349 0.000197 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000061 0.223205 Positive 0.000035 0.126176 Positive median_high_minus_low True True True 0.044206 0.044206 Positive 0.104604 0.059437 True False False 0.049435 2.118284 False False 0.118601
65 2 14 dit_m 15 0.034467 eva_fisica 0.091995 0.091995 70521 21442 27083 70521 0.000027 -0.000141 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000108 0.229812 Positive 0.000103 0.220638 Positive median_high_minus_low True True True 0.085418 0.085418 Positive 0.213823 0.218763 True True True 0.006577 1.076997 False False 0.015989
32 26 46 cohabitation_with_couple_children 7 0.058190 tipo_de_vivienda_rec2_other_unknown 0.089551 0.089551 70521 60621 70521 70521 -0.000688 0.000372 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000019 0.030814 Positive NaN NaN Unknown median_high_minus_low False True True 0.060709 0.060709 Positive 0.020305 NaN True True True 0.028843 1.475099 False False 0.010509
14 51 0 occupation_condition_corr24_unemployed 4 0.099876 adm_age_rec3 0.087931 0.087931 70521 17641 17640 70521 0.001183 -0.000347 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000211 0.204989 Positive 0.000160 0.190486 Positive median_high_minus_low True True True 0.105125 0.105125 Positive 0.276356 0.181443 True True True -0.017194 0.836445 False False -0.071367
66 2 13 dit_m 15 0.034467 eva_sm 0.087026 0.087026 70521 40420 30092 70521 0.000027 -0.000141 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True 0.000049 0.103764 Positive 0.000112 0.240587 Positive median_high_minus_low True True True 0.103319 0.103319 Positive 0.110683 0.303318 True True True -0.016294 0.842297 False False -0.006918
8 42 23 primary_sub_mod_alcohol 2 0.267121 polysubstance_strict -0.086448 0.086448 70521 18940 51581 70521 0.003733 -0.000676 Death death horizon_risk_approx 3.0 70521 99.5 0.0 0.0 True True True -0.000416 -0.195777 Negative -0.000416 -0.195777 Negative binary_1_minus_0 False True True 0.052460 0.052460 Positive 0.118523 0.118523 True True True 0.033988 1.647897 True True 0.077254

Direction Flip Summary Across Horizons

outcome total_pairs flipped_pairs flip_pct
0 Death 180 0 0.0
1 Readmission 201 0 0.0

Direction Flip Summary Between Scales

outcome horizon_months n_pairs corr_flip_n q4q1_flip_n corr_flip_pct q4q1_flip_pct
0 Death 3.0 72 30 30 41.666667 41.666667
1 Death 6.0 72 30 30 41.666667 41.666667
2 Death 12.0 72 30 30 41.666667 41.666667
3 Death 36.0 71 28 28 39.436620 39.436620
4 Death 60.0 70 27 27 38.571429 38.571429
5 Death 96.0 67 20 20 29.850746 29.850746
6 Readmission 3.0 56 11 11 19.642857 19.642857
7 Readmission 6.0 56 11 11 19.642857 19.642857
8 Readmission 12.0 57 11 11 19.298246 19.298246
9 Readmission 36.0 60 9 9 15.000000 15.000000
10 Readmission 60.0 62 10 10 16.129032 16.129032
11 Readmission 96.0 62 9 9 14.516129 14.516129

Threshold Justification

outcome metric p50 p75 p90 chosen_threshold chosen_percentile
0 Death rank_range 3.000000 7.000000 13.000000 8.00 82.469136
1 Death abs_corr_delta 0.003627 0.007644 0.012414 0.05 100.000000
2 Readmission rank_range 7.000000 13.000000 22.000000 8.00 59.382716
3 Readmission abs_corr_delta 0.007371 0.015231 0.026958 0.05 97.777778

Threshold Sensitivity Grid

outcome rank_range_threshold abs_corr_delta_threshold n_flagged n_total pct_flagged is_current_rule
0 Death 5 0.02 349 810 43.086420 False
1 Death 5 0.03 336 810 41.481481 False
2 Death 5 0.05 328 810 40.493827 False
3 Death 8 0.02 195 810 24.074074 False
4 Death 8 0.03 181 810 22.345679 False
5 Death 8 0.05 173 810 21.358025 True
6 Death 10 0.02 145 810 17.901235 False
7 Death 10 0.03 131 810 16.172840 False
8 Death 10 0.05 123 810 15.185185 False
9 Death 15 0.02 87 810 10.740741 False
10 Death 15 0.03 73 810 9.012346 False
11 Death 15 0.05 65 810 8.024691 False
12 Readmission 5 0.02 536 810 66.172840 False
13 Readmission 5 0.03 527 810 65.061728 False
14 Readmission 5 0.05 518 810 63.950617 False
15 Readmission 8 0.02 404 810 49.876543 False
16 Readmission 8 0.03 385 810 47.530864 False
17 Readmission 8 0.05 372 810 45.925926 True
18 Readmission 10 0.02 334 810 41.234568 False
19 Readmission 10 0.03 310 810 38.271605 False
20 Readmission 10 0.05 293 810 36.172840 False
21 Readmission 15 0.02 242 810 29.876543 False
22 Readmission 15 0.03 208 810 25.679012 False
23 Readmission 15 0.05 179 810 22.098765 False

Saved files:
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_global_top_enriched_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_horizon_top_enriched_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_scale_comparison_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_direction_flip_horizon_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_direction_flip_scale_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_threshold_justification_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_threshold_sensitivity_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_effectsize_summary_20260306_1821.csv
- G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb11_dual_interactions_aug_run_info_20260306_1821.json

Interactions detected on log-hazard scale may differ in magnitude and direction from risk-scale interactions due to non-linear transformation (∂risk/∂η = (-ln(1-r))(1-r)). For mortality, 33% of top interactions changed direction between scales, while readmission showed only 3% direction changes. We report both scales for completeness and emphasize risk-scale interactions for clinical interpretation.

Time-dependent interaction thresholds were set at rank range ≥8 positions (91.8th percentile for death, 38.6th for readmission) and correlation delta ≥0.03 (100th percentile for death, 94.9th for readmission), based on empirical distributions of interaction variability across horizons.

Readmission interactions showed substantially more time-dependent variability (68% flagged) compared to mortality (11% flagged), suggesting readmission risk factors evolve more dynamically over follow-up while mortality risk factors remain relatively stable.

Functional form

Code
#@title ⚡ Step 12: Functional Form Analysis (XGBoost - DUAL SHAP, Aggregated)

import os
import re
import glob
import pickle
from datetime import datetime
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from IPython.display import display, Markdown, HTML  # FIXED: Added HTML

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()

TABLE_DIR = PROJECT_ROOT / "_out_tabble"   # use your requested folder name
FIG_DIR = PROJECT_ROOT / "_figs"

TABLE_DIR.mkdir(parents=True, exist_ok=True)
FIG_DIR.mkdir(parents=True, exist_ok=True)

IN_DIR = PROJECT_ROOT / "_out"  # where Step 5 SHAP files are read from

try:
    from scipy.stats import f as f_dist
except Exception:
    f_dist = None

# --- 1) CONFIG ---
CONTINUOUS_VARS = ['adm_age_rec3', 'porc_pobr', 'dit_m', 
'tenure_status_household', 'urbanicity_cat', 'evaluacindelprocesoteraputico', 
'eva_consumo',  'eva_fam', 'eva_relinterp', 'eva_ocupacion', 'eva_sm', 'eva_fisica', 'eva_transgnorma', 'prim_sub_freq_rec', 'ed_attainment_corr']

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_PATH = TABLE_DIR / f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.xlsx"
PARQUET_PATH = TABLE_DIR / f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.parquet"
CSV_PATH = TABLE_DIR / f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.csv"

MAX_SCATTER_N = 10000 # max points shown in each scatter plot; if more, it samples 5000 for speed/readability.
POLY_DEGREE = 3 # polynomial degree used for trend fit (cubic).
RANDOM_STATE = 2125
SHAP_SCALE = 1.0 #no rescaling
SHAP_UNIT_LABEL = "Log-Hazard (model margin)"
N_BOOT = 500
CI_ALPHA = 0.05

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_FILENAME = f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.xlsx"
PARQUET_FILENAME = f"XGB12_corr_Functional_Forms_Dual_Aggregated_{timestamp}.parquet"

mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman", "DejaVu Serif"],
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "font.size": 14,
    "axes.grid": True,
    "grid.alpha": 0.30
})

# --- 2) HELPERS ---
def sanitize_name(txt):
    return re.sub(r"[^A-Za-z0-9_.-]+", "_", str(txt))

def make_sheet_name(raw, used):
    s = re.sub(r"[\[\]\*\?\/\\:]", "_", raw)[:31]
    base = s
    i = 1
    while s in used:
        suffix = f"_{i}"
        s = (base[:31-len(suffix)] + suffix)[:31]
        i += 1
    used.add(s)
    return s

def pick_latest_file():
    """Find latest SHAP file, handling both _mar26 and non-_mar26 suffixes."""
    # FIXED: Try both patterns
    patterns = [
        "xgb6_corr_DUAL_SHAP_Aggregated_*.pkl",
        "xgb6_corr_DUAL_SHAP_Aggregated_*_mar26.pkl"
    ]
    files = []
    for pat in patterns:
        files.extend(IN_DIR.glob(pat))
    
    if not files:
        raise FileNotFoundError(f"No xgb6_corr_DUAL_SHAP_Aggregated_*.pkl found in {IN_DIR}")
    return sorted(files, key=lambda p: p.stat().st_mtime)[-1]

def f_test_linear_vs_quadratic(x, y):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    n = x.size
    if n < 10:
        return np.nan, np.nan

    X1 = np.column_stack([np.ones(n), x])
    X2 = np.column_stack([np.ones(n), x, x**2])

    try:
        b1 = np.linalg.lstsq(X1, y, rcond=None)[0]
        b2 = np.linalg.lstsq(X2, y, rcond=None)[0]
        rss1 = np.sum((y - X1 @ b1) ** 2)
        rss2 = np.sum((y - X2 @ b2) ** 2)
    except Exception:
        return np.nan, np.nan

    df1 = 1
    df2 = n - X2.shape[1]
    if df2 <= 0 or rss2 <= 0 or rss1 < rss2:
        return np.nan, np.nan

    f_stat = ((rss1 - rss2) / df1) / (rss2 / df2)
    if f_dist is None:
        return float(f_stat), np.nan
    p_val = float(f_dist.sf(f_stat, df1, df2))
    return float(f_stat), p_val

def polyfit_bootstrap_ci(x, y, degree=3, n_boot=300, alpha=0.05, seed=2125):
    x = np.asarray(x, dtype=float)
    y = np.asarray(y, dtype=float)
    m = np.isfinite(x) & np.isfinite(y)
    x, y = x[m], y[m]
    if x.size < 20 or np.unique(x).size < 4:
        return None

    rng = np.random.default_rng(seed)
    xg = np.linspace(np.min(x), np.max(x), 200)

    xc = x - np.mean(x)
    xgc = xg - np.mean(x)
    deg = min(degree, np.unique(xc).size - 1, xc.size - 1)
    if deg < 1:
        return None

    try:
        p = np.poly1d(np.polyfit(xc, y, deg))
        yhat = p(xgc)
    except Exception:
        return None

    boot = []
    n = xc.size
    for _ in range(n_boot):
        idx = rng.integers(0, n, n)
        xb, yb = xc[idx], y[idx]
        if np.unique(xb).size < (deg + 1):
            continue
        try:
            pb = np.poly1d(np.polyfit(xb, yb, deg))
            boot.append(pb(xgc))
        except Exception:
            continue

    if len(boot) < 20:
        return xg, yhat, None, None

    boot = np.vstack(boot)
    lo = np.quantile(boot, alpha / 2.0, axis=0)
    hi = np.quantile(boot, 1 - alpha / 2.0, axis=0)
    return xg, yhat, lo, hi

# --- 3) LOAD AGGREGATED SHAP ---
display(Markdown("### Loading DUAL SHAP aggregated file"))
shap_file = pick_latest_file()
display(Markdown(f"Using file: `{shap_file}`"))

with shap_file.open("rb") as f:
    data = pickle.load(f)

if not isinstance(data, dict):
    raise ValueError("Loaded SHAP file is not a dict.")

if "X_all" not in data or "feature_names" not in data:
    raise ValueError("Missing required keys: X_all and/or feature_names.")

X_all = data["X_all"]
feature_names = data["feature_names"]

if not isinstance(X_all, pd.DataFrame):
    X_all = pd.DataFrame(np.asarray(X_all), columns=feature_names)
else:
    if len(feature_names) == X_all.shape[1]:
        X_all = X_all.copy()
        X_all.columns = feature_names

# FIXED: Added duplicate feature check
if len(feature_names) != len(set(feature_names)):
    dupes = [f for f in feature_names if feature_names.count(f) > 1]
    raise ValueError(f"Duplicate feature names detected: {set(dupes)}")

shap_arrays = {}
if "shap_r_all" in data:
    shap_arrays["Readmission"] = np.asarray(data["shap_r_all"])
if "shap_d_all" in data:
    shap_arrays["Death"] = np.asarray(data["shap_d_all"])
if "shap_all" in data and not shap_arrays:
    shap_arrays["Readmission"] = np.asarray(data["shap_all"])

if not shap_arrays:
    raise ValueError("No SHAP arrays found (shap_r_all / shap_d_all / shap_all).")

display(Markdown(f"Outcomes found: `{list(shap_arrays.keys())}`"))
display(Markdown("Detected aggregated (non-horizon) SHAP structure."))

# --- 4) FUNCTIONAL FORM ANALYSIS ---
all_data_list = []
summary_rows = []

for outcome, shap_mat in shap_arrays.items():
    if shap_mat.ndim != 2:
        continue

    n = min(X_all.shape[0], shap_mat.shape[0])
    p = min(X_all.shape[1], shap_mat.shape[1])
    X_use = X_all.iloc[:n, :p].copy()
    S_use = shap_mat[:n, :p]

    for var in CONTINUOUS_VARS:
        if var not in X_use.columns:
            continue

        col_idx = X_use.columns.get_loc(var)
        x_vec = X_use.iloc[:, col_idx].to_numpy()
        y_vec = np.asarray(S_use[:, col_idx], dtype=float) * SHAP_SCALE

        mask = np.isfinite(x_vec) & np.isfinite(y_vec)
        x_vec = x_vec[mask]
        y_vec = y_vec[mask]
        if x_vec.size == 0:
            continue

        f_stat, p_val = f_test_linear_vs_quadratic(x_vec, y_vec)

        all_data_list.append(pd.DataFrame({
            "Feature_Value": x_vec,
            "SHAP_Impact": y_vec,
            "Predictor": var,
            "Outcome": outcome,
            "Scope": "Aggregated_All_Times"
        }))

        corr = np.nan
        if np.std(x_vec) > 0 and np.std(y_vec) > 0:
            corr = float(np.corrcoef(x_vec, y_vec)[0, 1])

        summary_rows.append({
            "Outcome": outcome,
            "Scope": "Aggregated_All_Times",
            "Predictor": var,
            "N": int(x_vec.size),
            "Mean_SHAP": float(np.mean(y_vec)),
            "MeanAbs_SHAP": float(np.mean(np.abs(y_vec))),
            "Q10_SHAP": float(np.quantile(y_vec, 0.10)),
            "Q50_SHAP": float(np.quantile(y_vec, 0.50)),
            "Q90_SHAP": float(np.quantile(y_vec, 0.90)),
            "F_linear_vs_quad": f_stat,
            "P_linear_vs_quad": p_val,
            "Nonlinear_p_lt_0_05": bool(pd.notna(p_val) and p_val < 0.05),
            "Corr_X_SHAP": corr
        })

        plt.figure(figsize=(8, 5))
        if x_vec.size > MAX_SCATTER_N:
            rng = np.random.default_rng(RANDOM_STATE)
            idx = rng.choice(x_vec.size, MAX_SCATTER_N, replace=False)
            plt.scatter(x_vec[idx], y_vec[idx], alpha=0.30, c="#1f77b4", s=15, edgecolors="none", label="Patients (sample)")
        else:
            plt.scatter(x_vec, y_vec, alpha=0.45, c="#1f77b4", s=18, edgecolors="none", label="Patients")

        fit = polyfit_bootstrap_ci(
            x_vec, y_vec,
            degree=POLY_DEGREE,
            n_boot=N_BOOT,
            alpha=CI_ALPHA,
            seed=RANDOM_STATE
        )
        if fit is not None:
            x_grid, y_hat, y_lo, y_hi = fit
            if y_lo is not None and y_hi is not None:
                plt.fill_between(x_grid, y_lo, y_hi, color="red", alpha=0.15, label=f"{int((1-CI_ALPHA)*100)}% CI")
            plt.plot(x_grid, y_hat, "r--", linewidth=2.2, label=f"Trend (poly-{POLY_DEGREE})")

        plt.axhline(0, color="k", linestyle=":", linewidth=1)
        plt.title(f"Functional Form: {var}\n({outcome}, Aggregated)", fontsize=13, fontweight="bold")
        plt.xlabel(f"Feature Value: {var}")
        plt.ylabel(f"SHAP Impact ({SHAP_UNIT_LABEL})")
        plt.legend(loc="best")

        fname = f"XGB12_corr_DUAL_FuncForm_{sanitize_name(outcome)}_{sanitize_name(var)}_{timestamp}"
        plt.savefig(FIG_DIR / f"{fname}.png", dpi=300, bbox_inches="tight")
        plt.savefig(FIG_DIR / f"{fname}.pdf", bbox_inches="tight")
        plt.show()

if not all_data_list:
    raise ValueError("No functional-form data created. Check predictor names and SHAP structure.")

full_df = pd.concat(all_data_list, ignore_index=True)
summary_df = pd.DataFrame(summary_rows).sort_values(["Outcome", "Predictor"]).reset_index(drop=True)

# --- 5) SAVE OUTPUTS ---
full_df.to_parquet(PARQUET_PATH, index=False)
full_df.to_csv(CSV_PATH, index=False)

used_sheet_names = set()
with pd.ExcelWriter(EXCEL_PATH, engine="xlsxwriter") as writer:
    summary_df.to_excel(writer, sheet_name="Effects_Summary", index=False)

    meta_df = pd.DataFrame({
        "Item": ["SHAP_SCALE", "SHAP_UNIT_LABEL", "Interpretation", "Scope"],
        "Value": [
            SHAP_SCALE,
            SHAP_UNIT_LABEL,
            "For survival:cox, SHAP is on log-hazard-ratio scale.",
            "Aggregated across time horizons (no horizon-specific SHAP in source file)."
        ]
    })
    meta_df.to_excel(writer, sheet_name="Meta", index=False)

    for (outcome, predictor), g in full_df.groupby(["Outcome", "Predictor"], sort=True):
        sheet_raw = f"{outcome[:1]}_{predictor[:18]}"
        sheet_name = make_sheet_name(sheet_raw, used_sheet_names)
        g_to_save = g[["Feature_Value", "SHAP_Impact"]]
        if len(g_to_save) > 100000:
            g_to_save = g_to_save.sample(100000, random_state=RANDOM_STATE)
        g_to_save.to_excel(writer, sheet_name=sheet_name, index=False)

# FIXED: Use display(HTML(...)) instead of print for final output
display(HTML(f"<b>Done.</b> Excel: <code>{EXCEL_PATH}</code>"))
display(HTML(f"<b>Done.</b> Parquet: <code>{PARQUET_PATH}</code>"))
display(HTML(f"<b>Done.</b> CSV: <code>{CSV_PATH}</code>"))
display(HTML(f"<b>Plots saved in:</b> <code>{FIG_DIR}</code>"))

global_functional_data = full_df
global_functional_summary = summary_df

Loading DUAL SHAP aggregated file

Using file: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\xgb6_corr_DUAL_SHAP_Aggregated_20260306_1821_mar26.pkl

Outcomes found: ['Readmission', 'Death']

Detected aggregated (non-horizon) SHAP structure.

Done. Excel: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out_tabble\XGB12_corr_Functional_Forms_Dual_Aggregated_20260306_1834.xlsx
Done. Parquet: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out_tabble\XGB12_corr_Functional_Forms_Dual_Aggregated_20260306_1834.parquet
Done. CSV: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out_tabble\XGB12_corr_Functional_Forms_Dual_Aggregated_20260306_1834.csv
Plots saved in: G:\My Drive\Alvacast\SISTRAT 2023\cons\_figs
Code
#@title Step 12b: 4×4 Facetted Plots (15 vars + legend, per outcome)

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from matplotlib import rcParams
import matplotlib.patches as mpatches

# Academic publication settings
rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "Nimbus Roman"],
    "font.size": 9,
    "axes.labelsize": 9,
    "axes.titlesize": 10,
    "xtick.labelsize": 8,
    "ytick.labelsize": 8,
    "legend.fontsize": 8,
    "figure.dpi": 300,
    "savefig.dpi": 300,
    "savefig.bbox": "tight",
    "savefig.pad_inches": 0.02,
})

# Your exact 15 variables
CONTINUOUS_VARS_15 = [
    'adm_age_rec3', 'porc_pobr', 'dit_m', 
    'tenure_status_household', 'urbanicity_cat', 'evaluacindelprocesoteraputico', 
    'eva_consumo', 'eva_fam', 'eva_relinterp', 'eva_ocupacion', 
    'eva_sm', 'eva_fisica', 'eva_transgnorma', 'prim_sub_freq_rec', 
    'ed_attainment_corr'
]

def create_4x4_facetted_plot(outcome_name, shap_mat, X_df, figsize=(10, 10)):
    """
    Create 4×4 facetted plot: 15 variables + legend in 16th position.
    No main title. Academic format.
    """
    
    # Filter to only variables that exist in data
    features = [f for f in CONTINUOUS_VARS_15 if f in X_df.columns]
    n_features = len(features)
    
    if n_features == 0:
        raise ValueError(f"No valid features found for {outcome_name}")
    
    fig = plt.figure(figsize=figsize)
    gs = gridspec.GridSpec(4, 4, figure=fig, 
                          wspace=0.125, hspace=0.15,  # Tight spacing
                          left=0.07, right=0.98, 
                          top=0.98, bottom=0.05)
    
    # Plot each feature in first 15 positions
    for i, var in enumerate(features):
        row = i // 4
        col = i % 4
        
        ax = fig.add_subplot(gs[row, col])
        
        col_idx = X_df.columns.get_loc(var)
        x_vec = X_df.iloc[:, col_idx].to_numpy()
        y_vec = np.asarray(shap_mat[:, col_idx], dtype=float) * SHAP_SCALE
        
        mask = np.isfinite(x_vec) & np.isfinite(y_vec)
        x_vec = x_vec[mask]
        y_vec = y_vec[mask]
        
        if x_vec.size == 0:
            ax.axis('off')
            continue
        
        # Scatter plot
        if x_vec.size > MAX_SCATTER_N:
            rng = np.random.default_rng(RANDOM_STATE)
            idx_sample = rng.choice(x_vec.size, MAX_SCATTER_N, replace=False)
            ax.scatter(x_vec[idx_sample], y_vec[idx_sample], alpha=0.20, 
                      c="#1f77b4", s=5, edgecolors="none", rasterized=True)
        else:
            ax.scatter(x_vec, y_vec, alpha=0.30, c="#1f77b4", 
                      s=7, edgecolors="none", rasterized=True)
        
        # Trend line with CI
        fit = polyfit_bootstrap_ci(x_vec, y_vec, degree=POLY_DEGREE, 
                                   n_boot=N_BOOT, alpha=CI_ALPHA, 
                                   seed=RANDOM_STATE)
        if fit is not None:
            x_grid, y_hat, y_lo, y_hi = fit
            if y_lo is not None and y_hi is not None:
                ax.fill_between(x_grid, y_lo, y_hi, color="red", 
                               alpha=0.12, linewidth=0)
            ax.plot(x_grid, y_hat, "r--", linewidth=1.3)
        
        # Zero reference line
        ax.axhline(0, color="black", linestyle=":", linewidth=0.7, alpha=0.6)
        
        # Feature name as subplot title
        ax.set_title(var.replace('_', ' ').title(), 
                    fontsize=11, fontweight='bold', pad=2)
        
        # Clean spines
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        
        # Axis labels - NO repeated labels
        # Only bottom row gets x-axis labels
        if row == 3:  # Bottom row (rows 0-3)
            ax.set_xlabel("Value", fontsize=12)
            ax.tick_params(axis='x', labelsize=10)
        else:
            ax.set_xlabel("")
            ax.tick_params(labelbottom=False)
        
        # Only leftmost column gets y-axis labels  
        if col == 0:
            ax.set_ylabel("SHAP (log-hazard)", fontsize=12)
            ax.tick_params(axis='y', labelsize=10)
        else:
            ax.set_ylabel("")
            ax.tick_params(labelleft=False)
    
    # 16th position (bottom-right): Legend
    legend_ax = fig.add_subplot(gs[3, 3])
    legend_ax.axis('off')
    
    # Create legend elements manually
    from matplotlib.lines import Line2D
    
    legend_elements = [
        Line2D([0], [0], marker='o', color='w', markerfacecolor='#1f77b4', 
               markersize=8, alpha=0.6, label='Observations'),
        Line2D([0], [0], color='red', linestyle='--', linewidth=1.5, 
               label=f'Trend (poly-{POLY_DEGREE})'),
        mpatches.Patch(facecolor='red', alpha=0.12, edgecolor='none',
                      label=f'{int((1-CI_ALPHA)*100)}% CI'),
        Line2D([0], [0], color='black', linestyle=':', linewidth=0.8, 
               alpha=0.6, label='Zero reference')
    ]
    
    legend = legend_ax.legend(handles=legend_elements,
                             loc='center', 
                             frameon=True, 
                             fancybox=False,
                             edgecolor='gray',
                             fontsize=12,
                             title=f'n = {X_df.shape[0]:,}',
                             title_fontsize=13)
    legend.get_frame().set_linewidth(0.5)
    
    # NO main title - clean academic style
    
    return fig

# Generate one 4×4 plot per outcome
for outcome, shap_mat in shap_arrays.items():
    if shap_mat.ndim != 2:
        continue
    
    # Check we have the variables    
    available_vars = [v for v in CONTINUOUS_VARS_15 if v in X_all.columns]
    if len(available_vars) < 15:
        display(HTML(f"<b>Warning {outcome}:</b> Only {len(available_vars)}/15 variables found"))
    
    n = min(X_all.shape[0], shap_mat.shape[0])
    p = min(X_all.shape[1], shap_mat.shape[1])
    X_use = X_all.iloc[:n, :p].copy()
    S_use = shap_mat[:n, :p]
    
    # Create the 4×4 facetted plot
    fig = create_4x4_facetted_plot(
        outcome_name=outcome,
        shap_mat=S_use,
        X_df=X_use,
        figsize=(11, 11)  # Square format
    )
    
    # Save
    safe_outcome = sanitize_name(outcome)
    png_path = FIG_DIR / f"XGB12_Faceted4x4_{safe_outcome}_{timestamp}.png"
    pdf_path = FIG_DIR / f"XGB12_Faceted4x4_{safe_outcome}_{timestamp}.pdf"
    
    fig.savefig(png_path, dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    fig.savefig(pdf_path, dpi=300, bbox_inches='tight', 
                facecolor='white', edgecolor='none')
    
    plt.show()
    display(HTML(f"<b>{outcome}:</b> 4×4 plot saved (15 vars + legend, no title)"))
    display(HTML(f"&nbsp;&nbsp;PNG: <code>{png_path.name}</code>"))
    display(HTML(f"&nbsp;&nbsp;PDF: <code>{pdf_path.name}</code>"))
    
    plt.close(fig)

Readmission: 4×4 plot saved (15 vars + legend, no title)
  PNG: XGB12_Faceted4x4_Readmission_20260306_1834.png
  PDF: XGB12_Faceted4x4_Readmission_20260306_1834.pdf

Death: 4×4 plot saved (15 vars + legend, no title)
  PNG: XGB12_Faceted4x4_Death_20260306_1834.png
  PDF: XGB12_Faceted4x4_Death_20260306_1834.pdf

We used XGBoost to check whether risk really changes linearly with predictors, then used SHAP to see each feature’s true effect: if the SHAP curve looked straight (e.g., age vs. death) we kept a simple linear term; if it was curved or U‑shaped (e.g., age vs. readmission) we modeled it with splines or categories; and if SHAP revealed interactions (e.g., age boosting the effect of alcohol) we added those interaction terms to the final statistical model, so the final model reflects the real shapes and interactions the machine learner found.

Balance between train and test

For categorical variables with more than two levels, standardized mean differences (SMDs) were computed at the level of each category by recoding the factor into a set of binary indicators (one-vs-rest). For each level, the SMD was calculated as the difference in proportions between training and testing samples divided by the pooled standard deviation of a Bernoulli variable.

Code
from pathlib import Path
import pandas as pd
from IPython.display import display

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)

imputation_nodum_1 = pd.read_parquet(
        BASE_DIR / f"imputation_nondum_1.parquet",
        engine="fastparquet"
    )

# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7
mask_adm_death = (
    imputation_nodum_1["tr_outcome"].str.contains("adm reasons", case=False, na=False)
    & imputation_nodum_1["death_time_from_disch_m"].notna()
    & (imputation_nodum_1["death_time_from_disch_m"] <= 0.23)
    & (imputation_nodum_1["death_event"] > 0)
)
# Condition 2: tr_outcome_other == 1 (any time)
mask_other = imputation_nodum_1["tr_outcome"] == "other"

# Combined exclusion mask
exclude = mask_adm_death | mask_other
keep = ~exclude

# ── Filter imputation_nodum_1 in place ──
imputation_nodum_1 = imputation_nodum_1[~exclude].copy()


# Only for imputation_nodum_1 (non-dummified)
col = "first_sub_used"

# To match your previous grouped dummies:
to_other = {
    "hallucinogens",
    "opioids",
    "amphetamine-type stimulants",
    "tranquilizers/hypnotics",
    "inhalants",
    "others",
}

# Keep original if needed
imputation_nodum_1["first_sub_used_original"] = imputation_nodum_1[col]

# Recode
s = imputation_nodum_1[col].astype("string").str.strip().str.lower()
imputation_nodum_1[col] = s.mask(s.isin(to_other), "other")

# Optional: back to categorical
imputation_nodum_1[col] = imputation_nodum_1[col].astype("category")

# Check result
display(imputation_nodum_1[col].value_counts(dropna=False))
first_sub_used
alcohol           51882
marijuana         26311
cocaine paste      4300
cocaine powder     3283
other              2376
Name: count, dtype: int64
Code
imputation_nodum_1["tr_outcome"].value_counts()
tr_outcome
dropout                                 47055
completion                              23408
referral                                10044
adm discharge - rule violation/undet     6463
adm discharge - adm reasons              1182
Name: count, dtype: int64
Code
from IPython.display import display, Markdown, HTML
import io
import sys
def fold_output(title, func):
    buffer = io.StringIO()
    sys.stdout = buffer
    func()
    sys.stdout = sys.__stdout__
    
    html = f"""
    <details>
      <summary>{title}</summary>
      <pre>{buffer.getvalue()}</pre>
    </details>
    """
    display(HTML(html))

def glimpse(df):
    lines = [f"**{df.shape[0]} obs. of {df.shape[1]} variables:**\n"]
    for col in df.columns:
        dtype = df[col].dtype
        n_unique = df[col].nunique()
        sample = str(df[col].iloc[0])[:30]
        lines.append(f"- **{col}**: {dtype} ({n_unique} unique) | eg: {sample}")
    return "\n".join(lines)  # Return string, don't display

# Then in your fold:
fold_output("Glimpse of transformed database: ", lambda: print(glimpse(imputation_nodum_1)))
Glimpse of transformed database:
**88152 obs. of 44 variables:**

- **readmit_time_from_adm_m**: float64 (6951 unique) | eg: 84.93548387096774
- **death_time_from_adm_m**: float64 (4855 unique) | eg: 84.93548387096774
- **adm_age_rec3**: float64 (4580 unique) | eg: 31.53
- **porc_pobr**: float64 (882 unique) | eg: 0.175679117441177
- **dit_m**: float64 (12618 unique) | eg: 15.967741935483872
- **sex_rec**: object (2 unique) | eg: man
- **tenure_status_household**: object (5 unique) | eg: stays temporarily with a relat
- **cohabitation**: object (4 unique) | eg: alone
- **sub_dep_icd10_status**: object (2 unique) | eg: drug dependence
- **any_violence**: object (2 unique) | eg: 0.No domestic violence/sex abu
- **prim_sub_freq_rec**: object (3 unique) | eg: 2.2–6 days/wk
- **tr_outcome**: object (5 unique) | eg: referral
- **adm_motive**: object (5 unique) | eg: sanitary sector
- **first_sub_used**: category (5 unique) | eg: alcohol
- **primary_sub_mod**: object (5 unique) | eg: alcohol
- **tipo_de_vivienda_rec2**: object (2 unique) | eg: other/unknown
- **national_foreign**: int32 (2 unique) | eg: 0
- **plan_type_corr**: object (5 unique) | eg: pg-pab
- **occupation_condition_corr24**: object (3 unique) | eg: unemployed
- **marital_status_rec**: object (3 unique) | eg: single
- **urbanicity_cat**: object (3 unique) | eg: 3.Urban
- **ed_attainment_corr**: object (3 unique) | eg: 2-Completed high school or les
- **evaluacindelprocesoteraputico**: object (3 unique) | eg: logro alto
- **eva_consumo**: object (3 unique) | eg: logro alto
- **eva_fam**: object (3 unique) | eg: logro intermedio
- **eva_relinterp**: object (3 unique) | eg: logro alto
- **eva_ocupacion**: object (3 unique) | eg: logro alto
- **eva_sm**: object (3 unique) | eg: logro intermedio
- **eva_fisica**: object (3 unique) | eg: logro alto
- **eva_transgnorma**: object (3 unique) | eg: logro alto
- **ethnicity**: float64 (2 unique) | eg: 0.0
- **dg_psiq_cie_10_instudy**: bool (2 unique) | eg: False
- **dg_psiq_cie_10_dg**: bool (2 unique) | eg: True
- **dx_f3_mood**: int32 (2 unique) | eg: 0
- **dx_f6_personality**: int32 (2 unique) | eg: 0
- **dx_f_any_severe_mental**: bool (2 unique) | eg: True
- **any_phys_dx**: bool (2 unique) | eg: False
- **polysubstance_strict**: int32 (2 unique) | eg: 0
- **readmit_event**: float64 (2 unique) | eg: 0.0
- **death_event**: int32 (2 unique) | eg: 0
- **readmit_time_from_disch_m**: float64 (6991 unique) | eg: 68.96774193548387
- **death_time_from_disch_m**: float64 (4955 unique) | eg: 68.96774193548387
- **center_id**: object (429 unique) | eg: 330
- **first_sub_used_original**: object (10 unique) | eg: alcohol
Code
from pathlib import Path
import os
import pandas as pd
PROJECT_ROOT = Path.cwd()   # current notebook directory
OUT_DIR = PROJECT_ROOT / "_out"

split_seed2125 = pd.read_parquet(
        OUT_DIR / f"readm_split_seed2125_test20_mar26.parquet",
        engine="fastparquet"
    )
Code
import numpy as np
import pandas as pd
from IPython.display import display

from pathlib import Path

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
OUT_DIR = PROJECT_ROOT / "_out_tabble"   # or "_out" if you prefer
OUT_DIR.mkdir(parents=True, exist_ok=True)

CONTINUOUS_VARS = ['adm_age_rec3', 'porc_pobr', 'dit_m']

# Optional: exclude outcome/time columns from baseline Table 1
EXCLUDE_COLS = {
    'readmit_time_from_disch_m', 'readmit_event', 
    'death_time_from_disch_m', 'death_event', 'center_id', 
    'readmit_time_from_adm_m', 'death_time_from_adm_m', 
    'first_sub_used_original'
}

MAX_LEVELS = 30  # safety cap to avoid huge tables for very high-cardinality vars

# --- Helpers ---
def fmt_median_iqr(x):
    x = pd.to_numeric(pd.Series(x), errors="coerce").dropna()
    if len(x) == 0:
        return "NA"
    q1, med, q3 = x.quantile(0.25), x.quantile(0.50), x.quantile(0.75)
    return f"{med:.2f} [{q1:.2f}, {q3:.2f}]"

def smd_cont(train, test):
    train = pd.to_numeric(pd.Series(train), errors="coerce").dropna()
    test = pd.to_numeric(pd.Series(test), errors="coerce").dropna()
    if len(train) < 2 or len(test) < 2:
        return np.nan
    pooled = np.sqrt((train.std(ddof=1)**2 + test.std(ddof=1)**2) / 2.0)
    if pooled == 0 or not np.isfinite(pooled):
        return np.nan
    return (train.mean() - test.mean()) / pooled

def smd_bin(train01, test01):
    train01 = pd.to_numeric(pd.Series(train01), errors="coerce").dropna()
    test01 = pd.to_numeric(pd.Series(test01), errors="coerce").dropna()
    if len(train01) == 0 or len(test01) == 0:
        return np.nan
    p1, p0 = train01.mean(), test01.mean()
    den = np.sqrt((p1 * (1 - p1) + p0 * (1 - p0)) / 2.0)
    if den == 0 or not np.isfinite(den):
        return np.nan
    return (p1 - p0) / den

def fmt_mean_sd(x):
    x = pd.to_numeric(pd.Series(x), errors="coerce").dropna()
    return "NA" if len(x) == 0 else f"{x.mean():.2f} +/- {x.std(ddof=1):.2f}"

def fmt_n_pct(x01):
    x01 = pd.to_numeric(pd.Series(x01), errors="coerce").dropna()
    if len(x01) == 0:
        return "NA"
    n = int((x01 == 1).sum())
    pct = 100.0 * x01.mean()
    return f"{n} ({pct:.1f}%)"

def indicator_for_level(series, level, string_mode=False):
    s = series.copy()
    if string_mode:
        s = s.astype("string").str.strip()
        lvl = str(level).strip()
        ind = (s == lvl).astype(float)
        ind[s.isna()] = np.nan
        return ind
    ind = (s == level).astype(float)
    ind[s.isna()] = np.nan
    return ind

# --- Split alignment ---
split = split_seed2125.copy()
if {"row_id", "is_train"}.issubset(split.columns):
    split = split.sort_values("row_id").reset_index(drop=True)
    if not np.array_equal(split["row_id"].to_numpy(), np.arange(len(split))):
        raise ValueError("split_seed2125$row_id is not 0..N-1 after sorting.")
    is_train = split["is_train"].astype(bool).to_numpy()
elif "is_train" in split.columns:
    is_train = split["is_train"].astype(bool).to_numpy()
else:
    raise ValueError("split_seed2125 must contain column 'is_train'.")

base = imputation_nodum_1.reset_index(drop=True).copy()
if len(base) != len(is_train):
    raise ValueError(f"Row mismatch: imputation_nodum_1={len(base)}, split={len(is_train)}")

train_df = base.loc[is_train].copy()
test_df = base.loc[~is_train].copy()

# --- Build Table 1 ---
rows = []
excluded_high_card = []
missing_cont = [c for c in CONTINUOUS_VARS if c not in base.columns]

for col in base.columns:
    if col in EXCLUDE_COLS:
        continue

    miss_tr = 100.0 * train_df[col].isna().mean()
    miss_te = 100.0 * test_df[col].isna().mean()

    # 1) forced continuous vars
    if col in CONTINUOUS_VARS:
        smd = smd_cont(train_df[col], test_df[col])
        rows.append({
            "Variable": col,
            "Type": "Continuous",
            "Level": "",
            "Train": fmt_mean_sd(train_df[col]),              # keep old display if you want
            "Test": fmt_mean_sd(test_df[col]),
            "Train_Median_IQR": fmt_median_iqr(train_df[col]),
            "Test_Median_IQR": fmt_median_iqr(test_df[col]),
            "SMD": smd,
            "|SMD|": abs(smd) if np.isfinite(smd) else np.nan,
            "%Missing_Train": round(miss_tr, 2),
            "%Missing_Test": round(miss_te, 2),
        })
        continue

    s_full = base[col]
    non_na = s_full.dropna()
    if non_na.empty:
        continue

    # 2) bool / numeric / categorical handling
    if pd.api.types.is_bool_dtype(s_full):
        levels_all = [False, True]
        levels_report = [True]  # one row for dichotomous variable
        string_mode = False

    elif pd.api.types.is_numeric_dtype(s_full):
        vals = pd.to_numeric(non_na, errors="coerce")
        vals = np.sort(pd.unique(vals[np.isfinite(vals)]))
        levels_all = list(vals)
        string_mode = False

        # 0/1 dichotomous -> one row (level=1)
        if len(levels_all) == 2 and set(np.round(levels_all, 10)).issubset({0.0, 1.0}):
            levels_report = [1.0]
        else:
            if len(levels_all) > MAX_LEVELS:
                excluded_high_card.append((col, len(levels_all)))
                continue
            levels_report = levels_all

    else:
        vals = sorted(non_na.astype("string").str.strip().dropna().unique().tolist())
        levels_all = vals
        string_mode = True

        if len(levels_all) > MAX_LEVELS:
            excluded_high_card.append((col, len(levels_all)))
            continue

        # dichotomous categorical -> one row (last level)
        if len(levels_all) == 2:
            levels_report = [levels_all[-1]]
        else:
            levels_report = levels_all

    for lvl in levels_report:
        ind_tr = indicator_for_level(train_df[col], lvl, string_mode=string_mode)
        ind_te = indicator_for_level(test_df[col], lvl, string_mode=string_mode)

        smd = smd_bin(ind_tr, ind_te)
        var_type = "Dichotomous" if len(levels_all) == 2 else "Categorical (level)"

        rows.append({
            "Variable": col,
            "Type": var_type,
            "Level": str(lvl),
            "Train": fmt_n_pct(ind_tr),
            "Test": fmt_n_pct(ind_te),
            "Train_Median_IQR": "",
            "Test_Median_IQR": "",
            "SMD": smd,
            "|SMD|": abs(smd) if np.isfinite(smd) else np.nan,
            "%Missing_Train": round(miss_tr, 2),
            "%Missing_Test": round(miss_te, 2),
        })        

table1_split = pd.DataFrame(rows)
table1_split = table1_split.sort_values(
    ["Type", "Variable", "|SMD|"], ascending=[True, True, False], na_position="last"
).reset_index(drop=True)

summary_split = pd.DataFrame({
    "N_train": [len(train_df)],
    "N_test": [len(test_df)],
    "N_total": [len(base)],
    "N_rows_table1": [len(table1_split)],
    "N_missing_continuous_vars": [len(missing_cont)],
    "N_excluded_high_cardinality_vars": [len(excluded_high_card)],
})

display(summary_split)
if missing_cont:
    print("Missing CONTINUOUS_VARS in dataset:", missing_cont)
if excluded_high_card:
    print("Excluded high-cardinality variables (name, n_levels):", excluded_high_card)

# optional export
csv_path = OUT_DIR / "table1_split_seed2125_multilevel_mar26.csv"
table1_split.to_csv(csv_path, index=False)
print(f"Saved: {csv_path}")
N_train N_test N_total N_rows_table1 N_missing_continuous_vars N_excluded_high_cardinality_vars
0 70521 17631 88152 89 0 0
Code
import pandas as pd
from IPython.display import HTML, display

# Reset options so Pandas doesn't force everything
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Convert DataFrame to HTML and wrap in a scrollable div
html_table = table1_split.to_html()
scroll_box = f"""
<div style="max-height:500px; max-width:1000px; overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">
{html_table}
</div>
"""
display(HTML(scroll_box))
Variable Type Level Train Test Train_Median_IQR Test_Median_IQR SMD |SMD| %Missing_Train %Missing_Test
0 adm_motive Categorical (level) another SUD facility/FONODROGAS/SENDA Previene 6565 (9.3%) 1705 (9.7%) -0.012324 0.012324 0.0 0.0
1 adm_motive Categorical (level) justice sector 6721 (9.5%) 1724 (9.8%) -0.008388 0.008388 0.0 0.0
2 adm_motive Categorical (level) sanitary sector 21933 (31.1%) 5425 (30.8%) 0.007176 0.007176 0.0 0.0
3 adm_motive Categorical (level) other 3663 (5.2%) 890 (5.0%) 0.006636 0.006636 0.0 0.0
4 adm_motive Categorical (level) spontaneous consultation 31639 (44.9%) 7887 (44.7%) 0.002633 0.002633 0.0 0.0
5 cohabitation Categorical (level) Others 6111 (8.7%) 1493 (8.5%) 0.007056 0.007056 0.0 0.0
6 cohabitation Categorical (level) alone 6677 (9.5%) 1693 (9.6%) -0.004573 0.004573 0.0 0.0
7 cohabitation Categorical (level) family of origin 26070 (37.0%) 6527 (37.0%) -0.001083 0.001083 0.0 0.0
8 cohabitation Categorical (level) with couple/children 31663 (44.9%) 7918 (44.9%) -0.000218 0.000218 0.0 0.0
9 ed_attainment_corr Categorical (level) 3-Completed primary school or less 18501 (26.2%) 4692 (26.6%) -0.008561 0.008561 0.0 0.0
10 ed_attainment_corr Categorical (level) 2-Completed high school or less 39070 (55.4%) 9728 (55.2%) 0.004553 0.004553 0.0 0.0
11 ed_attainment_corr Categorical (level) 1-More than high school 12950 (18.4%) 3211 (18.2%) 0.003908 0.003908 0.0 0.0
12 eva_consumo Categorical (level) logro minimo 29543 (41.9%) 7529 (42.7%) -0.016411 0.016411 0.0 0.0
13 eva_consumo Categorical (level) logro intermedio 20274 (28.7%) 4962 (28.1%) 0.013416 0.013416 0.0 0.0
14 eva_consumo Categorical (level) logro alto 20704 (29.4%) 5140 (29.2%) 0.004516 0.004516 0.0 0.0
15 eva_fam Categorical (level) logro minimo 31203 (44.2%) 7981 (45.3%) -0.020524 0.020524 0.0 0.0
16 eva_fam Categorical (level) logro intermedio 22594 (32.0%) 5513 (31.3%) 0.016553 0.016553 0.0 0.0
17 eva_fam Categorical (level) logro alto 16724 (23.7%) 4137 (23.5%) 0.005902 0.005902 0.0 0.0
18 eva_fisica Categorical (level) logro intermedio 21987 (31.2%) 5309 (30.1%) 0.023129 0.023129 0.0 0.0
19 eva_fisica Categorical (level) logro minimo 27087 (38.4%) 6935 (39.3%) -0.018962 0.018962 0.0 0.0
20 eva_fisica Categorical (level) logro alto 21447 (30.4%) 5387 (30.6%) -0.003083 0.003083 0.0 0.0
21 eva_ocupacion Categorical (level) logro minimo 29417 (41.7%) 7539 (42.8%) -0.021180 0.021180 0.0 0.0
22 eva_ocupacion Categorical (level) logro alto 22298 (31.6%) 5451 (30.9%) 0.015139 0.015139 0.0 0.0
23 eva_ocupacion Categorical (level) logro intermedio 18806 (26.7%) 4641 (26.3%) 0.007801 0.007801 0.0 0.0
24 eva_relinterp Categorical (level) logro minimo 30797 (43.7%) 7859 (44.6%) -0.018211 0.018211 0.0 0.0
25 eva_relinterp Categorical (level) logro intermedio 22491 (31.9%) 5526 (31.3%) 0.011831 0.011831 0.0 0.0
26 eva_relinterp Categorical (level) logro alto 17233 (24.4%) 4246 (24.1%) 0.008261 0.008261 0.0 0.0
27 eva_sm Categorical (level) logro minimo 30093 (42.7%) 7648 (43.4%) -0.014255 0.014255 0.0 0.0
28 eva_sm Categorical (level) logro intermedio 23257 (33.0%) 5704 (32.4%) 0.013364 0.013364 0.0 0.0
29 eva_sm Categorical (level) logro alto 17171 (24.3%) 4279 (24.3%) 0.001842 0.001842 0.0 0.0
30 eva_transgnorma Categorical (level) logro minimo 27789 (39.4%) 7032 (39.9%) -0.009793 0.009793 0.0 0.0
31 eva_transgnorma Categorical (level) logro intermedio 17465 (24.8%) 4325 (24.5%) 0.005453 0.005453 0.0 0.0
32 eva_transgnorma Categorical (level) logro alto 25267 (35.8%) 6274 (35.6%) 0.005092 0.005092 0.0 0.0
33 evaluacindelprocesoteraputico Categorical (level) logro minimo 30546 (43.3%) 7848 (44.5%) -0.024136 0.024136 0.0 0.0
34 evaluacindelprocesoteraputico Categorical (level) logro intermedio 23040 (32.7%) 5565 (31.6%) 0.023718 0.023718 0.0 0.0
35 evaluacindelprocesoteraputico Categorical (level) logro alto 16935 (24.0%) 4218 (23.9%) 0.002117 0.002117 0.0 0.0
36 first_sub_used Categorical (level) cocaine paste 3375 (4.8%) 925 (5.2%) -0.021104 0.021104 0.0 0.0
37 first_sub_used Categorical (level) cocaine powder 2597 (3.7%) 686 (3.9%) -0.010912 0.010912 0.0 0.0
38 first_sub_used Categorical (level) alcohol 41555 (58.9%) 10327 (58.6%) 0.007165 0.007165 0.0 0.0
39 first_sub_used Categorical (level) marijuana 21085 (29.9%) 5226 (29.6%) 0.005641 0.005641 0.0 0.0
40 first_sub_used Categorical (level) other 1909 (2.7%) 467 (2.6%) 0.003608 0.003608 0.0 0.0
41 marital_status_rec Categorical (level) single 38639 (54.8%) 9621 (54.6%) 0.004462 0.004462 0.0 0.0
42 marital_status_rec Categorical (level) married/cohabiting 23711 (33.6%) 5959 (33.8%) -0.003719 0.003719 0.0 0.0
43 marital_status_rec Categorical (level) separated/divorced/annulled/widowed 8171 (11.6%) 2051 (11.6%) -0.001445 0.001445 0.0 0.0
44 occupation_condition_corr24 Categorical (level) inactive 11467 (16.3%) 2777 (15.8%) 0.013903 0.013903 0.0 0.0
45 occupation_condition_corr24 Categorical (level) unemployed 24498 (34.7%) 6228 (35.3%) -0.012274 0.012274 0.0 0.0
46 occupation_condition_corr24 Categorical (level) employed 34556 (49.0%) 8626 (48.9%) 0.001517 0.001517 0.0 0.0
47 plan_type_corr Categorical (level) m-pr 2829 (4.0%) 708 (4.0%) -0.000208 0.000208 0.0 0.0
48 plan_type_corr Categorical (level) pg-pr 7739 (11.0%) 1934 (11.0%) 0.000151 0.000151 0.0 0.0
49 plan_type_corr Categorical (level) m-pai 4202 (6.0%) 1051 (6.0%) -0.000109 0.000109 0.0 0.0
50 plan_type_corr Categorical (level) pg-pab 25814 (36.6%) 6453 (36.6%) 0.000091 0.000091 0.0 0.0
51 plan_type_corr Categorical (level) pg-pai 29937 (42.5%) 7485 (42.5%) -0.000050 0.000050 0.0 0.0
52 prim_sub_freq_rec Categorical (level) 1.≤1 day/wk 8137 (11.5%) 2084 (11.8%) -0.008771 0.008771 0.0 0.0
53 prim_sub_freq_rec Categorical (level) 3.Daily 31122 (44.1%) 7729 (43.8%) 0.005923 0.005923 0.0 0.0
54 prim_sub_freq_rec Categorical (level) 2.2–6 days/wk 31262 (44.3%) 7818 (44.3%) -0.000247 0.000247 0.0 0.0
55 primary_sub_mod Categorical (level) cocaine powder 13794 (19.6%) 3370 (19.1%) 0.011295 0.011295 0.0 0.0
56 primary_sub_mod Categorical (level) alcohol 23975 (34.0%) 6047 (34.3%) -0.006339 0.006339 0.0 0.0
57 primary_sub_mod Categorical (level) others 1283 (1.8%) 310 (1.8%) 0.004606 0.004606 0.0 0.0
58 primary_sub_mod Categorical (level) cocaine paste 26665 (37.8%) 6703 (38.0%) -0.004263 0.004263 0.0 0.0
59 primary_sub_mod Categorical (level) marijuana 4804 (6.8%) 1201 (6.8%) 0.000012 0.000012 0.0 0.0
60 tenure_status_household Categorical (level) owner/transferred dwellings/pays dividends 25731 (36.5%) 6291 (35.7%) 0.016774 0.016774 0.0 0.0
61 tenure_status_household Categorical (level) illegal settlement 1078 (1.5%) 238 (1.3%) 0.015007 0.015007 0.0 0.0
62 tenure_status_household Categorical (level) renting 12625 (17.9%) 3252 (18.4%) -0.014064 0.014064 0.0 0.0
63 tenure_status_household Categorical (level) others 2073 (2.9%) 543 (3.1%) -0.008209 0.008209 0.0 0.0
64 tenure_status_household Categorical (level) stays temporarily with a relative 29014 (41.1%) 7307 (41.4%) -0.006127 0.006127 0.0 0.0
65 tr_outcome Categorical (level) referral 8082 (11.5%) 1962 (11.1%) 0.010498 0.010498 0.0 0.0
66 tr_outcome Categorical (level) adm discharge - adm reasons 930 (1.3%) 252 (1.4%) -0.009496 0.009496 0.0 0.0
67 tr_outcome Categorical (level) adm discharge - rule violation/undet 5201 (7.4%) 1262 (7.2%) 0.008370 0.008370 0.0 0.0
68 tr_outcome Categorical (level) dropout 37588 (53.3%) 9467 (53.7%) -0.007915 0.007915 0.0 0.0
69 tr_outcome Categorical (level) completion 18720 (26.5%) 4688 (26.6%) -0.001002 0.001002 0.0 0.0
70 urbanicity_cat Categorical (level) 1.Rural 6113 (8.7%) 1497 (8.5%) 0.006342 0.006342 0.0 0.0
71 urbanicity_cat Categorical (level) 3.Urban 57504 (81.5%) 14419 (81.8%) -0.006213 0.006213 0.0 0.0
72 urbanicity_cat Categorical (level) 2.Mixed 6904 (9.8%) 1715 (9.7%) 0.002116 0.002116 0.0 0.0
73 adm_age_rec3 Continuous 35.73 +/- 10.46 35.81 +/- 10.44 34.17 [27.39, 42.94] 34.23 [27.50, 43.11] -0.007985 0.007985 0.0 0.0
74 dit_m Continuous 7.37 +/- 6.16 7.30 +/- 6.10 5.58 [3.00, 9.97] 5.52 [3.00, 9.94] 0.011042 0.011042 0.0 0.0
75 porc_pobr Continuous 0.14 +/- 0.07 0.14 +/- 0.07 0.12 [0.09, 0.17] 0.12 [0.09, 0.17] 0.013918 0.013918 0.0 0.0
76 any_phys_dx Dichotomous True 6709 (9.5%) 1636 (9.3%) 0.008032 0.008032 0.0 0.0
77 any_violence Dichotomous 1.Domestic violence/sex abuse 19904 (28.2%) 4954 (28.1%) 0.002801 0.002801 0.0 0.0
78 dg_psiq_cie_10_dg Dichotomous True 31562 (44.8%) 7805 (44.3%) 0.009796 0.009796 0.0 0.0
79 dg_psiq_cie_10_instudy Dichotomous True 12229 (17.3%) 3207 (18.2%) -0.022204 0.022204 0.0 0.0
80 dx_f3_mood Dichotomous 1.0 7205 (10.2%) 1831 (10.4%) -0.005537 0.005537 0.0 0.0
81 dx_f6_personality Dichotomous 1.0 20122 (28.5%) 4948 (28.1%) 0.010415 0.010415 0.0 0.0
82 dx_f_any_severe_mental Dichotomous True 3448 (4.9%) 844 (4.8%) 0.004768 0.004768 0.0 0.0
83 ethnicity Dichotomous 1.0 4500 (6.4%) 1093 (6.2%) 0.007487 0.007487 0.0 0.0
84 national_foreign Dichotomous 1.0 453 (0.6%) 115 (0.7%) -0.001234 0.001234 0.0 0.0
85 polysubstance_strict Dichotomous 1.0 51581 (73.1%) 12759 (72.4%) 0.017428 0.017428 0.0 0.0
86 sex_rec Dichotomous woman 18082 (25.6%) 4535 (25.7%) -0.001858 0.001858 0.0 0.0
87 sub_dep_icd10_status Dichotomous hazardous consumption 19222 (27.3%) 4765 (27.0%) 0.005192 0.005192 0.0 0.0
88 tipo_de_vivienda_rec2 Dichotomous other/unknown 7813 (11.1%) 1882 (10.7%) 0.012995 0.012995 0.0 0.0
Back to top