Deep Hit (part 2)

This notebook implements a DeepHit model for competing risks. It uses a single neural network to simultaneously predict the risk of death and readmission, explicitly modeling their competition. The model outputs Cumulative Incidence Functions (CIFs) for each risk, allowing evaluation of prediction accuracy over time for both outcomes using metrics like Uno’s C-Index (corrected for competing risks) and cause-specific Brier scores. It processes multiple imputed datasets and evaluates performance across different time horizons.

Author

ags

Published

February 11, 2026

0. Package loading and installation

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda activate surv-deephit
#conda install ipykernel -y
#conda install -c conda-forge pytorch torchtuples pycox
#conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# conda install pycox torchtuples scikit-learn scikit-survival lifelines shap seaborn matplotlib scipy pandas -c conda-forge -y
# por si: conda install pycox torchtuples -c conda-forge -y

#Conda te avisa que va a hacer dos cambios porque estás instalando PyTorch con CUDA:
#conda-forge::cuda-cudart 12.9  →  nvidia::cuda-cudart 11.8

#Packages stored in : 
#conda env export --no-builds > "G:\My Drive\Alvacast\SISTRAT 2023\dh\environment.yml"

#Load packages in:
#conda activate base
# conda install -c conda-forge conda-lock
#conda env export --no-builds | findstr /V "^prefix:" > "G:\My Drive\Alvacast\SISTRAT 2023\dh\environment.yml"

import sys
import subprocess

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ Compute Device: {device}")

from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv



#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df
✅ Compute Device: cuda
Code
packages = ["torch", "torchtuples", "pycox"]

for p in packages:
    try:
        mod = __import__(p)
        print(f"✅ {p} installed | version:", getattr(mod, "__version__", "unknown"))
    except ImportError:
        print(f"❌ {p} NOT installed")
✅ torch installed | version: 2.5.1
✅ torchtuples installed | version: 0.2.2
✅ pycox installed | version: 0.3.0

Load data

Code
from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)


import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)


imputation_nodum_1 = pd.read_parquet(
    BASE_DIR / "imputation_nondum_1.parquet",
    engine="fastparquet"
)

X_reduced_imp0 = pd.read_parquet(
    BASE_DIR / "X_reduced_imp0.parquet",
    engine="fastparquet"
)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)


# Quick check
glimpse(imputation_nodum_1)
glimpse(imputation_1)
glimpse(X_reduced_imp0)
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Rows: 88504 | Columns: 78
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
national_foreign               int32           0, 0, 0, 0, 0
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
death_event                    int32           0, 0, 0, 0, 0
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
tenure_status_household_illegal_settlement float64         0.0, 0.0, 0.0, 0.0, 0.0
tenure_status_household_owner_transferred_dwellings_pays_dividends float64         0.0, 1.0, 0.0, 1.0, 0.0
tenure_status_household_renting float64         0.0, 0.0, 0.0, 0.0, 0.0
tenure_status_household_stays_temporarily_with_a_relative float64         1.0, 0.0, 1.0, 0.0, 1.0
cohabitation_alone             float64         1.0, 0.0, 0.0, 0.0, 0.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
prim_sub_freq_rec_2_2_6_days_wk float64         1.0, 0.0, 0.0, 0.0, 0.0
prim_sub_freq_rec_3_daily      float64         0.0, 1.0, 1.0, 1.0, 1.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_completion          float64         0.0, 0.0, 0.0, 0.0, 1.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_spontaneous_consultation float64         0.0, 1.0, 0.0, 0.0, 1.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_marijuana       float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_opioids         float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_tranquilizers_hypnotics float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
urbanicity_cat_1_rural         float64         0.0, 0.0, 0.0, 0.0, 0.0
urbanicity_cat_2_mixed         float64         0.0, 0.0, 0.0, 0.0, 0.0
ed_attainment_corr_2_completed_high_school_or_less float64         1.0, 0.0, 1.0, 1.0, 0.0
ed_attainment_corr_3_completed_primary_school_or_less float64         0.0, 1.0, 0.0, 0.0, 1.0
evaluacindelprocesoteraputico_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 0.0
evaluacindelprocesoteraputico_logro_minimo float64         0.0, 1.0, 1.0, 1.0, 0.0
eva_consumo_logro_intermedio   float64         0.0, 0.0, 0.0, 1.0, 0.0
eva_consumo_logro_minimo       float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_fam_logro_intermedio       float64         1.0, 0.0, 0.0, 1.0, 0.0
eva_fam_logro_minimo           float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_relinterp_logro_intermedio float64         0.0, 0.0, 0.0, 1.0, 0.0
eva_relinterp_logro_minimo     float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_ocupacion_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 1.0
eva_ocupacion_logro_minimo     float64         0.0, 1.0, 1.0, 1.0, 0.0
eva_sm_logro_intermedio        float64         1.0, 0.0, 0.0, 1.0, 0.0
eva_sm_logro_minimo            float64         0.0, 1.0, 1.0, 0.0, 1.0
eva_fisica_logro_intermedio    float64         0.0, 0.0, 1.0, 1.0, 0.0
eva_fisica_logro_minimo        float64         0.0, 1.0, 0.0, 0.0, 0.0
eva_transgnorma_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 1.0
eva_transgnorma_logro_minimo   float64         0.0, 1.0, 1.0, 1.0, 0.0
Rows: 88504 | Columns: 123
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         int64           0, 0, 1, 0, 0
dg_psiq_cie_10_dg              int64           1, 0, 0, 1, 0
f0_organic                     int64           0, 0, 0, 0, 0
f2_psychotic                   int64           0, 0, 0, 0, 0
f3_mood                        int64           0, 0, 0, 0, 0
f4_anxiety_stress_somatoform   int64           0, 0, 0, 0, 0
f5_physio_eating_sleep_sexual  int64           0, 0, 0, 0, 0
f6_personality_adult_behaviour int64           0, 0, 0, 1, 0
f7_intellectual_disability     int64           1, 0, 0, 0, 0
f8_9_neurodevelopment_child    int64           0, 0, 0, 0, 0
dx_f2_smi_psychotic            int32           0, 0, 0, 0, 0
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f45_anx_stress_phys         int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f0789_neurocog_dev          int32           1, 0, 0, 0, 0
phys_dx_instudy                int32           1, 1, 1, 1, 1
phys_dx_other_spec_medical_cond int32           0, 0, 0, 1, 0
phys_dx_organ_system_med_dis   int32           0, 0, 0, 0, 0
phys_dx_injuries_and_sequelae  int32           0, 0, 0, 0, 0
phys_dx_infectious_diseases    int32           0, 0, 0, 0, 0
polysubstance_strict           int32           0, 1, 1, 1, 1
treat_lt_90                    int32           0, 0, 1, 0, 0
adm_age_log                    float64         3.48216274048526, 3.0731561705187946, 3.773220602547687, 4.120824195026479, 3.83...
adm_age_pow2                   float64         994.1409000000001, 424.77209999999997, 1807.9504000000002, 3673.5721, 2032.20639...
adm_age_pow3                   float64         31345.262577000005, 8754.552980999999, 76874.05100800002, 222655.204981, 91611.8...
adm_age_c                      float64         -4.23091837657055, -15.150918376570552, 6.759081623429452, 24.84908162342945, 9....
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
porc_pobr_log                  float64         0.1618459529248312, 0.17213308174953304, 0.12258256123798024, 0.1255388244402368...
porc_pobr_c                    float64         0.03381019166263441, 0.04596697619708939, -0.011456481306229588, -0.008109740106...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
treat_log                      float64         2.831314008252564, 1.921812597476253, 0.3888402221385285, 2.075266170269355, 2.0...
treat_days_pow2                float64         254.96878251821022, 34.027777777777786, 0.22588044860677478, 48.534444444444446,...
treat_days_pow3                float64         4071.2757208552925, 198.49537037037044, 0.10735393363891862, 338.1232962962963, ...
sex_rec_woman                  bool            False, False, False, True, False
tenure_status_household_illegal settlement bool            False, False, False, False, False
tenure_status_household_others bool            False, False, False, False, False
tenure_status_household_renting bool            False, False, False, False, False
tenure_status_household_stays temporarily with a relative bool            True, False, True, False, True
occupation_condition_corr24_inactive bool            False, False, False, True, False
occupation_condition_corr24_unemployed bool            True, False, False, False, True
marital_status_rec_married/cohabiting bool            False, False, False, True, False
marital_status_rec_separated/divorced/annulled/widowed bool            False, False, False, False, False
marital_status_rec_single      bool            True, True, True, False, True
urbanicity_cat_2.Mixed         bool            False, False, False, False, False
urbanicity_cat_3.Urban         bool            True, True, True, True, True
ed_attainment_corr_1-More than high school bool            False, False, False, False, False
ed_attainment_corr_2-Completed high school or less bool            True, False, True, True, False
ed_attainment_corr_3-Completed primary school or less bool            False, True, False, False, True
cohabitation_with couple/children bool            False, False, True, True, False
cohabitation_family of origin  bool            False, True, False, False, True
cohabitation_Others            bool            False, False, False, False, False
sub_dep_icd10_status_drug dependence bool            True, False, True, True, True
dom_violence_Domestic violence bool            False, False, True, False, False
sex_abuse_Sexual abuse         bool            False, False, False, False, False
any_violence_0.No domestic violence/sex abuse bool            True, True, False, True, True
prim_sub_freq_1. Less than 1 day a week bool            False, False, False, False, False
prim_sub_freq_2. 1 day a week  bool            False, False, False, False, False
prim_sub_freq_3. 2 to 3 days a week bool            True, False, False, False, False
prim_sub_freq_4. 4 to 6 days a week bool            False, False, False, False, False
prim_sub_freq_5. Daily         bool            False, True, True, True, True
prim_sub_freq_rec_2.2–6 days/wk bool            True, False, False, False, False
prim_sub_freq_rec_3.Daily      bool            False, True, True, True, True
tr_outcome_adm discharge - adm reasons bool            False, False, False, False, False
tr_outcome_adm discharge - rule violation/undet bool            False, False, True, False, False
tr_outcome_dropout             bool            False, True, False, True, False
tr_outcome_other               bool            False, False, False, False, False
tr_outcome_referral            bool            True, False, False, False, False
adm_motive_another SUD facility/FONODROGAS/SENDA Previene bool            False, False, False, False, False
adm_motive_justice sector      bool            False, False, False, False, False
adm_motive_other               bool            False, False, False, False, False
adm_motive_sanitary sector     bool            True, False, True, True, False
primary_sub_amphetamine-type stimulants bool            False, False, False, False, False
primary_sub_cocaine paste      bool            False, True, True, True, True
primary_sub_cocaine powder     bool            False, False, False, False, False
primary_sub_dissociatives      bool            False, False, False, False, False
primary_sub_hallucinogens      bool            False, False, False, False, False
primary_sub_inhalants          bool            False, False, False, False, False
primary_sub_marijuana          bool            False, False, False, False, False
primary_sub_opioids            bool            False, False, False, False, False
primary_sub_others             bool            False, False, False, False, False
primary_sub_tranquilizers/hypnotics bool            False, False, False, False, False
primary_sub_mod_cocaine paste  bool            False, True, True, True, True
primary_sub_mod_cocaine powder bool            False, False, False, False, False
primary_sub_mod_alcohol        bool            True, False, False, False, False
primary_sub_mod_others         bool            False, False, False, False, False
usuario_tribunal_trat_droga_no bool            True, True, True, True, True
usuario_tribunal_trat_droga_si bool            False, False, False, False, False
tipo_de_vivienda_rec_shared/secondary unit bool            False, False, False, False, False
tipo_de_vivienda_rec_homeless/unsheltered/informal/temporary housing/institutional/collective bool            False, False, False, False, False
tipo_de_vivienda_rec_other/unknown bool            True, False, False, False, False
tipo_de_vivienda_rec2_other/unknown bool            True, False, False, True, False
evaluacindelprocesoteraputico_logro alto bool            True, False, False, False, True
evaluacindelprocesoteraputico_logro intermedio bool            False, False, False, False, False
evaluacindelprocesoteraputico_logro minimo bool            False, True, True, True, False
eva_consumo_logro alto         bool            True, False, False, False, True
eva_consumo_logro intermedio   bool            False, False, False, True, False
eva_consumo_logro minimo       bool            False, True, True, False, False
eva_fam_logro alto             bool            False, False, False, False, True
eva_fam_logro intermedio       bool            True, False, False, True, False
eva_fam_logro minimo           bool            False, True, True, False, False
eva_relinterp_logro alto       bool            True, False, False, False, True
eva_relinterp_logro intermedio bool            False, False, False, True, False
eva_relinterp_logro minimo     bool            False, True, True, False, False
eva_ocupacion_logro alto       bool            True, False, False, False, False
eva_ocupacion_logro intermedio bool            False, False, False, False, True
eva_ocupacion_logro minimo     bool            False, True, True, True, False
eva_sm_logro alto              bool            False, False, False, False, False
eva_sm_logro intermedio        bool            True, False, False, True, False
eva_sm_logro minimo            bool            False, True, True, False, True
eva_fisica_logro alto          bool            True, False, False, False, True
eva_fisica_logro intermedio    bool            False, False, True, True, False
eva_fisica_logro minimo        bool            False, True, False, False, False
eva_transgnorma_logro alto     bool            True, False, False, False, False
eva_transgnorma_logro intermedio bool            False, False, False, False, True
eva_transgnorma_logro minimo   bool            False, True, True, True, False
adm_age_cat_30-44              bool            True, False, True, False, False
adm_age_cat_45-64              bool            False, False, False, True, True
nationality_chile_other        bool            False, False, False, False, False
plan_type_corr_m-pr            bool            False, False, False, True, False
plan_type_corr_pg-pab          bool            True, True, False, False, False
plan_type_corr_pg-pai          bool            False, False, False, False, True
plan_type_corr_pg-pr           bool            False, False, True, False, False

Load in python

Code
if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    print("First element type:", type(imputations_list_jan26[0]))
    if isinstance(imputations_list_jan26[0], dict):
        print("First element keys:", imputations_list_jan26[0].keys())
    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        print("First element shape:", imputations_list_jan26[0].shape)
First element type: <class 'pandas.core.frame.DataFrame'>
First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Compare databases (transformed and original)

Inspect and compare the column names of two datasets: the first imputation from imputations_list_jan26 (which likely contains dummy variables) and imputation_nodum_1 (which, as its name suggests, probably doesn’t have dummy variables).

Code
# Inspect columns of the first imputation
cols_first_imp = imputations_list_jan26[0].columns.tolist()
print("First imputation columns:", cols_first_imp[:10], "... total:", len(cols_first_imp))

# Inspect columns of imputation_no_dum
cols_nodum = imputation_nodum_1.columns.tolist()
print("No-dum columns:", cols_nodum[:10], "... total:", len(cols_nodum))

# Compare overlap
common_cols = set(cols_first_imp).intersection(cols_nodum)
missing_in_imp = [c for c in cols_nodum if c not in cols_first_imp]
missing_in_nodum = [c for c in cols_first_imp if c not in cols_nodum]

print("Common columns:", len(common_cols))
print("Missing in imputations_list_jan26:", missing_in_imp)

# Inspect columns of the first imputation
cols_first_imp_raw = imputation_1.columns.tolist()
print("First imputation columns:", cols_first_imp_raw[:10], "... total:", len(cols_first_imp_raw))

# Compare overlap
common_cols_raw = set(cols_first_imp_raw).intersection(cols_nodum)
missing_in_imp_raw = [c for c in cols_nodum if c not in cols_first_imp_raw]

print("Common columns:", len(common_cols_raw))
print("Missing in imputations_list_jan26:", missing_in_imp_raw)
print(common_cols_raw)

import pandas as pd

# Example: choose a combination of variables that uniquely identify rows
key_vars = ["adm_age_rec3", "porc_pobr", "dit_m"]

# Take one imputation (first element of the list) and merge with the no-dum dataset
df_imp = imputations_list_jan26[0]
df_nodum = imputation_nodum_1

merged_check = pd.merge(
    df_imp[key_vars],
    df_nodum[key_vars],
    on=key_vars,
    how="inner"
)

print(f"Merged rows: {merged_check.shape[0]}")
print("Preview of merged check:")
print(merged_check.head())

#drop merge
del merged_check

import pandas as pd

# Example: choose a combination of variables that uniquely identify rows
key_vars_raw = ['dit_m',
            'readmit_time_from_adm_m',
            'death_time_from_adm_m',
            'adm_age_rec3']
# Take one imputation (first element of the list) and merge with the no-dum dataset
df_raw = imputation_1

merged_check_raw = pd.merge(
    df_imp[key_vars],
    df_raw[key_vars],
    on=key_vars,
    how="inner"
)

print(f"Merged rows: {merged_check_raw.shape[0]}")
print("Preview of merged check:")
print(merged_check_raw.head())
print(f"{(merged_check_raw.shape[0] / imputation_1.shape[0] * 100):.2f}%")
#drop merge
del merged_check_raw
First imputation columns: ['adm_age_rec3', 'porc_pobr', 'dit_m', 'tenure_status_household', 'prim_sub_freq_rec', 'national_foreign', 'urbanicity_cat', 'ed_attainment_corr', 'evaluacindelprocesoteraputico', 'eva_consumo'] ... total: 56
No-dum columns: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'adm_age_rec3', 'porc_pobr', 'dit_m', 'sex_rec', 'tenure_status_household', 'cohabitation', 'sub_dep_icd10_status', 'any_violence'] ... total: 43
Common columns: 24
Missing in imputations_list_jan26: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'sex_rec', 'cohabitation', 'sub_dep_icd10_status', 'any_violence', 'tr_outcome', 'adm_motive', 'first_sub_used', 'primary_sub_mod', 'tipo_de_vivienda_rec2', 'plan_type_corr', 'occupation_condition_corr24', 'marital_status_rec', 'readmit_event', 'death_event', 'readmit_time_from_disch_m', 'death_time_from_disch_m', 'center_id']
First imputation columns: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'adm_age_rec3', 'porc_pobr', 'dit_m', 'national_foreign', 'ethnicity', 'dg_psiq_cie_10_instudy', 'dg_psiq_cie_10_dg', 'dx_f3_mood'] ... total: 78
Common columns: 18
Missing in imputations_list_jan26: ['sex_rec', 'tenure_status_household', 'cohabitation', 'sub_dep_icd10_status', 'any_violence', 'prim_sub_freq_rec', 'tr_outcome', 'adm_motive', 'first_sub_used', 'primary_sub_mod', 'tipo_de_vivienda_rec2', 'plan_type_corr', 'occupation_condition_corr24', 'marital_status_rec', 'urbanicity_cat', 'ed_attainment_corr', 'evaluacindelprocesoteraputico', 'eva_consumo', 'eva_fam', 'eva_relinterp', 'eva_ocupacion', 'eva_sm', 'eva_fisica', 'eva_transgnorma', 'center_id']
{'death_event', 'readmit_time_from_disch_m', 'any_phys_dx', 'dx_f_any_severe_mental', 'dx_f3_mood', 'dg_psiq_cie_10_instudy', 'polysubstance_strict', 'dg_psiq_cie_10_dg', 'national_foreign', 'dit_m', 'porc_pobr', 'dx_f6_personality', 'death_time_from_disch_m', 'readmit_event', 'ethnicity', 'death_time_from_adm_m', 'adm_age_rec3', 'readmit_time_from_adm_m'}
Merged rows: 88516
Preview of merged check:
   adm_age_rec3  porc_pobr      dit_m
0         31.53   0.175679  15.967742
1         20.61   0.187836   5.833333
2         42.52   0.130412   0.475269
3         60.61   0.133759   6.966667
4         45.08   0.083189   6.903226
Merged rows: 88516
Preview of merged check:
   adm_age_rec3  porc_pobr      dit_m
0         31.53   0.175679  15.967742
1         20.61   0.187836   5.833333
2         42.52   0.130412   0.475269
3         60.61   0.133759   6.966667
4         45.08   0.083189   6.903226
100.01%

Create bins for followup (landmarks)

This code prepares your data for survival analysis. It extracts the time until an event (like readmission or death) and whether that event actually happened for each patient from the df_nodum dataset. Then, it automatically creates a set of important time points, called an ‘evaluation grid’, which are specific moments to assess the model’s performance on both readmission and death outcomes.

Code
import numpy as np

# Required columns for survival outcomes
required = ["readmit_time_from_disch_m", "readmit_event",
            "death_time_from_disch_m", "death_event"]

# Check that df_raw has all required columns
missing = [c for c in required if c not in df_raw.columns]
if missing:
    raise KeyError(f"df_nodum is missing columns: {missing}")

# Create time/event arrays directly from df_raw
time_readm = df_raw["readmit_time_from_adm_m"].to_numpy()
event_readm = (df_raw["readmit_event"].to_numpy() == 1)

time_death = df_raw["death_time_from_adm_m"].to_numpy()
event_death = (df_nodum["death_event"].to_numpy() == 1)

print("Arrays created for df_raw:")
print("Readmission times:", time_readm[:5])
print("Readmission events:", event_readm[:5])
print("Death times:", time_death[:5])
print("Death events:", event_death[:5])

# Build evaluation grids (quantiles of event times)
event_times_readm = time_readm[event_readm]
event_times_death = time_death[event_death]

if len(event_times_readm) < 5 or len(event_times_death) < 5:
    raise ValueError("Too few events in df_raw to build reliable time grids.")

times_eval_readm = np.unique(np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50)))
times_eval_death = np.unique(np.quantile(event_times_death, np.linspace(0.05, 0.95, 50)))

print("Eval times (readmission):", times_eval_readm[:5], "...", times_eval_readm[-5:])
print("Eval times (death):", times_eval_death[:5], "...", times_eval_death[-5:])
Arrays created for df_raw:
Readmission times: [84.93548387 12.83333333 13.73333333 11.96666667 14.25806452]
Readmission events: [False  True  True  True  True]
Death times: [ 84.93548387  87.16129032 117.22580645  98.93548387  37.93548387]
Death events: [False False False False False]
Eval times (readmission): [3.93548387 4.77419355 5.45058701 6.06492649 6.67741935] ... [54.44173469 58.41566162 63.23333333 68.54767171 74.68983871]
Eval times (death): [4.16290323 5.43022383 6.68564845 8.24254115 9.77961817] ... [81.92700461 85.41186103 88.78518762 93.5538183  99.21935484]

Prepare survival data

Code
import numpy as np

# Step 1. Extract survival outcomes directly from df_raw
time_readm = df_raw["readmit_time_from_adm_m"].to_numpy()
event_readm = (df_raw["readmit_event"].to_numpy() == 1)

time_death = df_raw["death_time_from_adm_m"].to_numpy()
event_death = (df_raw["death_event"].to_numpy() == 1)

# Step 2. Build structured arrays (Surv objects)
y_surv_readm = np.empty(len(time_readm), dtype=[("event", "?"), ("time", "<f8")])
y_surv_readm["event"] = event_readm
y_surv_readm["time"] = time_readm

y_surv_death = np.empty(len(time_death), dtype=[("event", "?"), ("time", "<f8")])
y_surv_death["event"] = event_death
y_surv_death["time"] = time_death

# Step 3. Replicate across imputations
n_imputations = len(imputations_list_jan26)
y_surv_readm_list = [y_surv_readm for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death for _ in range(n_imputations)]

PyCox

XGBoost models are cause-specific Cox models treating competing events as censoring. Calibration under competing risks was assessed using inverse-probability–weighted Brier scores computed directly on cause-specific cumulative incidence functions. To summarize calibration across time, Brier scores were averaged across prespecified prediction horizons.

Competing-risk outcomes were constructed using a time-to-first-event formulation, where readmission was coded as the event of interest if it occurred before death, and death was treated as the competing event otherwise. This ensured correct handling of event ordering and avoided misclassification of earlier readmissions as deaths. Although patients may experience multiple clinical events over follow-up, competing-risk survival models are defined on time to first event only; subsequent events are therefore outside the estimand rather than discarded.

In some cross-validation folds, global IPCW-based concordance measures for readmission were undefined due to the low incidence of readmission and the resulting lack of informative comparable pairs after censoring adjustment. This behavior is inherent to inverse-probability-weighted concordance estimation in competing-risk settings and does not indicate model misspecification or instability.

First attempt

Code
#@title ⚡ Final Comprehensive Evaluation: Pooled DeepHit (Strict No-Leakage)
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import shap
import time
import gc
import warnings
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, confusion_matrix
from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score
from pycox.models import DeepHit
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
import torchtuples as tt
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter

start_time = time.time()

TEST_MODE = False

# --- 1. CONFIGURATION ---
BEST_LR = 0.001
BEST_WD = 0.001
BEST_BATCH = 1024
BEST_DROPOUT = 0.5
BEST_NODES = [256, 256, 128]
BEST_ALPHA = 0.5
BEST_SIGMA = 0.5

NUM_RISKS = 3
K_FOLDS = 10 
EVAL_HORIZONS = [3, 6, 12, 24, 36, 48, 60, 72, 84, 96, 108]
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Assuming imputations_list_jan26 is a list of 5 DataFrames
N_IMPUTATIONS = len(imputations_list_jan26) 

#Check if test mode, and modify parameters accordingly
if TEST_MODE:
    N_IMPUTATIONS_TEST = 1
    K_FOLDS_TEST = 5
    EVAL_HORIZONS_TEST = [12, 24]
    MAX_EPOCHS_TEST = 30
    SHAP_FOLDS_TEST = 2
else:
    N_IMPUTATIONS_TEST = N_IMPUTATIONS
    K_FOLDS_TEST = K_FOLDS
    EVAL_HORIZONS_TEST = EVAL_HORIZONS
    MAX_EPOCHS_TEST = 100
    SHAP_FOLDS_TEST = 3

warnings.filterwarnings("ignore")

print(f"⚡ Starting Pooled Evaluation on {N_IMPUTATIONS_TEST} Imputations...")
print(f"   Device: {DEVICE} | Horizons: {EVAL_HORIZONS_TEST}")

# --- 2. NETWORK CLASS ---
class CauseSpecificNet(nn.Module):
    def __init__(self, in_f, nodes, out_f, dropout, num_risks):
        super().__init__()
        self.net = tt.practical.MLPVanilla(in_f, nodes, out_f, batch_norm=True, dropout=dropout)
        self.num_risks = num_risks
    def forward(self, x):
        return self.net(x).view(x.size(0), self.num_risks, -1)

# --- 3. METRIC HELPERS ---
def get_binary_target(events, times, risk_id, t_horizon):
    # Target: 1 if Event_k <= t, 0 otherwise (excluding censored before t)
    is_case = (events == risk_id) & (times <= t_horizon)
    mask_censored_early = (events == 0) & (times <= t_horizon)
    valid_mask = ~mask_censored_early
    y_binary = is_case[valid_mask].astype(int)
    return y_binary, valid_mask

def find_optimal_threshold(y_true, y_prob):
    # Optimizes F1 on TRAINING data
    thresholds = np.linspace(0.01, 0.99, 99)
    best_f1 = -1
    best_th = 0.5
    for th in thresholds:
        y_pred = (y_prob >= th).astype(int)
        f1 = f1_score(y_true, y_pred, zero_division=0)
        if f1 > best_f1:
            best_f1 = f1
            best_th = th
    return best_th

def calculate_binary_metrics(y_true, y_prob, fixed_threshold):
    # Applies FIXED threshold to TEST data
    y_pred = (y_prob >= fixed_threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    
    return {
        'F1': f1_score(y_true, y_pred, zero_division=0),
        'Sens': recall_score(y_true, y_pred, zero_division=0),
        'Spec': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'PPV': precision_score(y_true, y_pred, zero_division=0),
        'NPV': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'Threshold': fixed_threshold
    }

def bootstrap_ci_non_normal(data, alpha=0.05):
    if len(data) < 2: return np.mean(data), np.mean(data), np.mean(data)
    lower = np.percentile(data, 100 * (alpha / 2))
    upper = np.percentile(data, 100 * (1 - alpha / 2))
    return np.mean(data), lower, upper

def compute_brier_competing(cif_values_at_time_horizon, censoring_kmf,
                            Y_test, D_test, event_of_interest, time_horizon):
    n = len(Y_test)
    assert len(D_test) == n

    residuals = np.zeros(n)
    for idx in range(n):
        observed_time = Y_test[idx]
        event_indicator = D_test[idx]
        if observed_time > time_horizon:
            weight = censoring_kmf.predict(time_horizon)
            residuals[idx] = (cif_values_at_time_horizon[idx])**2 / weight
        else:
            weight = censoring_kmf.predict(observed_time)
            if event_indicator == event_of_interest:
                residuals[idx] = (1 - cif_values_at_time_horizon[idx])**2 / weight
            elif event_indicator != event_of_interest and event_indicator != 0:
                residuals[idx] = (cif_values_at_time_horizon[idx])**2 / weight
    return residuals.mean()

# --- 4. MAIN POOLED LOOP ---
pooled_results = []
shap_samples_agg = []
shap_values_agg = []

start_time_global = time.time()

for imp_idx in range(N_IMPUTATIONS_TEST):
    print(f"\n📚 Imputation {imp_idx + 1}/{N_IMPUTATIONS_TEST}")
    
    # Load Data
    X_raw = imputations_list_jan26[imp_idx]
    y_d = y_surv_death_list[imp_idx]
    y_r = y_surv_readm_list[imp_idx]
    
    # Pre-process
    t_d = np.asarray(y_d['time'])
    e_d = np.asarray(y_d['event']).astype(bool)

    t_r = np.asarray(y_r['time'])
    e_r = np.asarray(y_r['event']).astype(bool)

    events = np.zeros(len(X_raw), dtype=int)
    times  = t_d.copy().astype('float32')  # default = death/censor time

    # Readmission happens first
    mask_r = e_r & (t_r <= t_d)
    events[mask_r] = 2
    times[mask_r]  = t_r[mask_r]

    # Death happens first (or only)
    mask_d = e_d & (~mask_r)
    events[mask_d] = 1
    # times already = t_d
    print("Event counts:", np.bincount(events))

    # Stratification
    X_curr = X_raw.copy()
    plan_cols = [c for c in X_raw.columns if c.startswith("plan_type_corr")]
    X_curr[plan_cols] = X_curr[plan_cols].astype('float32')
    #✅ Step 1: define an explicit plan index mapping
    plan_idx = np.zeros(len(X_curr), dtype=int)  # 0 = pg_pab (reference)
    if 'plan_type_corr_pg_pr' in X_curr.columns:
        plan_idx[X_curr['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in X_curr.columns:
        plan_idx[X_curr['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in X_curr.columns:
        plan_idx[X_curr['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in X_curr.columns:
        plan_idx[X_curr['plan_type_corr_m_pai'] == 1] = 4
    # Exactly one plan must be true per row
    plan_sum = X_curr[plan_cols].sum(axis=1)

    if (plan_sum > 1).all():
        raise ValueError("Invalid plan encoding: some rows have >1 plan types.")
    #✅ Step 2: combine with competing-risk event for stratification
    strat_labels = (events * 10) + plan_idx

    # CV Loop
    skf = StratifiedKFold(n_splits=K_FOLDS_TEST, shuffle=True, random_state=2125 + imp_idx) # Vary seed per imp
    
    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_curr, strat_labels)):
        print(f".", end="")
        
        X_train, X_val = X_curr.iloc[train_idx].values, X_curr.iloc[val_idx].values
        t_train, e_train = times[train_idx], events[train_idx]
        t_val, e_val = times[val_idx], events[val_idx]
        
        # Scale
        scaler = StandardScaler().fit(X_train)
        X_train_s = scaler.transform(X_train).astype('float32')
        X_val_s = scaler.transform(X_val).astype('float32')
        
        # Discretize
        labtrans = LabTransDiscreteTime(100)
        y_train = labtrans.fit_transform(t_train, e_train)
        y_val = labtrans.transform(t_val, e_val)
        y_train = (y_train[0].astype('int64'), y_train[1].astype('int64'))
        y_val = (y_val[0].astype('int64'), y_val[1].astype('int64'))
        
        # Train
        in_f = X_train.shape[1]
        out_f = labtrans.out_features * NUM_RISKS
        net = CauseSpecificNet(in_f, BEST_NODES, out_f, BEST_DROPOUT, NUM_RISKS)
        model = DeepHit(net, tt.optim.Adam, alpha=BEST_ALPHA, sigma=BEST_SIGMA, duration_index=labtrans.cuts)
        model.set_device(DEVICE)
        model.optimizer.set_lr(BEST_LR)
        model.optimizer.param_groups[0]['weight_decay'] = BEST_WD
        
        model.fit(X_train_s, y_train, batch_size=BEST_BATCH, epochs=MAX_EPOCHS_TEST, 
                  callbacks=[tt.callbacks.EarlyStopping()], verbose=False, val_data=(X_val_s, y_val))
        
        # Predict (Test)
        cif_val = model.predict_cif(X_val_s)
        
        # Predict (Train - for threshold tuning)
        cif_train = model.predict_cif(X_train_s)
        
        # --- SHAP (Imp 1 only to save space) ---
        if imp_idx == 0 and fold_idx < SHAP_FOLDS_TEST: 
            try:
                bg = X_train_s[np.random.choice(len(X_train_s), 50, replace=False)]
                test = X_val_s[np.random.choice(len(X_val_s), 50, replace=False)]
                def pred_risk_24m(x): # Combined Risk Driver
                    x_t = torch.from_numpy(x).to(DEVICE)
                    with torch.no_grad():
                        c = model.predict_cif(x_t)
                        i = np.searchsorted(model.duration_index, 24)
                        return (c[1][i,:] + c[2][i,:]).cpu().numpy()
                ex = shap.KernelExplainer(pred_risk_24m, bg)
                shap_values_agg.append(ex.shap_values(test, nsamples=50, silent=True))
                shap_samples_agg.append(test)
            except: pass

        # --- METRICS ---
        outcomes_map = {1: 'Death', 2: 'Readmission'}
        
        # Structures for IPCW
        y_tr_st = np.array([(bool(e!=0), t) for e, t in zip(e_train, t_train)], dtype=[('e', bool), ('t', float)])
        y_va_st = np.array([(bool(e!=0), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])

        # Censoring model (G-hat)
        # Event = 1 if censored, 0 otherwise
        censoring_kmf = KaplanMeierFitter()
        censoring_kmf.fit(
            t_train,
            event_observed=(e_train == 0).astype(int)
        )

        for risk_id, outcome_name in outcomes_map.items():
            # Cause-Specific Structures
            y_tr_cs = np.array([(bool(e==risk_id), t) for e, t in zip(e_train, t_train)], dtype=[('e', bool), ('t', float)])
            y_va_cs = np.array([(bool(e==risk_id), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])
            
            # 1. Global Metrics
            # Corrected: Use Integral of CIF as global risk score
            # Simply summing CIF over all time points approximates the integral
            risk_score_global = cif_val[risk_id].sum(axis=0) 
            
            try:
                uno_g = concordance_index_ipcw(y_tr_cs, y_va_cs, risk_score_global)[0]
            except Exception:
                uno_g = np.nan
            
            pooled_results.append({'Imp': imp_idx, 'Fold': fold_idx, 'Outcome': outcome_name, 'Time': 'Global', 'Metric': 'Uno C-Index', 'Value': uno_g})
            
            # 2. Time-Dependent Metrics
            for t in EVAL_HORIZONS_TEST:
                idx_t = np.searchsorted(model.duration_index, t)
                if idx_t >= len(model.duration_index): idx_t = len(model.duration_index) - 1
                
                # Risk Scores
                risk_t_val = cif_val[risk_id][idx_t]
                risk_t_train = cif_train[risk_id][idx_t] # For threshold finding
                
                # A. Uno's C-Index
                try:
                    auc_u = concordance_index_ipcw(y_tr_cs, y_va_cs, risk_t_val, tau=t)[0]
                except: auc_u = np.nan
                
                # B. Brier Score
                # CIF at horizon t (validation)
                cif_t_val = cif_val[risk_id][idx_t]
                brier_cr = compute_brier_competing(
                    cif_values_at_time_horizon=cif_t_val,
                    censoring_kmf=censoring_kmf,
                    Y_test=t_val,
                    D_test=e_val,
                    event_of_interest=risk_id,
                    time_horizon=t
                )
                
                # C. Binary Metrics (Leakage-Free)
                # 1. Get binary targets for TRAIN and VAL
                y_bin_train, mask_train = get_binary_target(e_train, t_train, risk_id, t)
                y_bin_val, mask_val = get_binary_target(e_val, t_val, risk_id, t)
                
                if len(np.unique(y_bin_train)) > 1 and len(np.unique(y_bin_val)) > 1:
                    # 2. Find Best Threshold on TRAIN
                    best_th = find_optimal_threshold(y_bin_train, risk_t_train[mask_train])
                    
                    # 3. Apply to VAL
                    bin_met = calculate_binary_metrics(y_bin_val, risk_t_val[mask_val], best_th)
                    
                    # Log
                    auc_roc = roc_auc_score(y_bin_val, risk_t_val[mask_val])
                    metrics_pack = {
                        'Uno C-Index': auc_u, 'Brier Score': bs_t, 'AUC-ROC': auc_roc,
                        'F1': bin_met['F1'], 'Sens': bin_met['Sens'], 'Spec': bin_met['Spec'],
                        'PPV': bin_met['PPV'], 'NPV': bin_met['NPV']
                    }
                else:
                    metrics_pack = {'Uno C-Index': auc_u, 'Brier Score': bs_t}

                # --- log CR Brier ONCE ---
                for m_name, m_val in metrics_pack.items():
                    pooled_results.append({
                        'Imp': imp_idx,
                        'Fold': fold_idx,
                        'Outcome': outcome_name,
                        'Time': t,
                        'Metric': 'Brier Score (CR)',
                        'Value': brier_cr
                    })
                # --- log discrimination & classification metrics ---
                for m_name, m_val in metrics_pack.items():
                    pooled_results.append({
                        'Imp': imp_idx,
                        'Fold': fold_idx,
                        'Outcome': outcome_name,
                        'Time': t,
                        'Metric': m_name,   # ✅ correct
                        'Value': m_val      # ✅ correct
                    })            

# --- 5. AGGREGATION ---
df_res = pd.DataFrame(pooled_results)

# Group by Outcome/Time/Metric -> Aggregate across all Imputations and Folds
# Standard Error for Pooled CV is complex, but empirical percentiles are safe.
summary_stats = []
grouped = df_res.groupby(['Outcome', 'Time', 'Metric'])

for (outcome, time_pt, metric), group in grouped:
    vals = group['Value'].dropna().values
    mean_val, lower, upper = bootstrap_ci_non_normal(vals)
    summary_stats.append({
        'Outcome': outcome, 'Time': time_pt, 'Metric': metric,
        'Mean': mean_val, 'CI_Lower': lower, 'CI_Upper': upper,
        'Format': f"{mean_val:.3f} [{lower:.3f}-{upper:.3f}]"
    })

df_summary = pd.DataFrame(summary_stats)
filename = f"DH2_Pooled_1st_try_{datetime.now().strftime('%Y%m%d_%H%M')}.csv"
df_summary.to_csv(filename, sep=';', index=False)
print(f"\n💾 Saved: {filename}")

total_duration_min = (time.time() - start_time) / 60
print(f"\n🏁 Total Execution Time: {total_duration_min:.2f} minutes")
#df_res.to_csv("df_res.csv", sep=';', index=False)
⚡ Starting Pooled Evaluation on 5 Imputations...
   Device: cuda | Horizons: [3, 6, 12, 24, 36, 48, 60, 72, 84, 96, 108]

📚 Imputation 1/5
Event counts: [66231  3203 19070]
..........
📚 Imputation 2/5
Event counts: [66231  3203 19070]
..........
📚 Imputation 3/5
Event counts: [66231  3203 19070]
..........
📚 Imputation 4/5
Event counts: [66231  3203 19070]
..........
📚 Imputation 5/5
Event counts: [66231  3203 19070]
..........
💾 Saved: DH2_Pooled_1st_try_20260209_1547.csv

🏁 Total Execution Time: 94.84 minutes

🏁 Total Execution Time: 94.84 minutes

Code
global_metrics = (
    pd.DataFrame(df_summary)
      .loc[lambda d: d['Time'] == 'Global']
      .loc[:, ['Outcome', 'Metric', 'Mean', 'CI_Lower', 'CI_Upper', 'Format']]
      .sort_values(['Outcome', 'Metric'])
      .reset_index(drop=True)
)

global_metrics
Outcome Metric Mean CI_Lower CI_Upper Format
0 Death Uno C-Index 0.632906 0.564686 0.689986 0.633 [0.565-0.690]
1 Readmission Uno C-Index 0.565386 0.516718 0.606802 0.565 [0.517-0.607]
Code
ibs_summary2 = (
    pd.DataFrame(df_summary)
      .loc[lambda d: d['Metric'].isin([
            "Uno C-Index",
            "AUC-ROC",
            "Brier Score (CR)",
            "F1",
            "Sens",
            "Spec",
            "PPV",
            "NPV"
        ])]
      .groupby(['Outcome', 'Metric'])['Mean']
      .agg(['mean'])
      .reset_index()
)
ibs_summary2
Outcome Metric mean
0 Death AUC-ROC 0.694087
1 Death Brier Score (CR) 0.027983
2 Death F1 0.094219
3 Death NPV 0.965889
4 Death PPV 0.101438
5 Death Sens 0.187002
6 Death Spec 0.890558
7 Death Uno C-Index 0.697315
8 Readmission AUC-ROC 0.603936
9 Readmission Brier Score (CR) 0.172425
10 Readmission F1 0.209695
11 Readmission NPV 0.746145
12 Readmission PPV 0.274274
13 Readmission Sens 0.221130
14 Readmission Spec 0.804981
15 Readmission Uno C-Index 0.605723
Code
# Create a display version that replaces NaN with "-"
display_df = df_summary.fillna("-")

display_df
Outcome Time Metric Mean CI_Lower CI_Upper Format
0 Death 3 AUC-ROC 0.835825 0.682982 0.964249 0.836 [0.683-0.964]
1 Death 3 Brier Score 0.348736 0.348736 0.348736 0.349 [0.349-0.349]
2 Death 3 Brier Score (CR) 0.001409 0.000680 0.002269 0.001 [0.001-0.002]
3 Death 3 F1 0.000000 0.000000 0.000000 0.000 [0.000-0.000]
4 Death 3 NPV 0.998580 0.997735 0.999313 0.999 [0.998-0.999]
... ... ... ... ... ... ... ...
195 Readmission 108 PPV 0.640940 0.626979 0.651215 0.641 [0.627-0.651]
196 Readmission 108 Sens 0.927758 0.846938 0.981864 0.928 [0.847-0.982]
197 Readmission 108 Spec 0.113724 0.041024 0.197117 0.114 [0.041-0.197]
198 Readmission 108 Uno C-Index 0.577136 0.551798 0.600818 0.577 [0.552-0.601]
199 Readmission Global Uno C-Index 0.565386 0.516718 0.606802 0.565 [0.517-0.607]

200 rows × 7 columns

Code
ibs_summary3 = (
    pd.DataFrame(display_df)
      .loc[lambda d: d['Time'] != 'Global']
      .loc[lambda d: d['Metric'].isin(["Uno C-Index", 'Brier Score (CR)'])]
      .reset_index()
)
ibs_summary3
index Outcome Time Metric Mean CI_Lower CI_Upper Format
0 2 Death 3 Brier Score (CR) 0.001409 0.000680 0.002269 0.001 [0.001-0.002]
1 8 Death 3 Uno C-Index 0.834723 0.679083 0.963355 0.835 [0.679-0.963]
2 11 Death 6 Brier Score (CR) 0.003456 0.002287 0.004566 0.003 [0.002-0.005]
3 17 Death 6 Uno C-Index 0.783458 0.689148 0.856198 0.783 [0.689-0.856]
4 20 Death 12 Brier Score (CR) 0.006950 0.005464 0.008204 0.007 [0.005-0.008]
5 26 Death 12 Uno C-Index 0.740357 0.656759 0.803871 0.740 [0.657-0.804]
6 29 Death 24 Brier Score (CR) 0.014148 0.012473 0.015970 0.014 [0.012-0.016]
7 35 Death 24 Uno C-Index 0.707610 0.661904 0.744497 0.708 [0.662-0.744]
8 38 Death 36 Brier Score (CR) 0.021568 0.019394 0.023423 0.022 [0.019-0.023]
9 44 Death 36 Uno C-Index 0.687132 0.647658 0.727763 0.687 [0.648-0.728]
10 47 Death 48 Brier Score (CR) 0.028370 0.025737 0.031065 0.028 [0.026-0.031]
11 53 Death 48 Uno C-Index 0.680784 0.642279 0.727363 0.681 [0.642-0.727]
12 56 Death 60 Brier Score (CR) 0.035375 0.033468 0.038078 0.035 [0.033-0.038]
13 62 Death 60 Uno C-Index 0.674980 0.638681 0.729599 0.675 [0.639-0.730]
14 65 Death 72 Brier Score (CR) 0.041030 0.038740 0.043303 0.041 [0.039-0.043]
15 71 Death 72 Uno C-Index 0.668668 0.622088 0.704828 0.669 [0.622-0.705]
16 74 Death 84 Brier Score (CR) 0.047131 0.045207 0.048848 0.047 [0.045-0.049]
17 80 Death 84 Uno C-Index 0.661598 0.610813 0.702378 0.662 [0.611-0.702]
18 83 Death 96 Brier Score (CR) 0.051740 0.049695 0.054205 0.052 [0.050-0.054]
19 89 Death 96 Uno C-Index 0.651510 0.593922 0.687870 0.652 [0.594-0.688]
20 92 Death 108 Brier Score (CR) 0.056638 0.052577 0.062105 0.057 [0.053-0.062]
21 98 Death 108 Uno C-Index 0.644059 0.575629 0.681721 0.644 [0.576-0.682]
22 102 Readmission 3 Brier Score (CR) 0.007153 0.005684 0.008298 0.007 [0.006-0.008]
23 108 Readmission 3 Uno C-Index 0.752889 0.644374 0.832481 0.753 [0.644-0.832]
24 111 Readmission 6 Brier Score (CR) 0.022696 0.019838 0.025578 0.023 [0.020-0.026]
25 117 Readmission 6 Uno C-Index 0.680948 0.576752 0.761444 0.681 [0.577-0.761]
26 120 Readmission 12 Brier Score (CR) 0.064279 0.060927 0.068199 0.064 [0.061-0.068]
27 126 Readmission 12 Uno C-Index 0.619436 0.554982 0.673896 0.619 [0.555-0.674]
28 129 Readmission 24 Brier Score (CR) 0.129933 0.125765 0.134415 0.130 [0.126-0.134]
29 135 Readmission 24 Uno C-Index 0.591209 0.549940 0.617881 0.591 [0.550-0.618]
30 138 Readmission 36 Brier Score (CR) 0.174603 0.170337 0.179080 0.175 [0.170-0.179]
31 144 Readmission 36 Uno C-Index 0.584333 0.554861 0.611878 0.584 [0.555-0.612]
32 147 Readmission 48 Brier Score (CR) 0.206036 0.202353 0.210195 0.206 [0.202-0.210]
33 153 Readmission 48 Uno C-Index 0.581860 0.555561 0.604949 0.582 [0.556-0.605]
34 156 Readmission 60 Brier Score (CR) 0.228602 0.224022 0.231694 0.229 [0.224-0.232]
35 162 Readmission 60 Uno C-Index 0.580254 0.555123 0.604236 0.580 [0.555-0.604]
36 165 Readmission 72 Brier Score (CR) 0.246760 0.243769 0.250731 0.247 [0.244-0.251]
37 171 Readmission 72 Uno C-Index 0.578780 0.551990 0.604362 0.579 [0.552-0.604]
38 174 Readmission 84 Brier Score (CR) 0.261123 0.257673 0.265106 0.261 [0.258-0.265]
39 180 Readmission 84 Uno C-Index 0.578422 0.553650 0.603345 0.578 [0.554-0.603]
40 183 Readmission 96 Brier Score (CR) 0.272294 0.266618 0.277159 0.272 [0.267-0.277]
41 189 Readmission 96 Uno C-Index 0.578020 0.552951 0.601334 0.578 [0.553-0.601]
42 192 Readmission 108 Brier Score (CR) 0.283199 0.278440 0.288712 0.283 [0.278-0.289]
43 198 Readmission 108 Uno C-Index 0.577136 0.551798 0.600818 0.577 [0.552-0.601]
Code
import os
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime

def plot_time_metric_joint(
    df_res,
    metric,
    ylabel=None,
    title=None,
    colors={"Death": "tab:red", "Readmission": "tab:blue"},
    band_alpha=0.25,
    ylim=(0, 1),
    outdir="dh3"
):
    # --- timestamp ---
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")

    # --- create output folder ---
    os.makedirs(outdir, exist_ok=True)

    plt.figure(figsize=(7.5, 5.5))

    for outcome, color in colors.items():
        dfp = (
            df_res
            .loc[
                (df_res["Metric"] == metric) &
                (df_res["Outcome"] == outcome) &
                (df_res["Time"] != "Global")
            ]
            .groupby("Time")["Value"]
            .agg(
                mean="mean",
                q25=lambda x: np.percentile(x, 25),
                q75=lambda x: np.percentile(x, 75),
            )
            .reset_index()
            .sort_values("Time")
        )

        if dfp.empty:
            continue

        plt.plot(
            dfp["Time"],
            dfp["mean"],
            color=color,
            linewidth=2,
            label=outcome
        )
        plt.fill_between(
            dfp["Time"],
            dfp["q25"],
            dfp["q75"],
            color=color,
            alpha=band_alpha
        )

    plt.xlabel("Time")
    plt.ylabel(ylabel if ylabel else metric)
    plt.title(title if title else f"Time-dependent {metric}")
    plt.ylim(*ylim)
    plt.legend(frameon=False)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # --- filenames ---
    base = f"DH3_{metric}_{timestamp}"
    png_path = os.path.join(outdir, f"{base}.png")
    pdf_path = os.path.join(outdir, f"{base}.pdf")

    # --- save ---
    plt.savefig(png_path, dpi=300, bbox_inches="tight")
    plt.savefig(pdf_path, bbox_inches="tight")
    plt.show()
    plt.close()

    print(f"💾 Saved: {png_path}")
    print(f"💾 Saved: {pdf_path}")


plot_time_metric_joint(
    df_res,
    metric="F1",
    ylabel="F1 score",
    title="Time-dependent F1 score",
    ylim=(0, 1)
)
plot_time_metric_joint(
    df_res,
    metric="PPV",
    ylabel="Positive predictive value",
    title="Time-dependent PPV",
    ylim=(0, 1)
)
plot_time_metric_joint(
    df_res,
    metric="NPV",
    ylabel="Negative predictive value",
    title="Time-dependent NPV",
    ylim=(0, 1)
)
plot_time_metric_joint(
    df_res,
    metric="AUC-ROC",
    ylabel="AUC-ROC",
    title="Time-dependent AUC-ROC",
    ylim=(0.5, 1)
)
plot_time_metric_joint(
    df_res=df_res,
    metric="Brier Score (CR)",
    ylabel="Brier score (competing risk)",
    title="Time-dependent competing-risk Brier score"
)
plot_time_metric_joint(
    df_res=df_res,
    metric="Uno C-Index",
    ylabel="Uno’s C-index",
    title="Time-dependent discrimination (Uno’s C-index)"
)

💾 Saved: dh3\DH3_F1_20260209_1600.png
💾 Saved: dh3\DH3_F1_20260209_1600.pdf

💾 Saved: dh3\DH3_PPV_20260209_1600.png
💾 Saved: dh3\DH3_PPV_20260209_1600.pdf

💾 Saved: dh3\DH3_NPV_20260209_1600.png
💾 Saved: dh3\DH3_NPV_20260209_1600.pdf

💾 Saved: dh3\DH3_AUC-ROC_20260209_1600.png
💾 Saved: dh3\DH3_AUC-ROC_20260209_1600.pdf

💾 Saved: dh3\DH3_Brier Score (CR)_20260209_1600.png
💾 Saved: dh3\DH3_Brier Score (CR)_20260209_1600.pdf

💾 Saved: dh3\DH3_Uno C-Index_20260209_1600.png
💾 Saved: dh3\DH3_Uno C-Index_20260209_1600.pdf
Code
#@title 📈 Take-Home Message: Time-Dependent Model Performance (DeepHit, Competing Risks)

import pandas as pd
from IPython.display import display

performance_msg = pd.DataFrame([

    # --- DISCRIMINATION ---
    {
        'Metric': 'AUC-ROC',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Stable for Death, Weaker for Readmission',
        'Interpretation': (
            'Discrimination is consistently high and stable for Death across follow-up, '
            'indicating a strong and persistent ranking of mortality risk. '
            'For Readmission, AUC-ROC is systematically lower and more variable, '
            'reflecting the greater stochasticity and behavioral component of readmission events.'
        )
    },
    {
        'Metric': 'Uno’s C-Index',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Clear Separation, Occasional Degeneracy',
        'Interpretation': (
            'Uno’s C-index closely mirrors AUC-ROC, with strong time-dependent discrimination for Death. '
            'For Readmission, occasional drops or undefined values reflect sparse comparable pairs after IPCW weighting, '
            'a known property of competing-risk concordance when the event is rare.'
        )
    },

    # --- CALIBRATION / ACCURACY ---
    {
        'Metric': 'Brier Score (Competing Risk)',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Lower for Death, Higher for Readmission',
        'Interpretation': (
            'Prediction error remains low and stable for Death, indicating good calibration of absolute risk. '
            'Higher and more variable Brier scores for Readmission indicate greater uncertainty in individual-level risk prediction, '
            'consistent with heterogeneous relapse pathways.'
        )
    },

    # --- CLASSIFICATION METRICS ---
    {
        'Metric': 'F1 Score',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Moderate for Death, Low for Readmission',
        'Interpretation': (
            'The harmonic balance between precision and recall is acceptable for Death but substantially lower for Readmission. '
            'This reflects the intrinsic difficulty of simultaneously identifying true readmissions while avoiding false alarms '
            'in a low-incidence competing-risk setting.'
        )
    },
    {
        'Metric': 'PPV',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Low–Moderate Across Horizons',
        'Interpretation': (
            'Positive predictive value remains limited, especially for Readmission, indicating that many high-risk predictions '
            'do not materialize into observed events. This is expected in rare-event settings and cautions against interpreting '
            'risk scores as deterministic forecasts.'
        )
    },
    {
        'Metric': 'NPV',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Consistently High',
        'Interpretation': (
            'Negative predictive value is high for both outcomes, indicating that the model is reliable at identifying individuals '
            'unlikely to experience the event within a given horizon. This makes the model particularly useful for ruling out '
            'near-term risk rather than confirming it.'
        )
    }
])

print("\n>>> TAKE-HOME MESSAGE: TIME-DEPENDENT PERFORMANCE (DeepHit, AJ-IPCW)")
pd.set_option('display.max_colwidth', None)
display(performance_msg.style.set_properties(**{
    'text-align': 'left',
    'white-space': 'pre-wrap'
}))

>>> TAKE-HOME MESSAGE: TIME-DEPENDENT PERFORMANCE (DeepHit, AJ-IPCW)
  Metric Outcome Pattern Interpretation
0 AUC-ROC Death vs Readmission Stable for Death, Weaker for Readmission Discrimination is consistently high and stable for Death across follow-up, indicating a strong and persistent ranking of mortality risk. For Readmission, AUC-ROC is systematically lower and more variable, reflecting the greater stochasticity and behavioral component of readmission events.
1 Uno’s C-Index Death vs Readmission Clear Separation, Occasional Degeneracy Uno’s C-index closely mirrors AUC-ROC, with strong time-dependent discrimination for Death. For Readmission, occasional drops or undefined values reflect sparse comparable pairs after IPCW weighting, a known property of competing-risk concordance when the event is rare.
2 Brier Score (Competing Risk) Death vs Readmission Lower for Death, Higher for Readmission Prediction error remains low and stable for Death, indicating good calibration of absolute risk. Higher and more variable Brier scores for Readmission indicate greater uncertainty in individual-level risk prediction, consistent with heterogeneous relapse pathways.
3 F1 Score Death vs Readmission Moderate for Death, Low for Readmission The harmonic balance between precision and recall is acceptable for Death but substantially lower for Readmission. This reflects the intrinsic difficulty of simultaneously identifying true readmissions while avoiding false alarms in a low-incidence competing-risk setting.
4 PPV Death vs Readmission Low–Moderate Across Horizons Positive predictive value remains limited, especially for Readmission, indicating that many high-risk predictions do not materialize into observed events. This is expected in rare-event settings and cautions against interpreting risk scores as deterministic forecasts.
5 NPV Death vs Readmission Consistently High Negative predictive value is high for both outcomes, indicating that the model is reliable at identifying individuals unlikely to experience the event within a given horizon. This makes the model particularly useful for ruling out near-term risk rather than confirming it.

Second attempt (Robust, w/SHAP, corrected for competing risk)

  1. Implements DeepHit for competing risks survival analysis.
  2. Handles death and readmission as competing events.
  3. Uses Aalen-Johansen for censoring distribution estimation.
  4. Performs stratified K-fold CV (K=10) across imputations.
  5. Computes time-dependent metrics like Uno’s C-Index and Brier.
  6. Evaluates binary classification metrics at horizons.
  7. Calculates multi-horizon SHAP for interpretability.
  8. Standardizes features and discretizes time for model.
  9. Bootstraps confidence intervals for pooled results.
  10. Exports metrics, predictions, and SHAP to files.
Code
#@title ⚡ Final DeepHit: AJ-IPCW, Multi-Horizon SHAP, Pooled
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import shap
import time
import gc
import warnings
import os
import pickle
from lifelines import AalenJohansenFitter
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, confusion_matrix
from sksurv.metrics import concordance_index_ipcw 
from pycox.models import DeepHit
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
import torchtuples as tt

# --- 1. CONFIGURATION ---
TEST_MODE = False # Set True for fast debug

# Optimal Hyperparameters
BEST_LR = 0.001
BEST_WD = 0.001
BEST_BATCH = 1024
BEST_DROPOUT = 0.5
BEST_NODES = [256, 256, 128]
BEST_ALPHA = 0.5
BEST_SIGMA = 0.5
NUM_RISKS = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# SHAP Configuration
SHAP_HORIZONS = [3, 12, 24, 48, 60, 72]

if TEST_MODE:
    print("⚠️ TEST MODE ACTIVE")
    N_IMPUTATIONS_RUN = 1
    K_FOLDS_RUN = 2
    EVAL_HORIZONS_RUN = [12, 24]
    EPOCHS_RUN = 5
    SHAP_SAMPLES = 10
else:
    if 'imputations_list_jan26' in locals():
        N_IMPUTATIONS_RUN = len(imputations_list_jan26)
    else:
        print("❌ Data not found."); N_IMPUTATIONS_RUN = 0
    K_FOLDS_RUN = 10
    EVAL_HORIZONS_RUN = [3, 6, 12, 24, 36, 48, 60, 72, 84, 96, 108]
    EPOCHS_RUN = 100
    SHAP_SAMPLES = 50

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
BASE_NAME = f"DH4_Final_AJ_{timestamp}"
CSV_PATH = f"{BASE_NAME}_Metrics.csv"
PKL_PATH = f"{BASE_NAME}_RawPreds.pkl"
SHAP_PATH = f"{BASE_NAME}_SHAP_MultiHorizon.pkl"

warnings.filterwarnings("ignore")

# --- 2. CUSTOM AALEN-JOHANSEN CENSORING ---
class AalenJohansenCensoring:
    """
    Estimates Censoring Distribution G(t) = P(C > t) using Aalen-Johansen.
    Treats 'Censoring' as Event 1, and 'Death/Readm' as Competing Event 2.
    """
    def __init__(self):
        self.ajf = AalenJohansenFitter(calculate_variance=False)
        self.max_time = 0
        
    def fit(self, durations, events_composite):
        # Input events: 0=Censored, 1=Death, 2=Readm
        # Transformation for Censoring Distribution:
        # We want to estimate probability of being Censored.
        # Target Event: 0 (The original censored people) -> Mapped to 1
        # Competing Event: 1, 2 (The original events) -> Mapped to 2
        
        aj_events = np.zeros_like(events_composite)
        # People who were originally censored (0) are now the Event of Interest (1)
        aj_events[events_composite == 0] = 1 
        # People who died/readmitted (1, 2) are now Competing Risks (2)
        aj_events[events_composite > 0] = 2
        
        self.max_time = durations.max()
        self.ajf.fit(durations, event_observed=aj_events, event_of_interest=1)
        
    def predict(self, times):
        # AJF predicts CIF_c(t) = P(C <= t, Event=Censored)
        # We need G(t) = P(C > t)
        # In a competing risk setting: P(C > t) = 1 - CIF_c(t)
        
        if np.isscalar(times):
            cif_val = self.ajf.predict(times).item()
            return 1.0 - cif_val
        else:
            cif_vals = self.ajf.predict(times).values.flatten()
            return 1.0 - cif_vals

def compute_brier_competing(cif_values_at_time_horizon, censoring_dist, 
                            Y_test, D_test, event_of_interest, time_horizon):
    """
    Brier Score using Aalen-Johansen weights.
    """
    n = len(Y_test)
    residuals = np.zeros(n)
    
    # Weight at horizon G(t)
    w_horizon = censoring_dist.predict(time_horizon)
    if w_horizon == 0: w_horizon = 1e-9
    
    # Weights at observed times G(Ti)
    # We pre-calculate all observed weights to avoid loop overhead
    w_obs_all = censoring_dist.predict(Y_test)
    w_obs_all[w_obs_all == 0] = 1e-9
    
    for idx in range(n):
        observed_time = Y_test[idx]
        event_indicator = D_test[idx]
        
        if observed_time > time_horizon:
            # Surviving past horizon
            # Weight = 1 / G(t_horizon)
            residuals[idx] = (cif_values_at_time_horizon[idx])**2 / w_horizon
        else:
            # Observed <= horizon
            # Weight = 1 / G(Ti)
            w_obs = w_obs_all[idx]
            
            if event_indicator == event_of_interest:
                # Outcome occurred
                residuals[idx] = (1 - cif_values_at_time_horizon[idx])**2 / w_obs
            elif event_indicator != event_of_interest and event_indicator != 0:
                # Competing event occurred
                residuals[idx] = (cif_values_at_time_horizon[idx])**2 / w_obs
            # Censored before horizon -> Residual 0 (Excluded)
                
    return residuals.mean()

# --- 3. STANDARD HELPERS ---
class CauseSpecificNet(nn.Module):
    def __init__(self, in_f, nodes, out_f, dropout, num_risks):
        super().__init__()
        self.net = tt.practical.MLPVanilla(in_f, nodes, out_f, batch_norm=True, dropout=dropout)
        self.num_risks = num_risks
    def forward(self, x):
        return self.net(x).view(x.size(0), self.num_risks, -1)

def get_binary_target(events, times, risk_id, t_horizon):
    # Case: Event of interest happened by t
    is_case = (events == risk_id) & (times <= t_horizon)
    # Control: Survived past t OR Competing Event happened before t
    # (We ONLY exclude people who were Censored (lost to follow-up) before t)
    mask_censored_early = (events == 0) & (times <= t_horizon)
    valid_mask = ~mask_censored_early
    # This keeps Competing Events as 0 (Valid Controls)
    y_binary = is_case[valid_mask].astype(int)
    return y_binary, valid_mask

def find_optimal_threshold(y_true, y_prob):
    thresholds = np.linspace(0.01, 0.99, 99)
    best_f1 = -1; best_th = 0.5
    for th in thresholds:
        f1 = f1_score(y_true, (y_prob >= th).astype(int), zero_division=0)
        if f1 > best_f1: best_f1 = f1; best_th = th
    return best_th

def calculate_binary_metrics(y_true, y_prob, fixed_threshold):
    y_pred = (y_prob >= fixed_threshold).astype(int)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return {
        'F1': f1_score(y_true, y_pred, zero_division=0),
        'Sens': recall_score(y_true, y_pred, zero_division=0),
        'Spec': tn / (tn + fp) if (tn + fp) > 0 else 0,
        'PPV': precision_score(y_true, y_pred, zero_division=0),
        'NPV': tn / (tn + fn) if (tn + fn) > 0 else 0,
        'Threshold': fixed_threshold
    }

def bootstrap_ci_non_normal(data, alpha=0.05):
    if len(data) < 2: return np.mean(data), np.mean(data), np.mean(data)
    lower = np.percentile(data, 100 * (alpha / 2))
    upper = np.percentile(data, 100 * (1 - alpha / 2))
    return np.mean(data), lower, upper

# --- 4. EXECUTION LOOP (CLEAN & FIXED) ---
pooled_results = []
raw_predictions_log = []

# SHAP Storage (Nested by Horizon)
shap_storage = {
    'death': {h: {'shap_values': [], 'data': []} for h in SHAP_HORIZONS},
    'readm': {h: {'shap_values': [], 'data': []} for h in SHAP_HORIZONS}
}

start_time_global = time.time()

if os.path.exists(CSV_PATH) and not TEST_MODE:
    print(f"⚠️ Found existing results. Aborting to prevent overwrite.")
else:
    print(f"⚡ Starting Analysis on {N_IMPUTATIONS_RUN} imputations...")
    
    for imp_idx in range(N_IMPUTATIONS_RUN):
        print(f"\n📚 Imputation {imp_idx + 1}/{N_IMPUTATIONS_RUN}")
        
        X_raw = imputations_list_jan26[imp_idx]
        y_d = y_surv_death_list[imp_idx]
        y_r = y_surv_readm_list[imp_idx]
        
        # --- TRAINING COMPOSITE (First Event) ---
        t_d = y_d['time'].values if hasattr(y_d['time'], 'values') else y_d['time']
        e_d = y_d['event'].values if hasattr(y_d['event'], 'values') else y_d['event']
        t_r = y_r['time'].values if hasattr(y_r['time'], 'values') else y_r['time']
        e_r = y_r['event'].values if hasattr(y_r['event'], 'values') else y_r['event']
        
        # Create Composite Targets for Training
        events_cr = np.zeros(len(X_raw), dtype=int)
        times_cr = t_d.copy().astype('float32')
        
        mask_r = (e_r.astype(bool)) & (t_r <= t_d) # Readm first
        events_cr[mask_r] = 2
        times_cr[mask_r] = t_r[mask_r].astype('float32')
        
        mask_d = (e_d.astype(bool)) & (~mask_r) # Death first/only
        events_cr[mask_d] = 1
        
        # Stratification Labels
        X_curr = pd.get_dummies(X_raw, drop_first=True).astype('float32')
        plan_cols = [c for c in X_raw.columns if c.startswith("plan_type_corr")]
        X_curr[plan_cols] = X_curr[plan_cols].astype('float32')

        #✅ Step 1: define an explicit plan index mapping
        plan_idx = np.zeros(len(X_curr), dtype=int)  # 0 = pg_pab (reference)
        if 'plan_type_corr_pg_pr' in X_curr.columns:
            plan_idx[X_curr['plan_type_corr_pg_pr'] == 1] = 1
        if 'plan_type_corr_m_pr' in X_curr.columns:
            plan_idx[X_curr['plan_type_corr_m_pr'] == 1] = 2
        if 'plan_type_corr_pg_pai' in X_curr.columns:
            plan_idx[X_curr['plan_type_corr_pg_pai'] == 1] = 3
        if 'plan_type_corr_m_pai' in X_curr.columns:
            plan_idx[X_curr['plan_type_corr_m_pai'] == 1] = 4
        # Exactly one plan must be true per row
        plan_sum = X_curr[plan_cols].sum(axis=1)        
        
        if (plan_sum > 1).any():
            raise ValueError("Invalid plan encoding: some rows have >1 plan types.")
        #✅ Step 2: combine with competing-risk event for stratification
        strat_labels = (events_cr * 10) + plan_idx

        # CV
        skf = StratifiedKFold(n_splits=K_FOLDS_RUN, shuffle=True, random_state=2125 + imp_idx)
        
        for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_curr, strat_labels)):
            print(f".", end="")
            
            # Split Data
            X_train, X_val = X_curr.iloc[train_idx].values, X_curr.iloc[val_idx].values
            t_train_c, e_train_c = times_cr[train_idx], events_cr[train_idx]
            t_val_c, e_val_c = times_cr[val_idx], events_cr[val_idx]
            
            # Raw Eval Targets
            t_d_val, e_d_val = t_d[val_idx], e_d[val_idx].astype(bool)
            t_r_val, e_r_val = t_r[val_idx], e_r[val_idx].astype(bool)
            t_train_c_raw = times_cr[train_idx]
            e_train_c_raw = events_cr[train_idx]

            # Fit AJ Censoring
            aj_censor = AalenJohansenCensoring()
            aj_censor.fit(t_train_c_raw, e_train_c_raw)

            # Scale
            scaler = StandardScaler().fit(X_train)
            X_train_s = scaler.transform(X_train).astype('float32')
            X_val_s = scaler.transform(X_val).astype('float32')
            
            # --- 🟢 FIX: Explicit int64 Casting ---
            labtrans = LabTransDiscreteTime(100)
            y_train_lt_raw = labtrans.fit_transform(t_train_c, e_train_c)
            y_train_lt = (y_train_lt_raw[0].astype('int64'), y_train_lt_raw[1].astype('int64'))
            
            y_val_lt_raw = labtrans.transform(t_val_c, e_val_c)
            y_val_lt = (y_val_lt_raw[0].astype('int64'), y_val_lt_raw[1].astype('int64'))
            # --------------------------------------
            
            # Train DeepHit
            in_f = X_train.shape[1]
            out_f = labtrans.out_features * NUM_RISKS
            net = CauseSpecificNet(in_f, BEST_NODES, out_f, BEST_DROPOUT, NUM_RISKS)
            model = DeepHit(net, tt.optim.Adam, alpha=BEST_ALPHA, sigma=BEST_SIGMA, duration_index=labtrans.cuts)
            model.set_device(DEVICE)
            model.optimizer.set_lr(BEST_LR); model.optimizer.param_groups[0]['weight_decay'] = BEST_WD
            
            model.fit(X_train_s, y_train_lt, batch_size=BEST_BATCH, epochs=EPOCHS_RUN, 
                      callbacks=[tt.callbacks.EarlyStopping()], verbose=False, 
                      val_data=(X_val_s, y_val_lt))
            
            # Predict
            cif_val = model.predict_cif(X_val_s)
            cif_train = model.predict_cif(X_train_s)
            
            # --- STORE RAW PREDICTIONS ---
            raw_predictions_log.append({
                'imp': imp_idx,
                'fold': fold_idx,
                'cif_pred': cif_val,
                'duration_index': model.duration_index,
                'y_time_val': t_val_c, # Storing composite targets
                'y_event_val': e_val_c
            })
            # --- MULTI-HORIZON SHAP (Imp 0 Only) ---
            if imp_idx == 0:
                try:
                    bg = X_train_s[np.random.choice(len(X_train_s), SHAP_SAMPLES, replace=False)]
                    val_sub_idx = np.random.choice(len(X_val_s), min(len(X_val_s), SHAP_SAMPLES), replace=False)
                    val_samp = X_val_s[val_sub_idx]
                    
                    for h_shap in SHAP_HORIZONS:
                        # Death Wrapper
                        def prd(x):
                            with torch.no_grad():
                                c = model.predict_cif(torch.from_numpy(x).to(DEVICE))
                                i = np.searchsorted(model.duration_index, h_shap)
                                if i >= len(model.duration_index): i = -1
                                return c[1][i,:].cpu().numpy()
                        
                        s_vals_d = shap.KernelExplainer(prd, bg).shap_values(val_samp, nsamples=SHAP_SAMPLES, silent=True)
                        shap_storage['death'][h_shap]['shap_values'].append(s_vals_d)
                        shap_storage['death'][h_shap]['data'].append(pd.DataFrame(val_samp, columns=X_curr.columns))

                        # Readm Wrapper
                        def prr(x):
                            with torch.no_grad():
                                c = model.predict_cif(torch.from_numpy(x).to(DEVICE))
                                i = np.searchsorted(model.duration_index, h_shap)
                                if i >= len(model.duration_index): i = -1
                                return c[2][i,:].cpu().numpy()
                                
                        s_vals_r = shap.KernelExplainer(prr, bg).shap_values(val_samp, nsamples=SHAP_SAMPLES, silent=True)
                        shap_storage['readm'][h_shap]['shap_values'].append(s_vals_r)
                        shap_storage['readm'][h_shap]['data'].append(pd.DataFrame(val_samp, columns=X_curr.columns))
                except Exception as e:
                    # print(f"SHAP Error: {e}")
                    pass

            # --- METRICS EVALUATION ---
            outcomes = {1: ('Death', t_d_val, e_d_val, t_d[train_idx], e_d[train_idx].astype(bool)), 
                        2: ('Readm', t_r_val, e_r_val, t_r[train_idx], e_r[train_idx].astype(bool))}

            for rid, (rname, t_v, e_v, t_tr, e_tr) in outcomes.items():
                # Global Uno
                try: 
                    y_tr_st = np.array([(e, t) for e,t in zip(e_tr, t_tr)], dtype=[('e',bool),('t',float)])
                    y_va_st = np.array([(e, t) for e,t in zip(e_v, t_v)], dtype=[('e',bool),('t',float)])
                    ug = concordance_index_ipcw(y_tr_st, y_va_st, cif_val[rid].sum(axis=0))[0]
                except: ug = np.nan
                pooled_results.append({'Imp': imp_idx, 'Fold': fold_idx, 'Outcome': rname, 'Time': 'Global', 'Metric': 'Uno C-Index', 'Value': ug})

                # Time-Dependent
                for t in EVAL_HORIZONS_RUN:
                    idx = np.searchsorted(model.duration_index, t)
                    if idx >= len(model.duration_index): idx = -1
                    
                    risk_v = cif_val[rid][idx]
                    risk_tr = cif_train[rid][idx]
                    
                    # 1. Uno
                    try: u = concordance_index_ipcw(y_tr_st, y_va_st, risk_v, tau=t)[0]
                    except: u = np.nan
                    
                    # 2. Brier (AJ-weighted)
                    try:
                        bs = compute_brier_competing(risk_v, aj_censor, t_v, 
                                                     events_cr[val_idx], 
                                                     rid, t)
                    except: bs = np.nan
                    
                    # 3. Binary
                    y_bin_tr, m_tr = get_binary_target(e_tr.astype(int), t_tr, 1, t) 
                    y_bin_va, m_va = get_binary_target(e_v.astype(int), t_v, 1, t)
                    
                    if len(np.unique(y_bin_tr))>1 and len(np.unique(y_bin_va))>1:
                        th = find_optimal_threshold(y_bin_tr, risk_tr[m_tr])
                        bin_m = calculate_binary_metrics(y_bin_va, risk_v[m_va], th)
                        auc = roc_auc_score(y_bin_va, risk_v[m_va])
                        mp = {'Uno C-Index': u, 'Brier Score': bs, 'AUC-ROC': auc, 
                              'F1': bin_m['F1'], 'Sens': bin_m['Sens'], 'Spec': bin_m['Spec'], 
                              'PPV': bin_m['PPV'], 'NPV': bin_m['NPV']}
                    else:
                        mp = {'Uno C-Index': u, 'Brier Score': bs}
                        
                    for k, v in mp.items():
                        pooled_results.append({'Imp': imp_idx, 'Fold': fold_idx, 'Outcome': rname, 'Time': t, 'Metric': k, 'Value': v})

            del model, net, cif_val, cif_train
            gc.collect()

    # --- EXPORT ---
    print("\n💾 Saving Results...")
    
    # Metrics
    df_res = pd.DataFrame(pooled_results)
    summ = []
    for (o, t, m), g in df_res.groupby(['Outcome', 'Time', 'Metric']):
        vals = g['Value'].dropna().values
        mn, lo, hi = bootstrap_ci_non_normal(vals)
        summ.append({'Outcome': o, 'Time': t, 'Metric': m, 'Mean': mn, 'CI_Lo': lo, 'CI_Hi': hi, 'Str': f"{mn:.3f} [{lo:.3f}-{hi:.3f}]"})
    pd.DataFrame(summ).to_csv(CSV_PATH, sep=';', index=False)
    
    # 2. ADD THIS BLOCK TO SAVE RAW PREDICTIONS:
    print(f"💾 Saving Raw Predictions to {PKL_PATH}...")
    with open(PKL_PATH, 'wb') as f:
        pickle.dump(raw_predictions_log, f)

    # 3. SHAP (Complex Nested Structure, Consolidating Multi-Horizon Data)
    # Check if we actually collected data (Imputation 0 must have run)
    if shap_storage['death'][SHAP_HORIZONS[0]]['shap_values']:
        print(f"💾 Consolidating and Saving Multi-Horizon SHAP to {SHAP_PATH}...")
        
        # Initialize final structure
        final_shap_export = {'death': {}, 'readm': {}}
        
        for h in SHAP_HORIZONS:
            # --- CONSOLIDATE DEATH SHAP ---
            d_vals = shap_storage['death'][h]['shap_values'] # List of arrays
            d_data = shap_storage['death'][h]['data']       # List of DataFrames
            
            if d_vals: # If list is not empty
                final_shap_export['death'][h] = {
                    # Stack arrays vertically (row-wise)
                    'shap_values': np.concatenate(d_vals, axis=0),
                    # Stack DataFrames vertically
                    'data': pd.concat(d_data, axis=0)
                }
            
            # --- CONSOLIDATE READMISSION SHAP ---
            r_vals = shap_storage['readm'][h]['shap_values']
            r_data = shap_storage['readm'][h]['data']
            
            if r_vals:
                final_shap_export['readm'][h] = {
                    'shap_values': np.concatenate(r_vals, axis=0),
                    'data': pd.concat(r_data, axis=0)
                }
                
        # Write to disk
        with open(SHAP_PATH, 'wb') as f:
            pickle.dump(final_shap_export, f)
            
    print(f"🏁 Done. Total time: {(time.time()-start_time_global)/60:.1f}m")
⚡ Starting Analysis on 5 imputations...

📚 Imputation 1/5
..........
📚 Imputation 2/5
..........
📚 Imputation 3/5
..........
📚 Imputation 4/5
..........
📚 Imputation 5/5
..........
💾 Saving Results...
💾 Saving Raw Predictions to DH4_Final_AJ_20260209_2329_RawPreds.pkl...
💾 Consolidating and Saving Multi-Horizon SHAP to DH4_Final_AJ_20260209_2329_SHAP_MultiHorizon.pkl...
🏁 Done. Total time: 38.0m

The AJ-IPCW DeepHit pipeline is the more robust approach, as it uses a censoring estimator valid under competing risks, treats competing events correctly in both discrimination and classification metrics, and evaluates a clearly defined time-to-first-event estimand. The more computationally intensive alternative does not provide additional statistical guarantees and relies on assumptions that are violated in competing-risk settings. The older script initially dropped competing events (No artificial inflation of AUC / F1) and assumed competing events behave like independent censoring (used Kaplan Meier)

To ensure the stability and generalizability of the discovered risk factors, Shapley Additive Explanations (SHAP) were computed using a Pooled Cross-Validation Approach. - Imputation Strategy: Analysis was conducted on the first imputed dataset to control for data variance, allowing for a focused assessment of model-driven variance. - Pooled Estimation: Kernel SHAP values were computed for the validation set of each cross-validation fold. These estimates were then concatenated to form a comprehensive ‘Super-Validation’ set. - Robustness Defense: This pooling technique ensures that the reported functional forms (e.g., the linearity of Age or the log-decay of Treatment Duration) represent the average learned behavior across all data splits, effectively smoothing out noise specific to any single training fold.

Code
glimpse(df_res)
Rows: 8900 | Columns: 6
Imp                            int64           0, 0, 0, 0, 0
Fold                           int64           0, 0, 0, 0, 0
Outcome                        object          Death, Death, Death, Death, Death
Time                           object          Global, 3, 3, 3, 3
Metric                         object          Uno C-Index, Uno C-Index, Brier Score, AUC-ROC, F1
Value                          float64         0.6619960031781943, 0.8342944322631493, 0.001134686348268598, 0.8349293833964864...
Code
import os
import numpy as np
import pandas as pd

# 1. Create the directory
os.makedirs("dh5", exist_ok=True)

# 2. Robust Aggregation using Named Aggregation
# This avoids the MultiIndex issue and explicitly creates the columns you need
df_agg = (
    df_res
    .groupby(["Outcome", "Time", "Metric"])["Value"]
    .agg(
        mean='mean',
        lo=lambda x: np.percentile(x, 2.5) if len(x) > 0 else np.nan,
        hi=lambda x: np.percentile(x, 97.5) if len(x) > 0 else np.nan
    )
    .reset_index()
)

# 3. Filter for specific metrics
# Note: Ensure "Brier Score" is the correct name used in your loop
metrics_keep = ["Uno C-Index", "Brier Score", "AUC-ROC"] 
df_agg = df_agg[df_agg["Metric"].isin(metrics_keep)].copy()

# 4. Create the formatted string
df_agg["fmt"] = (
    df_agg["mean"].map('{:.3f}'.format)
    + " ["
    + df_agg["lo"].map('{:.3f}'.format)
    + "–"
    + df_agg["hi"].map('{:.3f}'.format)
    + "]"
)

# 5. Pivot for Table 2
master_time_db = (
    df_agg
    .pivot(index="Time", columns=["Outcome", "Metric"], values="fmt")
)

# Flatten columns: e.g., "Death_Uno_C-Index"
master_time_db.columns = [
    f"{metric.replace(' ', '_')}_{outcome}"
    for outcome, metric in master_time_db.columns
]

# 6. Define cols_final (Ordering the columns logically)
# This ensures Death metrics come before Readmission metrics
cols_final = [c for c in master_time_db.columns if "Death" in c] + \
             [c for c in master_time_db.columns if "Readm" in c]

# 7. Apply Styling
styled_master = (
    master_time_db[cols_final]
    .style
    .set_caption("<b>Table 2: Time-Dependent Performance Metrics (Mean + 95% CI)</b>")
    .set_table_styles([
        {'selector': 'caption', 'props': [
            ('color', '#333'), ('font-size', '16px'), ('font-weight', 'bold'), ('margin-bottom', '10px')
        ]},
        {'selector': 'th', 'props': [
            ('background-color', '#f4f4f4'), ('color', 'black'), ('border-bottom', '2px solid #555'), ('text-align', 'center')
        ]},
        {'selector': 'td', 'props': [('text-align', 'center'), ('padding', '8px')]},
        {'selector': 'tr:hover', 'props': [('background-color', '#f5f5f5')]}
    ])
    .format(na_rep="-")
)

# Display the result
styled_master
Table 1: Table 2: Time-Dependent Performance Metrics (Mean + 95% CI)
  AUC-ROC_Death Brier_Score_Death Uno_C-Index_Death AUC-ROC_Readm Brier_Score_Readm Uno_C-Index_Readm
Time            
3 0.831 [0.670–0.933] 0.001 [0.001–0.002] 0.830 [0.669–0.932] 0.739 [0.568–0.865] 0.007 [0.006–0.009] 0.740 [0.567–0.866]
6 0.779 [0.695–0.880] 0.003 [0.003–0.005] 0.778 [0.695–0.879] 0.675 [0.553–0.774] 0.023 [0.020–0.026] 0.674 [0.555–0.772]
12 0.739 [0.672–0.801] 0.007 [0.005–0.009] 0.739 [0.670–0.801] 0.618 [0.543–0.675] 0.064 [0.060–0.070] 0.617 [0.540–0.674]
24 0.706 [0.659–0.749] 0.014 [0.012–0.016] 0.704 [0.658–0.748] 0.594 [0.552–0.624] 0.129 [0.124–0.135] 0.590 [0.550–0.620]
36 0.686 [0.638–0.731] 0.021 [0.019–0.023] 0.682 [0.635–0.729] 0.588 [0.555–0.611] 0.173 [0.169–0.177] 0.582 [0.551–0.604]
48 0.680 [0.633–0.716] 0.028 [0.026–0.030] 0.675 [0.632–0.710] 0.586 [0.560–0.609] 0.203 [0.199–0.207] 0.579 [0.548–0.600]
60 0.674 [0.633–0.707] 0.034 [0.032–0.036] 0.669 [0.632–0.700] 0.583 [0.559–0.604] 0.223 [0.220–0.227] 0.577 [0.548–0.599]
72 0.670 [0.632–0.702] 0.039 [0.037–0.040] 0.663 [0.624–0.695] 0.583 [0.563–0.604] 0.238 [0.235–0.241] 0.575 [0.548–0.597]
84 0.665 [0.628–0.699] 0.043 [0.041–0.045] 0.656 [0.623–0.689] 0.583 [0.563–0.603] 0.248 [0.246–0.250] 0.575 [0.548–0.596]
96 0.664 [0.622–0.703] 0.046 [0.045–0.047] 0.649 [0.617–0.679] 0.589 [0.558–0.618] 0.255 [0.252–0.257] 0.574 [0.549–0.594]
108 0.665 [0.618–0.708] 0.048 [0.047–0.050] 0.642 [0.600–0.676] 0.594 [0.568–0.619] 0.259 [0.256–0.262] 0.573 [0.550–0.592]
Global - - 0.634 [0.583–0.684] - - 0.561 [0.513–0.591]
Code
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime

HORIZONS = [3, 6, 12, 24, 36, 48, 60, 72, 84, 96, 108]

plt.rcParams.update({'font.size': 14})


def plot_time_metric_joint(
    df_res,
    durations,
    metric,
    ylabel=None,
    title=None,
    colors={"Death": "#d62728", "Readm": "#1f77b4"},
    ylim=None,
    outdir="dh5"
):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    os.makedirs(outdir, exist_ok=True)

    # 1. Setup Figure (Giving more room to the table row)
    fig, (ax, ax_table) = plt.subplots(2, 1, figsize=(11, 8.5), 
                                       gridspec_kw={'height_ratios': [3, 1]},
                                       sharex=True)

    # --- TOP: PERFORMANCE PLOT ---
    dfp_all = df_res[df_res["Time"] != "Global"].copy()
    dfp_all["Time"] = pd.to_numeric(dfp_all["Time"])
    dfp_all = dfp_all[dfp_all["Time"].isin(HORIZONS)]

    for outcome, color in colors.items():
        dfp = (
            dfp_all.loc[(dfp_all["Metric"] == metric) & (dfp_all["Outcome"] == outcome)]
            .groupby("Time")["Value"]
            .agg(mean="mean", lo=lambda x: np.percentile(x, 2.5), hi=lambda x: np.percentile(x, 97.5))
            .reset_index().sort_values("Time")
        )
        if dfp.empty: continue
        ax.plot(dfp["Time"], dfp["mean"], color=color, label=outcome, 
                marker='o', markersize=7, linewidth=2, markeredgecolor='w', zorder=3)
        ax.fill_between(dfp["Time"], dfp["lo"], dfp["hi"], color=color, alpha=0.15, zorder=2)

    if metric in ["AUC-ROC", "Uno C-Index"]:
        ax.axhline(0.5, color="gray", linestyle="--", alpha=0.5)

    # --- FIX: FORCE X-LABELS TO STAY ---
    ax.tick_params(labelbottom=True) 
    plt.xticks(HORIZONS)
    
    ax.set_ylabel(ylabel if ylabel else metric, fontweight='bold')
    #ax.set_title(title if title else f"{metric} over Time", fontsize=14, pad=15)
    ax.legend(frameon=True, loc='best')
    ax.grid(True, linestyle=':', alpha=0.6)
    if ylim: ax.set_ylim(*ylim)

    # --- BOTTOM: THE STAGGERED TABLE ---
    ax_table.axis('off')
    ax_table.set_ylim(0, 1)
    
    # Label for the section
    ax_table.text(-5, 0.6, "At Risk:", fontweight='bold', va='center', ha='right', fontsize=12)
    
    nar_counts = [sum(durations >= h) for h in HORIZONS]
    
    for i, count in enumerate(nar_counts):
        # Intercalate: Even indices slightly higher than odd indices
        y_pos = 0.65 if i % 2 == 0 else 0.15
        
        # Add the count
        ax_table.text(HORIZONS[i], y_pos, f"{int(count):,}", 
                      ha='center', va='center', fontsize=12,
                      bbox=dict(facecolor='white', edgecolor='gray', boxstyle='round,pad=0.2', alpha=0.1))
        
        # Add a tiny connector line from the tick to the number
        ax_table.plot([HORIZONS[i], HORIZONS[i]], [y_pos + 0.1, 1.0], 
                      color='gray', linestyle='-', linewidth=0.5, alpha=0.3, clip_on=False)

    ax_table.set_xlabel("Months after Discharge", fontweight='bold', fontsize=11, labelpad=20)
    
    for spine in ['top', 'right']: ax.spines[spine].set_visible(False)

    plt.tight_layout()
    # Adjust hspace so the months and the "At Risk" numbers don't touch
    plt.subplots_adjust(hspace=0.25) 
    
    plt.savefig(f"{outdir}/DH5_Final_{metric}_{timestamp}.png", dpi=300, bbox_inches="tight")
    plt.savefig( f"{outdir}/DH5_Final_{metric}_{timestamp}.png".replace(".png", ".pdf"), bbox_inches="tight" )
    plt.show()

    
# --- Execution ---

# Note: Using "Brier Score" instead of "Brier Score (CR)" 
# based on your df_res glimpse showing "Brier Score"
plot_time_metric_joint(df_res, metric="AUC-ROC", ylabel="AUC-ROC", durations=times_cr, ylim=(0.5, 1.0))
plot_time_metric_joint(df_res, metric="Uno C-Index", ylabel="Uno C-Index", durations=times_cr, ylim=(0.5, 1.0))
plot_time_metric_joint(df_res, metric="Brier Score", ylabel="Brier Score", durations=times_cr, ylim=(0, 0.25))
plot_time_metric_joint(df_res, metric="F1", durations=times_cr, ylabel="F1 Score")
plot_time_metric_joint(df_res, metric="PPV", durations=times_cr, ylabel="PPV (Precision)")
plot_time_metric_joint(df_res, metric="NPV", durations=times_cr, ylabel="NPV")

# --- Additional Binary Metrics ---

plot_time_metric_joint(
    df_res, 
    metric="Sens", 
    ylabel="Sensitivity (Recall)", 
    title="Time-dependent Sensitivity", 
    durations=times_cr,
    ylim=(0, 1)
)

plot_time_metric_joint(
    df_res, 
    metric="Spec", 
    ylabel="Specificity", 
    title="Time-dependent Specificity", 
    durations=times_cr,
    ylim=(0, 1)
)

https://github.com/georgehc/survival-intro/blob/main/S6.1.4_DeepHit_competing.ipynb

Code
import shutil
import os
import pickle
import pandas as pd
import numpy as np

# 1. Define the Problematic Path (G: Drive)
# Note: Use raw string r"..." to handle backslashes safely
source_path = r"G:\My Drive\Alvacast\SISTRAT 2023\dh\DH4_Final_AJ_20260209_2329_RawPreds.pkl"

# 2. Define a Safe Local Destination (e.g., your main User folder)
# This puts it in C:\Users\YourName\temp_pred_file.pkl
dest_path = os.path.join(os.path.expanduser("~"), "temp_pred_file.pkl")

print(f"🔄 Attempting to rescue file from G: drive...")
print(f"   From: {source_path}")
print(f"   To:   {dest_path}")

try:
    # Copy the file locally (this forces Google Drive to download it)
    shutil.copyfile(source_path, dest_path)
    print("✅ Copy successful! File is now local.")

    # 3. Load the LOCAL copy
    print("📂 Loading data...")
    with open(dest_path, 'rb') as f:
        raw_log = pickle.load(f)
    
    print(f"✅ Success! Loaded {len(raw_log)} folds.")
    
    # 4. Run the Diagnostic on the clean data
    if len(raw_log) > 0:
        entry = raw_log[0]
        cif_preds = entry['cif_pred']
        
        print("\n📊 DATA DIAGNOSTIC:")
        print(f"   Heads Found: {len(cif_preds)}")
        
        # Check Head 0
        risk0 = cif_preds[0][-1, :]
        print(f"   Head 0 Max Risk: {np.mean(risk0):.4f} (Avg) | {np.max(risk0):.4f} (Max)")
        
        # Check Head 1
        risk1 = cif_preds[1][-1, :]
        print(f"   Head 1 Max Risk: {np.mean(risk1):.4f} (Avg) | {np.max(risk1):.4f} (Max)")

        # Clean up temp file (optional, keeps your drive clean)
        # os.remove(dest_path) 

except FileNotFoundError:
    print("\n❌ Error: Python cannot find the source file at all.")
    print("   Check: Is the G: drive connected? Is the filename exactly correct?")
except OSError as e:
    print(f"\n❌ Copy Failed: {e}")
    print("   Try manually dragging the file from G: to your Desktop and updating the path.")
🔄 Attempting to rescue file from G: drive...
   From: G:\My Drive\Alvacast\SISTRAT 2023\dh\DH4_Final_AJ_20260209_2329_RawPreds.pkl
   To:   C:\Users\andre\temp_pred_file.pkl
✅ Copy successful! File is now local.
📂 Loading data...
✅ Success! Loaded 50 folds.

📊 DATA DIAGNOSTIC:
   Heads Found: 3
   Head 0 Max Risk: 0.5681 (Avg) | 0.8436 (Max)
   Head 1 Max Risk: 0.2570 (Avg) | 0.4073 (Max)
Code
#@title 📊 Step 5b: Calibration (Swapped Heads & Local Load)
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from lifelines import AalenJohansenFitter
import pickle
import os

# --- CONFIGURATION ---
# 1. Use the LOCAL path you just confirmed works
LOCAL_PKL_PATH = r"C:\Users\andre\temp_pred_file.pkl"
TARGET_TIMES = [12, 24, 36, 48, 60]
RISK_GROUPS = 10

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
BASE_OUT = os.path.join(os.path.expanduser("~"), "dh5")
os.makedirs(BASE_OUT, exist_ok=True)

def get_calibration_data_swapped(raw_log, risk_id, time_point):
    all_pred = []
    
    # 🟢 CRITICAL FIX: SWAPPING THE HEADS
    # Based on your stats: 
    # Head 0 (Avg 57%) -> Readmission (Risk 2)
    # Head 1 (Avg 25%) -> Death (Risk 1)
    
    if risk_id == 2: # Readmission
        network_index = 0  # Use Head 0
    elif risk_id == 1: # Death
        network_index = 1  # Use Head 1
    else:
        return None

    for entry in raw_log:
        cif_list = entry['cif_pred'] 
        dur_idx = entry['duration_index']
        y_time = entry['y_time_val']
        y_event = entry['y_event_val']
        
        # Find Time Index
        idx_t = np.searchsorted(dur_idx, time_point)
        if idx_t >= len(dur_idx): idx_t = len(dur_idx) - 1
        
        # Extract Prob
        prob_event = cif_list[network_index][idx_t, :]
        
        df_fold = pd.DataFrame({
            'prob': prob_event,
            'time': y_time,
            'event': y_event
        })
        all_pred.append(df_fold)

    if not all_pred: return None

    # Reset Index
    df_all = pd.concat(all_pred).reset_index(drop=True)
    
    try:
        df_all['decile'] = pd.qcut(df_all['prob'], RISK_GROUPS, labels=False, duplicates='drop')
    except:
        return None 
        
    calibration_points = []
    ajf = AalenJohansenFitter(calculate_variance=False, jitter_level=0.001)
    
    for g in sorted(df_all['decile'].unique()):
        group = df_all[df_all['decile'] == g]
        mean_pred = group['prob'].mean()
        
        T = group['time']
        E = group['event']
        
        # Check if event exists
        if risk_id not in E.values:
            obs_freq = 0.0
        else:
            try:
                ajf.fit(T, E, event_of_interest=risk_id)
                if time_point > T.max():
                    obs_freq = ajf.predict(T.max()).item()
                else:
                    obs_freq = ajf.predict(time_point).item()
            except:
                obs_freq = np.nan
            
        calibration_points.append({'decile': g, 'mean_pred': mean_pred, 'obs_freq': obs_freq})
        
    return pd.DataFrame(calibration_points)

def plot_calibration_final(risk_name, risk_id, time_horizons):
    plt.figure(figsize=(8, 8))
    cmap = cm.get_cmap('viridis', len(time_horizons))
    max_val = 0
    
    # Load Data
    try:
        with open(LOCAL_PKL_PATH, 'rb') as f:
            raw_log = pickle.load(f)
    except Exception as e:
        print(f"❌ Could not load local file: {e}")
        return

    print(f"\n📈 Plotting {risk_name} (Risk ID: {risk_id})...")
    
    for i, t in enumerate(time_horizons):
        cal_df = get_calibration_data_swapped(raw_log, risk_id=risk_id, time_point=t)
        
        if cal_df is not None and not cal_df.empty:
            current_max = max(cal_df['mean_pred'].max(), cal_df['obs_freq'].max())
            if current_max > max_val: max_val = current_max
            
            plt.plot(cal_df['mean_pred'], cal_df['obs_freq'], 
                     marker='o', linewidth=2, color=cmap(i), label=f"{t} Months")
        else:
            print(f"   ⚠️ No data for {t} months")

    # Plot diagonal
    limit = min(1.0, max_val * 1.1)
    plt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label="Perfect")
    
    plt.xlim(0, limit)
    plt.ylim(0, limit)
    plt.xlabel("Predicted Probability")
    plt.ylabel("Observed Frequency (AJ)")
    #plt.title(f"Calibration: {risk_name}")
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.savefig(os.path.join(BASE_OUT, f"DH5_Calib_{risk_name}_{timestamp}.png"), dpi=300, bbox_inches="tight")
    plt.savefig(os.path.join(BASE_OUT, f"DH5_Calib_{risk_name}_{timestamp}.pdf"), bbox_inches="tight")
    plt.show()

print(os.path.join(BASE_OUT, f"DH5_Calib_{timestamp}.png"))

# --- EXECUTE ---
# We force the mapping: 
# Death (1) -> uses Head 1
# Readm (2) -> uses Head 0
plot_calibration_final('Death', 1, TARGET_TIMES)       
plot_calibration_final('Readmission', 2, TARGET_TIMES)
C:\Users\andre\dh5\DH5_Calib_20260210_1000.png

📈 Plotting Death (Risk ID: 1)...


📈 Plotting Readmission (Risk ID: 2)...

Code
#@title 📈 Take-Home Message: Time-Dependent Performance (DeepHit, AJ-IPCW)

import pandas as pd
from IPython.display import display

# --- UPDATE: Reflecting Table 2 Values ---
performance_msg = pd.DataFrame([

    # --- DISCRIMINATION (AUC / C-Index) ---
    {
        'Metric': 'Discrimination (AUC & Uno’s C)',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Strong Early Performance, Gradual Stabilization',
        'Interpretation': (
            'The model shows strong early discrimination for Death (AUC ~0.83 at 3m), stabilizing around 0.66-0.67 '
            'by Year 5. Readmission starts moderate (AUC ~0.74 at 3m) but declines faster, stabilizing near 0.58-0.59. '
            'This confirms that proximal events are driven by identifiable baseline factors, while long-term readmission '
            'risk becomes increasingly stochastic or dependent on unmeasured post-discharge events.'
        )
    },

    # --- CALIBRATION (Brier Score & Visual Inspection) ---
    {
        'Metric': 'Brier Score (Competing Risk) & Calibration Plots',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Misleadingly "Good" Metrics for Death vs. Robust Calibration for Readmission',
        'Interpretation': (
            'The Brier scores present a paradox: Death has low error scores (0.001–0.048) driven solely by event rarity, yet '
            'visual inspection reveals severe underestimation (e.g., observed risk reaches ~10% while predictions remain <1%). '
            'The model is overly conservative regarding mortality. '
            'Conversely, Readmission has higher Brier scores (up to 0.259) due to inherent clinical uncertainty, '
            'but the calibration plots are excellent, hugging the diagonal line. '
            'Conclusion: Readmission probabilities are quantitatively reliable and can be used for direct risk communication, '
            'whereas Death risk estimates are uncalibrated and should only be used for relative ranking.'
        )
    },
    # --- CLINICAL UTILITY (Implicit) ---
    {
        'Metric': 'Global Performance',
        'Outcome': 'Death vs Readmission',
        'Pattern': 'Death Dominance (C-Index 0.63 vs 0.56)',
        'Interpretation': (
            'Globally, the model is significantly better at ranking mortality risk (C-Index 0.634) than readmission risk (0.561). '
            'Clinically, this suggests the tool is best used as a "Mortality Alert" system rather than a precise readmission forecaster. '
            'The lower readmission performance highlights the need for dynamic (longitudinal) data updates rather than relying solely on admission baseline data.'
        )
    }
])

print("\n>>> TAKE-HOME MESSAGE: TIME-DEPENDENT PERFORMANCE (DeepHit, AJ-IPCW)")
pd.set_option('display.max_colwidth', None)

# Apply styling
styled_table = performance_msg.style.set_properties(**{
    'text-align': 'left',
    'white-space': 'pre-wrap',
    'font-size': '11pt'
}).set_table_styles([
    {'selector': 'th', 'props': [('background-color', '#f2f2f2'), ('font-weight', 'bold'), ('text-align', 'center')]},
    {'selector': 'td', 'props': [('vertical-align', 'top'), ('padding', '10px')]}
]).hide(axis="index")

display(styled_table)

>>> TAKE-HOME MESSAGE: TIME-DEPENDENT PERFORMANCE (DeepHit, AJ-IPCW)
Metric Outcome Pattern Interpretation
Discrimination (AUC & Uno’s C) Death vs Readmission Strong Early Performance, Gradual Stabilization The model shows strong early discrimination for Death (AUC ~0.83 at 3m), stabilizing around 0.66-0.67 by Year 5. Readmission starts moderate (AUC ~0.74 at 3m) but declines faster, stabilizing near 0.58-0.59. This confirms that proximal events are driven by identifiable baseline factors, while long-term readmission risk becomes increasingly stochastic or dependent on unmeasured post-discharge events.
Brier Score (Competing Risk) & Calibration Plots Death vs Readmission Misleadingly "Good" Metrics for Death vs. Robust Calibration for Readmission The Brier scores present a paradox: Death has low error scores (0.001–0.048) driven solely by event rarity, yet visual inspection reveals severe underestimation (e.g., observed risk reaches ~10% while predictions remain <1%). The model is overly conservative regarding mortality. Conversely, Readmission has higher Brier scores (up to 0.259) due to inherent clinical uncertainty, but the calibration plots are excellent, hugging the diagonal line. Conclusion: Readmission probabilities are quantitatively reliable and can be used for direct risk communication, whereas Death risk estimates are uncalibrated and should only be used for relative ranking.
Global Performance Death vs Readmission Death Dominance (C-Index 0.63 vs 0.56) Globally, the model is significantly better at ranking mortality risk (C-Index 0.634) than readmission risk (0.561). Clinically, this suggests the tool is best used as a "Mortality Alert" system rather than a precise readmission forecaster. The lower readmission performance highlights the need for dynamic (longitudinal) data updates rather than relying solely on admission baseline data.

⏱️ Elapsed time: 23.7663 minutes

SHAP

We generated the SHAP (SHapley Additive exPlanations) plots with separate sets for Death and Readmission, covering the four plot types (Bar, Beeswarm, Waterfall Case 1, Waterfall Case 2).

Code
# change wd
os.chdir("G:/My Drive/Alvacast/SISTRAT 2023/dh")
# Verify
print(os.getcwd())
G:\My Drive\Alvacast\SISTRAT 2023\dh
Code
SHAP values were scaled by 100 (converting them to percentages) right before plotting
Code
#@title 📊 Step 6: SHAP Plots (High/Low Risk + Percentage Scaling + Safe Path)
import pickle
import shap
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import os
import shutil

# --- CONFIGURATION ---
# 1. Source File (Assumes script is running from G: Drive folder)
# We use os.getcwd() to find the file in the current folder
G_DRIVE_PATH = os.path.join(os.getcwd(), "DH4_Final_AJ_20260209_2329_SHAP_MultiHorizon.pkl")

# 2. Setup SAFE Local Paths (On your C: Drive)
# This creates a folder "DH_Analysis" in your User folder to bypass G: drive errors
USER_HOME = os.path.expanduser("~") 
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
os.makedirs(WORK_DIR, exist_ok=True)

# Local destination for the pickle
LOCAL_PKL = os.path.join(WORK_DIR, "temp_shap_data.pkl")

# Output folder for images
OUTPUT_DIR = os.path.join(WORK_DIR, "dh6_plots")
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"📍 Working Locally at: {WORK_DIR}")
print(f"📂 Output Images will be saved to: {OUTPUT_DIR}")

# --- 1. RESCUE DATA (Copy from G: to C:) ---
if not os.path.exists(LOCAL_PKL):
    print(f"\n🔄 Copying file from G: Drive to local disk...")
    try:
        shutil.copyfile(G_DRIVE_PATH, LOCAL_PKL)
        print("   ✅ Copy successful!")
    except Exception as e:
        print(f"   ❌ Copy Failed: {e}")
        print("   👉 Solution: Manually copy the .pkl file to your C: drive and update the path.")
        raise
else:
    print(f"\n✅ Found local copy of data: {LOCAL_PKL}")

# Load the local copy
with open(LOCAL_PKL, 'rb') as f:
    shap_data_export = pickle.load(f)

# --- 2. PLOTTING FUNCTION ---
def generate_shap_plots(outcome_name, horizon, shap_dict):
    print(f"\n🎨 Generating plots for: {outcome_name} @ {horizon} Months")
    
    raw_values = shap_dict['shap_values']
    raw_data = shap_dict['data']
    feature_names = raw_data.columns.tolist()
    
    # 🟢 MINIMAL FIX: Scale values by 100 to show percentages
    # This turns "0.001" (which plots as +0) into "0.1" (which plots as +0.1)
    scaled_values = raw_values * 100 
    
    # Reconstruct Object using SCALED values
    expl = shap.Explanation(
        values=scaled_values,
        data=raw_data.values,
        feature_names=feature_names,
        base_values=np.zeros(scaled_values.shape[0])
    )
    
    # A. BAR PLOT
    plt.figure(figsize=(10, 8))
    shap.plots.bar(expl, show=False, max_display=15)
    plt.title(f"Feature Importance ({outcome_name} - {horizon}m)\n(Impact in % Risk Points)", fontsize=16)
    plt.savefig(f"{OUTPUT_DIR}/DH6_Bar_{outcome_name}_{horizon}m.png", bbox_inches='tight', dpi=300)
    plt.savefig(f"{OUTPUT_DIR}/DH6_Bar_{outcome_name}_{horizon}m.pdf", bbox_inches="tight")
    plt.show()
    plt.close()

    # B. BEESWARM PLOT
    plt.figure(figsize=(10, 8))
    shap.plots.beeswarm(expl, show=False, max_display=15)
    plt.title(f"Feature Impact ({outcome_name} - {horizon}m)", fontsize=16)
    plt.xlabel("SHAP value (Impact on % Risk)")
    plt.savefig(f"{OUTPUT_DIR}/DH6_Beeswarm_{outcome_name}_{horizon}m.png", bbox_inches='tight', dpi=300)
    plt.savefig(f"{OUTPUT_DIR}/DH6_Beeswarm_{outcome_name}_{horizon}m.pdf", bbox_inches="tight")
    plt.show()
    plt.close()

    # C. WATERFALL PLOTS
    # Find interesting cases (High Risk vs Low Risk)
    # Calculate total risk deviation for every patient using RAW values (sorting is same)
    total_risk_dev = np.sum(raw_values, axis=1)
    
    # 🟢 YOUR LOGIC: Highest vs Lowest Risk
    high_risk_idx = np.argsort(total_risk_dev)[-1]  # Max deviation (Highest Risk)
    low_risk_idx  = np.argsort(total_risk_dev)[0]   # Min deviation (Lowest Risk)

    cases = [
        ('HighRisk', high_risk_idx), 
        ('LowRisk',  low_risk_idx)
    ]

    for label, idx in cases:
        plt.figure(figsize=(10, 8))
        try:
            shap.plots.waterfall(expl[idx], show=False, max_display=12)
            
            # 🟢 TITLE: Includes Row ID and % label
            plt.title(f"{label} Patient (ID: {idx}) - {outcome_name} @ {horizon}m\n(Impact in % Risk Points)", fontsize=14)
            
            fname = f"{OUTPUT_DIR}/DH6_Waterfall_{outcome_name}_{horizon}m_{label}_ID{idx}"
            plt.savefig(fname + ".png", bbox_inches='tight', dpi=300)
            plt.savefig(fname + ".pdf", bbox_inches='tight')
            plt.show()
            print(f"   ✅ Saved: {label} (ID: {idx})")
        except Exception as e:
            print(f"   ⚠️ Failed: {e}")
        finally:
            plt.close()    

# --- 3. EXECUTION ---
outcomes = ['death', 'readm']
count = 0

for outcome in outcomes:
    if outcome in shap_data_export:
        horizons_dict = shap_data_export[outcome]
        for horizon, data_dict in horizons_dict.items():
            if data_dict and 'shap_values' in data_dict:
                generate_shap_plots(outcome, horizon, data_dict)
                count += 1

print(f"\n🏁 Done! Generated plots for {count} scenarios.")
print(f"📂 Open this folder to see images: {OUTPUT_DIR}")

# --- 4. OPTIONAL: COPY BACK TO G: DRIVE ---
# Tries to copy images back to Google Drive (if accessible)
try:
    dest_folder = os.path.join(os.getcwd(), "dh6")
    shutil.copytree(OUTPUT_DIR, dest_folder, dirs_exist_ok=True)
    print(f"📤 Uploaded images back to Google Drive folder: {dest_folder}")
except Exception as e:
    print(f"⚠️ Could not copy back to G: drive (likely due to device error): {e}")
    print("   ✅ Your images are safe on your C: drive at the path above.")
📍 Working Locally at: C:\Users\andre\DH_Analysis
📂 Output Images will be saved to: C:\Users\andre\DH_Analysis\dh6_plots

✅ Found local copy of data: C:\Users\andre\DH_Analysis\temp_shap_data.pkl

🎨 Generating plots for: death @ 3 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: death @ 12 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: death @ 24 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: death @ 48 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: death @ 60 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: death @ 72 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 3 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 12 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 24 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 48 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 60 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🎨 Generating plots for: readm @ 72 Months

   ✅ Saved: HighRisk (ID: 87)

   ✅ Saved: LowRisk (ID: 447)

🏁 Done! Generated plots for 12 scenarios.
📂 Open this folder to see images: C:\Users\andre\DH_Analysis\dh6_plots
📤 Uploaded images back to Google Drive folder: G:\My Drive\Alvacast\SISTRAT 2023\dh\dh6

Export to excel

  1. Computes per-horizon SHAP importance for every predictor and outcome
  2. Classifies predictors as risk, protective, or mixed via SHAP–value correlation
  3. Quantifies robustness using variability of SHAP effects across patients
  4. Produces ranked, time-specific predictor tables for death and readmission
  5. Exports publication-ready Excel reports with automated local + Drive saving
Code
#@title 📊 Step 6.3: Export Detailed Predictor Analysis (Excel) with Auto-Upload
import pandas as pd
import numpy as np
import pickle
import os
import shutil
from scipy.stats import pearsonr
from datetime import datetime

# --- 1. CONFIGURATION ---
# Use the safe local path (C: Drive) to avoid G: drive write errors
USER_HOME = os.path.expanduser("~")
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
LOCAL_PKL = os.path.join(WORK_DIR, "temp_shap_data.pkl") 

# Output Filename
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_FILENAME = f"DH64_Predictor_Analysis_{TIMESTAMP}.xlsx"
EXCEL_PATH = os.path.join(WORK_DIR, EXCEL_FILENAME)

print(f"📍 Working Directory: {WORK_DIR}")

# --- 2. HELPER FUNCTIONS ---
def get_direction_and_robustness(feature_values, shap_values):
    """
    Determines if a feature is a Risk Factor vs Protective based on correlation.
    """
    # Robustness: Std Dev of the absolute impact
    robustness = np.std(np.abs(shap_values))

    # Direction: Correlation between Feature Value and SHAP Value
    try:
        if np.std(feature_values) < 1e-6:
            return "Neutral (Constant)", 0.0, 0.0

        corr, _ = pearsonr(feature_values, shap_values)

        if corr > 0.1: direction = "Risk Factor (↑)"     # Higher Value = Higher Risk
        elif corr < -0.1: direction = "Protective (↓)"   # Higher Value = Lower Risk
        else: direction = "Non-Linear/Mixed"

        return direction, corr, robustness
    except:
        return "Unknown", 0.0, robustness

def analyze_horizon(outcome_name, horizon, data_dict):
    """
    Analyzes all features for a specific outcome and time horizon.
    """
    print(f"   ... Processing {outcome_name} @ {horizon}m")
    
    X_data = data_dict['data']          
    shap_vals = data_dict['shap_values'] 
    feature_names = X_data.columns.tolist()
    
    metrics_list = []
    
    for i, fname in enumerate(feature_names):
        f_values = X_data.iloc[:, i].values
        s_values = shap_vals[:, i]
        
        # A. Global Importance (Mean |SHAP|) x 100 for %
        importance = np.mean(np.abs(s_values)) * 100 
        
        # B. Direction
        direction, corr, robust = get_direction_and_robustness(f_values, s_values)
        
        metrics_list.append({
            'Horizon': f"{horizon} Months",
            'Feature': fname,
            'Importance (%)': importance,
            'Direction': direction,
            'Correlation': corr,
            'Robustness (SD)': robust * 100 
        })
        
    df = pd.DataFrame(metrics_list)
    df = df.sort_values(by='Importance (%)', ascending=False)
    df.insert(0, 'Rank', range(1, len(df) + 1)) 
    return df

# --- 3. MAIN EXECUTION ---
if not os.path.exists(LOCAL_PKL):
    print("❌ Error: SHAP pickle file not found in C: drive.")
    print("   Run the plotting script (Step 6) first to copy the data locally.")
else:
    print(f"📂 Loading SHAP data from: {LOCAL_PKL}")
    with open(LOCAL_PKL, 'rb') as f:
        shap_data_export = pickle.load(f)

    all_sheets = {}

    for outcome in ['death', 'readm']:
        if outcome not in shap_data_export: continue
        
        print(f"⚡ Analyzing Outcome: {outcome}...")
        outcome_dfs = []
        
        horizons = sorted(shap_data_export[outcome].keys())
        for h in horizons:
            data_dict = shap_data_export[outcome][h]
            if data_dict and 'shap_values' in data_dict:
                df_h = analyze_horizon(outcome, h, data_dict)
                outcome_dfs.append(df_h)
        
        if outcome_dfs:
            full_df = pd.concat(outcome_dfs)
            sheet_name = f"{outcome.capitalize()} Predictors"
            all_sheets[sheet_name] = full_df

    # Export to Excel (Locally First)
    if all_sheets:
        print(f"\n💾 Saving Analysis to local Excel...")
        with pd.ExcelWriter(EXCEL_PATH, engine='xlsxwriter') as writer:
            
            for sheet_name, df in all_sheets.items():
                df.to_excel(writer, sheet_name=sheet_name, index=False)
                worksheet = writer.sheets[sheet_name]
                worksheet.set_column('A:A', 5)   
                worksheet.set_column('B:B', 15)  
                worksheet.set_column('C:C', 35)  
                worksheet.set_column('D:D', 15)  
                worksheet.set_column('E:E', 20)  
                
            meta_data = pd.DataFrame([
                {'Metric': 'Importance (%)', 'Definition': 'Mean absolute impact on risk probability (scaled 0-100%).'},
                {'Metric': 'Direction', 'Definition': 'Risk Factor (↑) vs Protective (↓).'},
                {'Metric': 'Correlation', 'Definition': 'Pearson correlation (+1 = Strong Risk, -1 = Strong Protective).'},
                {'Metric': 'Robustness (SD)', 'Definition': 'Standard deviation of impact across patients.'}
            ])
            meta_data.to_excel(writer, sheet_name='Definitions', index=False)
            writer.sheets['Definitions'].set_column('A:B', 60)

        print(f"✅ Report saved locally: {EXCEL_PATH}")

        # --- 4. COPY BACK TO G: DRIVE ---
        print("\n📤 Uploading to Google Drive...")
        try:
            # We copy specifically the EXCEL file to the current G: folder
            dest_path = os.path.join(os.getcwd(), EXCEL_FILENAME)
            shutil.copy(EXCEL_PATH, dest_path)
            print(f"✅ Success! Excel is now available on G: Drive at:")
            print(f"   {dest_path}")
        except Exception as e:
            print(f"⚠️ Could not upload to G: drive: {e}")
            print(f"   👉 You can find the file manually at: {EXCEL_PATH}")

    else:
        print("⚠️ No valid SHAP data found to export.")
📍 Working Directory: C:\Users\andre\DH_Analysis
📂 Loading SHAP data from: C:\Users\andre\DH_Analysis\temp_shap_data.pkl
⚡ Analyzing Outcome: death...
   ... Processing death @ 3m
   ... Processing death @ 12m
   ... Processing death @ 24m
   ... Processing death @ 48m
   ... Processing death @ 60m
   ... Processing death @ 72m
⚡ Analyzing Outcome: readm...
   ... Processing readm @ 3m
   ... Processing readm @ 12m
   ... Processing readm @ 24m
   ... Processing readm @ 48m
   ... Processing readm @ 60m
   ... Processing readm @ 72m

💾 Saving Analysis to local Excel...
✅ Report saved locally: C:\Users\andre\DH_Analysis\DH64_Predictor_Analysis_20260210_1320.xlsx

📤 Uploading to Google Drive...
✅ Success! Excel is now available on G: Drive at:
   G:\My Drive\Alvacast\SISTRAT 2023\dh\DH64_Predictor_Analysis_20260210_1320.xlsx

Summary of SHAP influences

  1. Integrates SHAP influences across all time horizons into one global feature ranking
  2. Uses mean absolute SHAP to measure lifetime importance of each covariate
  3. Quantifies time variability to detect early-only vs persistent predictors
  4. Determines direction via SHAP–feature correlation (risk vs protective)
  5. Separates analyses for death and readmission outcomes
  6. Aggregates evidence across patients and time, not single snapshots
  7. Produces Cox-style candidate features grounded in DeepHit explanations
  8. Outputs publication-ready Excel tables with rankings and metadata
  9. Flags unstable features with high temporal variability
  10. Translates complex SHAP dynamics into interpretable survival predictors
Code
#@title 📊 Step 7: Generate Time-Integrated Feature Importance Table (for CoxPH)
import pandas as pd
import numpy as np
import pickle
import os
from scipy.stats import pearsonr
from datetime import datetime

# --- CONFIGURATION ---
USER_HOME = os.path.expanduser("~")
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
LOCAL_PKL = os.path.join(WORK_DIR, "temp_shap_data.pkl") 
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_PATH = os.path.join(WORK_DIR, f"DH7_CoxPH_Candidate_Features_{TIMESTAMP}.xlsx")

print(f"📍 Working Directory: {WORK_DIR}")

# --- HELPER: Directionality Check ---
def get_global_direction(feature_values, shap_matrix):
    """
    Checks correlation across all time points to find a consistent direction.
    """
    # Flatten SHAP matrix to 1D (all patients, all times stacked) 
    # This is a simplification; we usually just check mean correlation
    
    # We will use the average SHAP value per patient across time for the correlation
    mean_shap_per_patient = np.mean(shap_matrix, axis=0) # Shape: (N_features,) -> Wait, input is (Times, Patients, Features)
    
    # Actually, we process feature by feature
    directions = []
    return "Calculated in loop"

# --- MAIN ANALYSIS ---
if not os.path.exists(LOCAL_PKL):
    print("❌ Error: Data not found. Run Step 6 first.")
else:
    print(f"📂 Loading SHAP data...")
    with open(LOCAL_PKL, 'rb') as f:
        shap_data_export = pickle.load(f)

    with pd.ExcelWriter(EXCEL_PATH, engine='xlsxwriter') as writer:
        
        for outcome in ['death', 'readm']:
            if outcome not in shap_data_export: continue
            
            print(f"⚡ Integrating Time Horizons for: {outcome}...")
            
            # 1. Collect Data Across All Horizons
            horizons = sorted(shap_data_export[outcome].keys())
            
            # Dictionary to store {feature_name: [list of importance scores over time]}
            feature_map = {}
            feature_directions = {}
            
            # We need the raw feature data (X) to check direction (High Value = Risk?)
            # We assume X is the same for all horizons (it is static data)
            first_h = horizons[0]
            X_data = shap_data_export[outcome][first_h]['data']
            feature_names = X_data.columns.tolist()
            
            # Initialize storage
            for f in feature_names:
                feature_map[f] = []
                feature_directions[f] = []

            # 2. Iterate Time and Accumulate Evidence
            for h in horizons:
                data_dict = shap_data_export[outcome][h]
                if not data_dict or 'shap_values' not in data_dict: continue
                
                raw_shap = data_dict['shap_values'] # (Patients, Features)
                
                for i, fname in enumerate(feature_names):
                    # A. Magnitude (Importance)
                    # We use Mean Absolute SHAP * 100 (percentage points)
                    imp = np.mean(np.abs(raw_shap[:, i])) * 100
                    feature_map[fname].append(imp)
                    
                    # B. Direction (Correlation at this specific time)
                    f_values = X_data.iloc[:, i].values
                    if np.std(f_values) > 1e-6:
                        corr, _ = pearsonr(f_values, raw_shap[:, i])
                        feature_directions[fname].append(corr)
                    else:
                        feature_directions[fname].append(0)

            # 3. Aggregate into a Global Summary
            summary_list = []
            for fname in feature_names:
                # Average Importance across all evaluated time points
                avg_importance = np.mean(feature_map[fname])
                
                # Consistency (Standard Deviation of importance over time)
                # High STD = Feature is only important at specific times (e.g. early shock)
                time_variability = np.std(feature_map[fname])
                
                # Average Direction
                avg_corr = np.mean(feature_directions[fname])
                if avg_corr > 0.05: direction = "Risk (↑)"
                elif avg_corr < -0.05: direction = "Protective (↓)"
                else: direction = "Mixed/Neutral"
                
                summary_list.append({
                    'Feature': fname,
                    'Global Importance (Mean %)': avg_importance,
                    'Time Variability': time_variability,
                    'Direction': direction,
                    'Avg Correlation': avg_corr
                })
            
            # 4. create DataFrame & Rank
            df_final = pd.DataFrame(summary_list)
            df_final = df_final.sort_values(by='Global Importance (Mean %)', ascending=False)
            df_final.insert(0, 'Rank', range(1, len(df_final) + 1))
            
            # 5. Save to Excel
            sheet_name = f"{outcome.capitalize()} - Integrated"
            df_final.to_excel(writer, sheet_name=sheet_name, index=False)
            
            # Formatting
            worksheet = writer.sheets[sheet_name]
            worksheet.set_column('B:B', 35) # Feature
            worksheet.set_column('C:C', 20) # Importance
            worksheet.set_column('D:D', 15) # Variability
            
            # 6. Interpret for the User
            print(f"   ✅ Top 5 for {outcome}: {df_final['Feature'].iloc[:5].tolist()}")

        # Metadata Sheet
        meta = pd.DataFrame([
            {'Metric': 'Global Importance', 'Definition': 'Average SHAP impact (in %) averaged across ALL time horizons (12m-96m). Represents "Lifetime Importance".'},
            {'Metric': 'Time Variability', 'Definition': 'How much the importance changes over time. Low = Consistent predictor. High = Important only at specific times (e.g., short-term).'},
            {'Metric': 'Direction', 'Definition': 'Overall tendency. Risk = Higher value increases hazard. Protective = Higher value decreases hazard.'}
        ])
        meta.to_excel(writer, sheet_name='Legend', index=False)
        writer.sheets['Legend'].set_column('A:B', 60)

    print(f"\n💾 Summary Table Saved: {EXCEL_PATH}")
    
    # --- AUTO-UPLOAD TO G: DRIVE ---
    try:
        dest_path = os.path.join(os.getcwd(), f"DH7_Integrated_CoxPH_Candidate_Features_{TIMESTAMP}.xlsx")
        shutil.copy(EXCEL_PATH, dest_path)
        print(f"📤 Uploaded to G: Drive: {dest_path}")
    except:
        print("⚠️ Could not upload to G: Drive (Device error). File is on C: Drive.")
📍 Working Directory: C:\Users\andre\DH_Analysis
📂 Loading SHAP data...
⚡ Integrating Time Horizons for: death...
   ✅ Top 5 for death: ['ethnicity', 'tipo_de_vivienda_rec2_other_unknown', 'primary_sub_mod_marijuana', 'eva_consumo', 'dit_m']
⚡ Integrating Time Horizons for: readm...
   ✅ Top 5 for readm: ['eva_consumo', 'cohabitation_others', 'dit_m', 'ethnicity', 'adm_age_rec3']

💾 Summary Table Saved: C:\Users\andre\DH_Analysis\DH7_CoxPH_Candidate_Features_20260210_1332.xlsx
📤 Uploaded to G: Drive: G:\My Drive\Alvacast\SISTRAT 2023\dh\DH7_Integrated_CoxPH_Candidate_Features_20260210_1332.xlsx
Code
#@title 📋 Step 7.2, Final Summary Table: Predictors (HTML, Positron-safe)

import pandas as pd
from IPython.display import display

# 1. Create the Summary Data
data = [
    {
        "Rank": 1,
        "Feature": "dit_m (Treatment Duration)",
        "Role": "⚠️ Risk Factor",
        "Stability": "High (Consistent)",
        "Interpretation": "Dominant predictor for both death and readmission across all time horizons. Higher values significantly elevate risk, with importance amplifying in longer-term follow-ups (e.g., from short-term immediate effects to cumulative long-term vulnerability). Represents a key modifiable clinical factor tied to treatment adherence and illness severity."
    },
    {
        "Rank": 2,
        "Feature": "ethnicity",
        "Role": "⚠️ Risk Factor",
        "Stability": "High",
        "Interpretation": "Consistently ranks in the top 3 for both outcomes, with strong positive correlations (often >0.80). Indicates potential disparities in access to care or socio-cultural stressors; importance remains stable but scales with horizon length, highlighting enduring systemic influences."
    },
    {
        "Rank": 3,
        "Feature": "adm_age_rec3 (Age at Admission)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Top predictor emphasizing older age as a universal vulnerability. Correlation around 0.60 for death and readmission; becomes more prominent in mid-to-long-term horizons (12-72 months), likely due to compounding health comorbidities over time."
    },
    {
        "Rank": 4,
        "Feature": "plan_type_corr_m_pr (Women-only Residential Treatment Plan Type)",
        "Role": "⚠️ Risk Factor",
        "Stability": "High",
        "Interpretation": "Frequently in top 5, with high correlations (>0.85). Specific plan types may indicate higher-risk patient profiles or less effective interventions; consistent across horizons but with amplified importance long-term, suggesting need for plan optimization."
    },
    {
        "Rank": 5,
        "Feature": "primary_sub_mod_marijuana (Primary Substance: Marijuana)",
        "Role": "🛡️ Protective",
        "Stability": "Low (Consistent)",
        "Interpretation": "Strong protective effect compared to harder substances like alcohol, with negative correlations (~-0.80). Appears in top 10-15 ranks; stable role across all horizons, implying lower mortality and recidivism risks for marijuana-primary users."
    },
    {
        "Rank": 6,
        "Feature": "any_phys_dx (Any Physical Diagnosis)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Common in top 10 for both outcomes, correlation ~0.70. Highlights comorbidities as drivers; more influential in mid-term horizons (12-48 months), where physical health deterioration impacts recovery."
    },
    {
        "Rank": 7,
        "Feature": "ed_attainment_corr (Education Attainment)",
        "Role": "🛡️ Protective",
        "Stability": "Moderate",
        "Interpretation": "Protective factor with negative correlations (~-0.50). Ranks higher in mid-term; suggests higher education correlates with better coping mechanisms or access to resources, reducing long-term risks."
    },
    {
        "Rank": 8,
        "Feature": "primary_sub_mod_alcohol (Primary Substance: Alcohol)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Low (Consistent)",
        "Interpretation": "Often non-linear/mixed but with positive correlations in key horizons. Stable driver of higher mortality and readmission vs. other substances; consistent low-to-mid rank, indicating specific intervention needs for alcohol users."
    },
    {
        "Rank": 9,
        "Feature": "occupation_condition_corr24_unemployed (Unemployed Status)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Top 10-15 risk factor, correlation ~0.43. Gains relevance in longer horizons; unemployment likely exacerbates stress and relapse, pointing to socio-economic interventions."
    },
    {
        "Rank": 10,
        "Feature": "porc_pobr (Poverty %)",
        "Role": "🛡️ Protective",
        "Stability": "High (Contextual)",
        "Interpretation": "Appears sporadically but as protective (negative correlation ~ -0.30). Likely proxies for eligibility in aid programs; more evident in long-term, enhancing safety nets in vulnerable populations."
    }
]

# 2. Convert to DataFrame
df_summary = pd.DataFrame(data)

# 3. Style for HTML display (Positron-safe)
styled_table = (
    df_summary.style
    .set_caption("📊 Table: Five-Year Integrated Predictors (DeepHit → CoxPH)")
    .set_properties(**{
        "text-align": "left",
        "white-space": "pre-wrap",
        "font-size": "14px"
    })
    .set_table_styles([
        {"selector": "caption", "props": [
            ("font-size", "16px"),
            ("font-weight", "bold"),
            ("margin-bottom", "10px")
        ]},
        {"selector": "th", "props": [
            ("background-color", "#f4f4f4"),
            ("border-bottom", "2px solid #555"),
            ("text-align", "center")
        ]},
        {"selector": "td", "props": [
            ("padding", "8px")
        ]},
        {"selector": "tr:hover", "props": [
            ("background-color", "#f9f9f9")
        ]}
    ])
)

# 4. Display
display(styled_table)
Table 2: 📊 Table: Five-Year Integrated Predictors (DeepHit → CoxPH)
  Rank Feature Role Stability Interpretation
0 1 dit_m (Treatment Duration) ⚠️ Risk Factor High (Consistent) Dominant predictor for both death and readmission across all time horizons. Higher values significantly elevate risk, with importance amplifying in longer-term follow-ups (e.g., from short-term immediate effects to cumulative long-term vulnerability). Represents a key modifiable clinical factor tied to treatment adherence and illness severity.
1 2 ethnicity ⚠️ Risk Factor High Consistently ranks in the top 3 for both outcomes, with strong positive correlations (often >0.80). Indicates potential disparities in access to care or socio-cultural stressors; importance remains stable but scales with horizon length, highlighting enduring systemic influences.
2 3 adm_age_rec3 (Age at Admission) ⚠️ Risk Factor Moderate Top predictor emphasizing older age as a universal vulnerability. Correlation around 0.60 for death and readmission; becomes more prominent in mid-to-long-term horizons (12-72 months), likely due to compounding health comorbidities over time.
3 4 plan_type_corr_m_pr (Treatment Plan Type) ⚠️ Risk Factor High Frequently in top 5, with high correlations (>0.85). Specific plan types may indicate higher-risk patient profiles or less effective interventions; consistent across horizons but with amplified importance long-term, suggesting need for plan optimization.
4 5 primary_sub_mod_marijuana (Primary Substance: Marijuana) 🛡️ Protective Low (Consistent) Strong protective effect compared to harder substances like alcohol, with negative correlations (~-0.80). Appears in top 10-15 ranks; stable role across all horizons, implying lower mortality and recidivism risks for marijuana-primary users.
5 6 any_phys_dx (Any Physical Diagnosis) ⚠️ Risk Factor Moderate Common in top 10 for both outcomes, correlation ~0.70. Highlights comorbidities as drivers; more influential in mid-term horizons (12-48 months), where physical health deterioration impacts recovery.
6 7 ed_attainment_corr (Education Attainment) 🛡️ Protective Moderate Protective factor with negative correlations (~-0.50). Ranks higher in mid-term; suggests higher education correlates with better coping mechanisms or access to resources, reducing long-term risks.
7 8 primary_sub_mod_alcohol (Primary Substance: Alcohol) ⚠️ Risk Factor Low (Consistent) Often non-linear/mixed but with positive correlations in key horizons. Stable driver of higher mortality and readmission vs. other substances; consistent low-to-mid rank, indicating specific intervention needs for alcohol users.
8 9 occupation_condition_corr24_unemployed (Unemployed Status) ⚠️ Risk Factor Moderate Top 10-15 risk factor, correlation ~0.43. Gains relevance in longer horizons; unemployment likely exacerbates stress and relapse, pointing to socio-economic interventions.
9 10 porc_pobr (Poverty %) 🛡️ Protective High (Contextual) Appears sporadically but as protective (negative correlation ~ -0.30). Likely proxies for eligibility in aid programs; more evident in long-term, enhancing safety nets in vulnerable populations.
Code
import pandas as pd
from IPython.display import display

# 1. Create the Summary Data
data = [
    {
        "Rank": 1,
        "Feature": "dit_m (Treatment Duration)",
        "Role": "⚠️ Risk Factor",
        "Stability": "High (Low Variability)",
        "Interpretation": "Prominent non-anomalous predictor for both death and readmission in the integrated CoxPH model. Global importance around 0.015% with low time variability (0.010-0.011), indicating consistent impact across horizons. Strong positive average correlation (~0.35), emphasizing longer durations as a sustained risk driver; modifiable through enhanced treatment retention strategies."
    },
    {
        "Rank": 2,
        "Feature": "ethnicity",
        "Role": "⚠️ Risk Factor",
        "Stability": "Low (High Variability)",
        "Interpretation": "Ranks highly in readmission (global importance 0.010%) and anomalously high in death, with extreme variability suggesting time-specific bursts. Average correlation >0.85 in readmission and 0.68 in death; highlights potential disparities in care access or socio-cultural factors, with amplified effects in certain periods."
    },
    {
        "Rank": 3,
        "Feature": "adm_age_rec3 (Age at Admission)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Consistent top feature across outcomes, global importance ~0.009% and moderate variability (0.006). High average correlation (~0.61), underscoring older age as a compounding risk factor over time, particularly in long-term horizons due to accumulating comorbidities."
    },
    {
        "Rank": 4,
        "Feature": "plan_type_corr_m_pr (Women-only Residential Treatment Plan Type)",
        "Role": "⚠️ Risk Factor",
        "Stability": "High",
        "Interpretation": "Key treatment-related risk, with global importance 0.007% and low variability (0.005). Very high average correlation (>0.87), indicating this plan type associates with elevated hazards; stable across time, suggesting opportunities for intervention through plan refinement."
    },
    {
        "Rank": 5,
        "Feature": "primary_sub_mod_marijuana (Primary Substance: Marijuana)",
        "Role": "🛡️ Protective",
        "Stability": "Moderate (Higher Variability)",
        "Interpretation": "Strong protective role in both outcomes, global importance ~0.031% in death and 0.004% in readmission, with moderate-high variability (0.061 in death). Negative average correlation (~-0.72), showing marijuana as primary substance reduces risks compared to alcohol or others; variability implies context-dependent protection."
    },
    {
        "Rank": 6,
        "Feature": "any_phys_dx (Any Physical Diagnosis)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Ranks in top 10-11, global importance ~0.005% with moderate variability (0.003). Strong positive correlation (~0.69), highlighting physical comorbidities as a major driver; more influential in mid-term horizons where health decline affects outcomes."
    },
    {
        "Rank": 7,
        "Feature": "ed_attainment_corr (Education Attainment)",
        "Role": "🛡️ Protective",
        "Stability": "Moderate",
        "Interpretation": "Protective across models, global importance ~0.004% with moderate variability (0.003). Negative correlation (~-0.50), linking higher education to better resilience and resource access; consistent mid-rank, with stable protective effects over time."
    },
    {
        "Rank": 8,
        "Feature": "primary_sub_mod_alcohol (Primary Substance: Alcohol)",
        "Role": "⚠️ Risk Factor",
        "Stability": "High (Low Variability)",
        "Interpretation": "Stable risk factor, global importance ~0.004% with low variability (0.002-0.003). Positive correlation (~0.11), marking alcohol as a specific driver of higher hazards vs. other substances; consistent influence suggests targeted therapies for alcohol users."
    },
    {
        "Rank": 9,
        "Feature": "occupation_condition_corr24_unemployed (Unemployed Status)",
        "Role": "⚠️ Risk Factor",
        "Stability": "Moderate",
        "Interpretation": "Top 10-11 socio-economic risk, global importance ~0.004% with moderate variability (0.003). Correlation ~0.39, indicating unemployment exacerbates stress and relapse; gains prominence in longer horizons, advocating for employment support programs."
    },
    {
        "Rank": 10,
        "Feature": "porc_pobr (Poverty %)",
        "Role": "🛡️ Protective",
        "Stability": "High (Contextual)",
        "Interpretation": "Protective in both outcomes, global importance ~0.002% with low variability (0.002). Negative correlation (~-0.32), likely serving as a proxy for aid eligibility and safety nets; contextual stability enhances its role in vulnerable groups over time."
    }
]

# 2. Convert to DataFrame
df_summary = pd.DataFrame(data)

# 3. Style for HTML display (Positron-safe)
styled_table = (
    df_summary.style
    .set_caption("📊 Table: Five-Year Integrated Predictors (DeepHit → CoxPH)")
    .set_properties(**{
        "text-align": "left",
        "white-space": "pre-wrap",
        "font-size": "14px"
    })
    .set_table_styles([
        {"selector": "caption", "props": [
            ("font-size", "16px"),
            ("font-weight", "bold"),
            ("margin-bottom", "10px")
        ]},
        {"selector": "th", "props": [
            ("background-color", "#f4f4f4"),
            ("border-bottom", "2px solid #555"),
            ("text-align", "center")
        ]},
        {"selector": "td", "props": [
            ("padding", "8px")
        ]},
        {"selector": "tr:hover", "props": [
            ("background-color", "#f9f9f9")
        ]}
    ])
)

# 4. Display
display(styled_table)
Table 3: 📊 Table: Five-Year Integrated Predictors (DeepHit → CoxPH)
  Rank Feature Role Stability Interpretation
0 1 dit_m (Likely Treatment/Illness Duration) ⚠️ Risk Factor High (Low Variability) Prominent non-anomalous predictor for both death and readmission in the integrated CoxPH model. Global importance around 0.015% with low time variability (0.010-0.011), indicating consistent impact across horizons. Strong positive average correlation (~0.35), emphasizing longer durations as a sustained risk driver; modifiable through enhanced treatment retention strategies.
1 2 ethnicity ⚠️ Risk Factor Low (High Variability) Ranks highly in readmission (global importance 0.010%) and anomalously high in death, with extreme variability suggesting time-specific bursts. Average correlation >0.85 in readmission and 0.68 in death; highlights potential disparities in care access or socio-cultural factors, with amplified effects in certain periods.
2 3 adm_age_rec3 (Age at Admission) ⚠️ Risk Factor Moderate Consistent top feature across outcomes, global importance ~0.009% and moderate variability (0.006). High average correlation (~0.61), underscoring older age as a compounding risk factor over time, particularly in long-term horizons due to accumulating comorbidities.
3 4 plan_type_corr_m_pr (Treatment Plan Type) ⚠️ Risk Factor High Key treatment-related risk, with global importance 0.007% and low variability (0.005). Very high average correlation (>0.87), indicating this plan type associates with elevated hazards; stable across time, suggesting opportunities for intervention through plan refinement.
4 5 primary_sub_mod_marijuana (Primary Substance: Marijuana) 🛡️ Protective Moderate (Higher Variability) Strong protective role in both outcomes, global importance ~0.031% in death and 0.004% in readmission, with moderate-high variability (0.061 in death). Negative average correlation (~-0.72), showing marijuana as primary substance reduces risks compared to alcohol or others; variability implies context-dependent protection.
5 6 any_phys_dx (Any Physical Diagnosis) ⚠️ Risk Factor Moderate Ranks in top 10-11, global importance ~0.005% with moderate variability (0.003). Strong positive correlation (~0.69), highlighting physical comorbidities as a major driver; more influential in mid-term horizons where health decline affects outcomes.
6 7 ed_attainment_corr (Education Attainment) 🛡️ Protective Moderate Protective across models, global importance ~0.004% with moderate variability (0.003). Negative correlation (~-0.50), linking higher education to better resilience and resource access; consistent mid-rank, with stable protective effects over time.
7 8 primary_sub_mod_alcohol (Primary Substance: Alcohol) ⚠️ Risk Factor High (Low Variability) Stable risk factor, global importance ~0.004% with low variability (0.002-0.003). Positive correlation (~0.11), marking alcohol as a specific driver of higher hazards vs. other substances; consistent influence suggests targeted therapies for alcohol users.
8 9 occupation_condition_corr24_unemployed (Unemployed Status) ⚠️ Risk Factor Moderate Top 10-11 socio-economic risk, global importance ~0.004% with moderate variability (0.003). Correlation ~0.39, indicating unemployment exacerbates stress and relapse; gains prominence in longer horizons, advocating for employment support programs.
9 10 porc_pobr (Poverty %) 🛡️ Protective High (Contextual) Protective in both outcomes, global importance ~0.002% with low variability (0.002). Negative correlation (~-0.32), likely serving as a proxy for aid eligibility and safety nets; contextual stability enhances its role in vulnerable groups over time.
  • Correlation= It tells you the direction and rough strength of the linear relationship between higher/lower values of that variable and higher/lower probability of the bad outcome (death or readmission) [average/pooled correlations across time horizons or overall in the integrated model].

Interaction

The following code: 1. Automatically scans SHAP outputs to discover feature–feature interactions 2. Uses SHAP residuals to isolate interaction effects from main effects 3. Tests interactions via robust Spearman correlation 4. Focuses on interactions among the top 20 most important predictors 5. Detects interactions separately for death and readmission outcomes 6. Tracks interaction strength across multiple time horizons 7. Classifies interactions as robust, time-dependent, or transient 8. Identifies trends (growing, fading, stable) over follow-up time 9. Aggregates results into ranked, interpretable interaction summaries 10. Exports publication-ready Excel reports with raw and summary tables

Code
#@title ⚡ Step 7.3: Automated Interaction Discovery (Correct Pattern)
import pickle
import numpy as np
import pandas as pd
import os
import glob
import shutil
from scipy.stats import pearsonr, spearmanr
from datetime import datetime

# --- 1. CONFIGURATION ---
# Search Pattern for your specific file
SEARCH_DIR = os.getcwd() #r"G:\My Drive\Alvacast\SISTRAT 2023\dh"
SEARCH_PATTERN = "DH4_Final_AJ_*_SHAP_MultiHorizon.pkl"

# Safe Local Working Directory
USER_HOME = os.path.expanduser("~")
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
LOCAL_PKL = os.path.join(WORK_DIR, "shap_data_interactions.pkl") 

# Output config
TOP_N_MAIN_FEATURES = 20  # Check interactions for top 20 predictors
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_FILENAME = f"DH73_Interaction_Analysis_{TIMESTAMP}.xlsx"
EXCEL_PATH = os.path.join(WORK_DIR, EXCEL_FILENAME)

os.makedirs(WORK_DIR, exist_ok=True)
print(f"📍 Working Directory: {WORK_DIR}")

# --- 2. FIND & COPY DATA ---
# Find the file matching your specific pattern
search_path = os.path.join(SEARCH_DIR, SEARCH_PATTERN)
found_files = sorted(glob.glob(search_path))

if not found_files:
    print(f"❌ No files found matching: {search_path}")
    print("   Please check the path and filename pattern.")
    # Fallback to local check just in case
    if os.path.exists(LOCAL_PKL):
        print("   ⚠️ Using existing local file instead.")
    else:
        raise FileNotFoundError("Could not find the SHAP pickle file.")
else:
    target_file = found_files[-1] # Take the latest one
    print(f"📂 Found Source: {os.path.basename(target_file)}")
    
    # Copy to local C: drive to avoid G: drive read errors during heavy processing
    if not os.path.exists(LOCAL_PKL) or os.path.getsize(LOCAL_PKL) != os.path.getsize(target_file):
        print("   🔄 Copying to local disk for safe processing...")
        shutil.copyfile(target_file, LOCAL_PKL)
        print("   ✅ Copy complete.")
    else:
        print("   ✅ Local copy already exists.")

# --- 3. INTERACTION ENGINE ---
def calculate_interaction_strength(main_idx, shap_matrix, X_matrix, feature_names):
    """
    Estimates interaction strength using SHAP Dependence:
    Interaction ~ Correlation(Residuals of Main Feature SHAP, Interactor Feature Value)
    """
    main_shap = shap_matrix[:, main_idx]
    main_val = X_matrix.iloc[:, main_idx].values
    
    # 1. Remove Main Effect (Polynomial Fit deg=2 to account for non-linearity)
    try:
        if np.std(main_val) < 1e-6: return []
        
        z = np.polyfit(main_val, main_shap, 2) 
        p = np.poly1d(z)
        residuals = main_shap - p(main_val)
    except:
        return []

    candidates = []
    
    # 2. Check all other features
    for j in range(X_matrix.shape[1]):
        if j == main_idx: continue
        
        interactor_val = X_matrix.iloc[:, j].values
        if np.std(interactor_val) < 1e-6: continue
        
        # Spearman Correlation (Robust to outliers)
        corr, _ = spearmanr(interactor_val, residuals)
        strength = abs(corr)
        
        # Threshold for "Meaningful" Interaction
        if strength > 0.15:
            candidates.append({
                'Interactor': feature_names[j],
                'Strength': strength
            })
            
    return sorted(candidates, key=lambda x: x['Strength'], reverse=True)[:3]

# --- 4. MAIN ANALYSIS LOOP ---
print(f"🚀 Loading Data...")
with open(LOCAL_PKL, 'rb') as f:
    shap_data_export = pickle.load(f)

all_interactions = []

for outcome in ['death', 'readm']:
    if outcome not in shap_data_export: continue
    
    print(f"\n🔍 Scanning {outcome.upper()}...")
    horizons = sorted(shap_data_export[outcome].keys())
    
    for h in horizons:
        data_dict = shap_data_export[outcome][h]
        if not data_dict or 'shap_values' not in data_dict: continue
        
        shap_matrix = data_dict['shap_values']
        X_data = data_dict['data']
        feature_names = X_data.columns.tolist()
        
        # Identify Top Features
        mean_abs_shap = np.mean(np.abs(shap_matrix), axis=0)
        top_indices = np.argsort(mean_abs_shap)[::-1][:TOP_N_MAIN_FEATURES]
        
        for main_idx in top_indices:
            main_name = feature_names[main_idx]
            interactors = calculate_interaction_strength(main_idx, shap_matrix, X_data, feature_names)
            
            for item in interactors:
                # Alphabetical key to unify A*B and B*A
                pair_key = " * ".join(sorted([main_name, item['Interactor']]))
                
                all_interactions.append({
                    'Outcome': outcome,
                    'Horizon': h,
                    'Main Feature': main_name,
                    'Interactor': item['Interactor'],
                    'Pair': pair_key,
                    'Strength': item['Strength']
                })
        print(f"   Checked Horizon: {h}m")

# --- 5. AGGREGATE & EXPORT ---
if all_interactions:
    df_raw = pd.DataFrame(all_interactions)
    
    # Calculate Stability & Trends
    summary_list = []
    for (outcome, pair), group in df_raw.groupby(['Outcome', 'Pair']):
        
        # Stability: Frequency across horizons
        freq = len(group['Horizon'].unique())
        total_horizons = len(horizons)
        stability_score = freq / total_horizons
        
        # Trend: Strength vs Time
        if freq > 1:
            slope = np.polyfit(group['Horizon'], group['Strength'], 1)[0]
        else:
            slope = 0
            
        if slope > 0.002: trend = "Growing ↗️"
        elif slope < -0.002: trend = "Fading ↘️"
        else: trend = "Stable ➡️"
        
        # Classification
        if stability_score > 0.7: m_type = "Robust (General)"
        elif stability_score < 0.3: m_type = "Transient (Noise?)"
        else: m_type = "Time-Dependent"
        
        summary_list.append({
            'Outcome': outcome,
            'Pair': pair,
            'Avg Strength': group['Strength'].mean(),
            'Max Strength': group['Strength'].max(),
            'Trend': trend,
            'Type': m_type,
            'Frequency': f"{freq}/{total_horizons}"
        })
        
    df_summary = pd.DataFrame(summary_list).sort_values('Avg Strength', ascending=False)
    
    # Save Locally First
    print(f"\n💾 Saving Report to {EXCEL_PATH}...")
    with pd.ExcelWriter(EXCEL_PATH, engine='xlsxwriter') as writer:
        df_summary.to_excel(writer, sheet_name='Top Candidates', index=False)
        df_raw.to_excel(writer, sheet_name='Raw Data', index=False)
        
        # Format
        worksheet = writer.sheets['Top Candidates']
        worksheet.set_column('B:B', 50) # Pair Width
    
    # Copy back to G: Drive
    try:
        dest_path = os.path.join(SEARCH_DIR, EXCEL_FILENAME)
        shutil.copy(EXCEL_PATH, dest_path)
        print(f"✅ Success! Uploaded to G: Drive: {dest_path}")
    except Exception as e:
        print(f"⚠️ Copy to G: Drive failed ({e}). File is available locally.")
else:
    print("⚠️ No strong interactions found.")
📍 Working Directory: C:\Users\andre\DH_Analysis
📂 Found Source: DH4_Final_AJ_20260209_2329_SHAP_MultiHorizon.pkl
   🔄 Copying to local disk for safe processing...
   ✅ Copy complete.
🚀 Loading Data...

🔍 Scanning DEATH...
   Checked Horizon: 3m
   Checked Horizon: 12m
   Checked Horizon: 24m
   Checked Horizon: 48m
   Checked Horizon: 60m
   Checked Horizon: 72m

🔍 Scanning READM...
   Checked Horizon: 3m
   Checked Horizon: 12m
   Checked Horizon: 24m
   Checked Horizon: 48m
   Checked Horizon: 60m
   Checked Horizon: 72m

💾 Saving Report to C:\Users\andre\DH_Analysis\DH73_Interaction_Analysis_20260210_1428.xlsx...
✅ Success! Uploaded to G: Drive: G:\My Drive\Alvacast\SISTRAT 2023\dh\DH73_Interaction_Analysis_20260210_1428.xlsx
Code
import pandas as pd
from IPython.display import display

# 1. Create the Summary Data (expanded to top ~20 key interactions, prioritizing robust and time-dependent across all horizons)
data = [
    {
        "Rank": 1,
        "Interaction": "plan_type_corr_m_pai * plan_type_corr_m_pr",
        "Outcomes": "Readmission (Primary), Death",
        "Avg Strength": 0.565,
        "Dynamics": "Stable ➡️, Robust (6/6 Horizons)",
        "Interpretation": "Dominant robust interaction between treatment plan types (PAI and PR). Synergistic amplification of risks for both outcomes across all horizons; highlights potential mismatches in plan combinations, urging review of integrated treatment strategies to reduce compounded hazards."
    },
    {
        "Rank": 2,
        "Interaction": "marital_status_rec_separated_divorced_annulled_widowed * plan_type_corr_m_pr",
        "Outcomes": "Readmission, Death",
        "Avg Strength": 0.517,
        "Dynamics": "Stable ➡️, Robust (6/6 Horizons)",
        "Interpretation": "Socio-demographic and treatment plan interaction. Non-stable marital status interacts with PR plan to consistently elevate risks; may indicate relational instability exacerbating treatment challenges, suggesting family-inclusive interventions."
    },
    {
        "Rank": 3,
        "Interaction": "plan_type_corr_m_pai * tr_outcome_adm_discharge_rule_violation_undet",
        "Outcomes": "Death, Readmission",
        "Avg Strength": 0.510,
        "Dynamics": "Stable ➡️, Robust (6/6 Horizons)",
        "Interpretation": "Plan type PAI with undetermined/rule violation discharge. Stable risk amplification; underscores non-compliance issues in specific plans, recommending enhanced discharge planning and follow-up support."
    },
    {
        "Rank": 4,
        "Interaction": "plan_type_corr_m_pr * plan_type_corr_pg_pr",
        "Outcomes": "Readmission (Primary), Death",
        "Avg Strength": 0.450,
        "Dynamics": "Stable ➡️, Robust (6/6 Horizons)",
        "Interpretation": "Interaction among PR plan variants. Consistent elevation of hazards; suggests overlapping inefficiencies in plan designs, with opportunities for streamlining to prevent synergistic negative effects."
    },
    {
        "Rank": 5,
        "Interaction": "adm_motive_justice_sector * first_sub_used_cocaine_powder",
        "Outcomes": "Readmission (Primary), Death",
        "Avg Strength": 0.447,
        "Dynamics": "Stable ➡️ (Readm), Fading ↘️ (Death), Robust/Time-Dependent (5/6, 2/6)",
        "Interpretation": "Justice sector referral with early cocaine powder use. Amplifies risks, more stably for readmission; fading in death implies acute phase vulnerabilities, advocating for specialized legal-substance programs."
    },
    {
        "Rank": 6,
        "Interaction": "eva_consumo * eva_fisica",
        "Outcomes": "Readmission",
        "Avg Strength": 0.442,
        "Dynamics": "Growing ↗️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Consumption and physical health evaluations. Increasing interaction over time; points to worsening physical conditions tied to substance use driving readmissions, emphasizing holistic health assessments."
    },
    {
        "Rank": 7,
        "Interaction": "eva_consumo * evaluacindelprocesoteraputico",
        "Outcomes": "Readmission (Primary), Death",
        "Avg Strength": 0.395,
        "Dynamics": "Growing ↗️, Robust/Time-Dependent (5/6, 4/6)",
        "Interpretation": "Consumption eval with therapeutic process. Growing strength signals poor therapy engagement linked to consumption patterns; recommends adaptive therapeutic approaches to counteract this escalation."
    },
    {
        "Rank": 8,
        "Interaction": "first_sub_used_cocaine_powder * primary_sub_mod_marijuana",
        "Outcomes": "Death, Readmission",
        "Avg Strength": 0.394,
        "Dynamics": "Stable ➡️, Robust (6/6 Horizons)",
        "Interpretation": "Early cocaine powder with primary marijuana modality. Stable risk elevation; may represent gateway effects or mixed substance profiles, with marijuana offering partial mitigation but not elimination of risks."
    },
    {
        "Rank": 9,
        "Interaction": "national_foreign * tr_outcome_adm_discharge_rule_violation_undet",
        "Outcomes": "Death, Readmission",
        "Avg Strength": 0.393,
        "Dynamics": "Stable ➡️, Time-Dependent (4/6 Horizons)",
        "Interpretation": "National/foreign status with rule violation discharge. Consistent interaction; could reflect access barriers for foreigners in discharge processes, suggesting culturally sensitive support mechanisms."
    },
    {
        "Rank": 10,
        "Interaction": "adm_motive_another_sud_facility_fonodrogas_senda_previene * tr_outcome_adm_discharge_rule_violation_undet",
        "Outcomes": "Death, Readmission",
        "Avg Strength": 0.389,
        "Dynamics": "Stable ➡️, Time-Dependent/Robust (4/6, 6/6)",
        "Interpretation": "Referral from other SUD facilities with violation discharge. Amplifies risks stably; indicates challenges in facility transfers and compliance, calling for improved inter-agency protocols."
    },
    {
        "Rank": 11,
        "Interaction": "eva_fisica * eva_transgnorma",
        "Outcomes": "Death",
        "Avg Strength": 0.373,
        "Dynamics": "Growing ↗️, Robust (5/6 Horizons)",
        "Interpretation": "Physical and normative transgression evaluations. Increasing over time; suggests physical health decline linked to behavioral norms violations, with growing impact on mortality risks."
    },
    {
        "Rank": 12,
        "Interaction": "tr_outcome_adm_discharge_adm_reasons * tr_outcome_adm_discharge_rule_violation_undet",
        "Outcomes": "Readmission (Primary), Death",
        "Avg Strength": 0.373,
        "Dynamics": "Stable ➡️, Time-Dependent (2/6, 3/6)",
        "Interpretation": "Administrative discharge reasons with rule violations. Interaction elevates readmission/death risks; highlights administrative failures compounding outcomes, recommending better discharge categorization."
    },
    {
        "Rank": 13,
        "Interaction": "eva_consumo * eva_sm",
        "Outcomes": "Readmission",
        "Avg Strength": 0.367,
        "Dynamics": "Growing ↗️, Robust (6/6 Horizons)",
        "Interpretation": "Consumption and mental health evals. Escalating interaction; mental health issues intertwined with substance consumption drive recurrent readmissions, urging dual-diagnosis treatments."
    },
    {
        "Rank": 14,
        "Interaction": "eva_relinterp * evaluacindelprocesoteraputico",
        "Outcomes": "Readmission",
        "Avg Strength": 0.411,
        "Dynamics": "Stable ➡️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Interpersonal relations eval with therapeutic process. Stable in select horizons; poor relations may hinder therapy, increasing readmission risks in mid-term periods."
    },
    {
        "Rank": 15,
        "Interaction": "eva_fam * eva_relinterp",
        "Outcomes": "Readmission",
        "Avg Strength": 0.405,
        "Dynamics": "Stable ➡️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Family and interpersonal evals. Interaction suggests family dynamics affecting relations, stably raising readmission in certain horizons; family therapy integration could mitigate."
    },
    {
        "Rank": 16,
        "Interaction": "eva_relinterp * eva_sm",
        "Outcomes": "Readmission",
        "Avg Strength": 0.403,
        "Dynamics": "Stable ➡️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Interpersonal and mental health evals. Consistent in mid-horizons; relational issues compounded by mental health may predict readmissions, highlighting social support needs."
    },
    {
        "Rank": 17,
        "Interaction": "eva_consumo * eva_fam",
        "Outcomes": "Death",
        "Avg Strength": 0.419,
        "Dynamics": "Growing ↗️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Consumption and family evals. Growing interaction for death; family dysfunction linked to consumption patterns escalates long-term mortality, suggesting family-focused prevention."
    },
    {
        "Rank": 18,
        "Interaction": "primary_sub_mod_marijuana * tipo_de_vivienda_rec2_other_unknown",
        "Outcomes": "Death, Readmission",
        "Avg Strength": 0.390,
        "Dynamics": "Growing ↗️ (Death), Stable ➡️ (Readm), Time-Dependent/Robust (3/6, 6/6)",
        "Interpretation": "Primary marijuana with other/unknown housing type. Amplifies risks; unstable housing with marijuana use may worsen outcomes, with growing death impact over time."
    },
    {
        "Rank": 19,
        "Interaction": "ethnicity * plan_type_corr_m_pr",
        "Outcomes": "Death",
        "Avg Strength": 0.386,
        "Dynamics": "Growing ↗️, Time-Dependent (2/6 Horizons)",
        "Interpretation": "Ethnicity with PR plan type. Increasing interaction; potential cultural-plan mismatches elevate death risks, recommending culturally adapted treatment plans."
    },
    {
        "Rank": 20,
        "Interaction": "eva_consumo * evaluacindelprocesoteraputico",
        "Outcomes": "Death",
        "Avg Strength": 0.359,
        "Dynamics": "Growing ↗️, Time-Dependent (4/6 Horizons)",
        "Interpretation": "Repeated for death context: Consumption and therapeutic eval growing synergy; poor process engagement with high consumption predicts escalating death hazards."
    }
]

# 2. Convert to DataFrame
df_summary = pd.DataFrame(data)

# 3. Style for HTML display (Positron-safe)
styled_table = (
    df_summary.style
    .set_caption("📊 Table: Key Feature Interactions (All Horizons)")
    .set_properties(**{
        "text-align": "left",
        "white-space": "pre-wrap",
        "font-size": "14px"
    })
    .set_table_styles([
        {"selector": "caption", "props": [
            ("font-size", "16px"),
            ("font-weight", "bold"),
            ("margin-bottom", "10px")
        ]},
        {"selector": "th", "props": [
            ("background-color", "#f4f4f4"),
            ("border-bottom", "2px solid #555"),
            ("text-align", "center")
        ]},
        {"selector": "td", "props": [
            ("padding", "8px")
        ]},
        {"selector": "tr:hover", "props": [
            ("background-color", "#f9f9f9")
        ]}
    ])
)

# 4. Display
display(styled_table)
Table 4: 📊 Table: Key Feature Interactions (All Horizons)
  Rank Interaction Outcomes Avg Strength Dynamics Interpretation
0 1 plan_type_corr_m_pai * plan_type_corr_m_pr Readmission (Primary), Death 0.565000 Stable ➡️, Robust (6/6 Horizons) Dominant robust interaction between treatment plan types (PAI and PR). Synergistic amplification of risks for both outcomes across all horizons; highlights potential mismatches in plan combinations, urging review of integrated treatment strategies to reduce compounded hazards.
1 2 marital_status_rec_separated_divorced_annulled_widowed * plan_type_corr_m_pr Readmission, Death 0.517000 Stable ➡️, Robust (6/6 Horizons) Socio-demographic and treatment plan interaction. Non-stable marital status interacts with PR plan to consistently elevate risks; may indicate relational instability exacerbating treatment challenges, suggesting family-inclusive interventions.
2 3 plan_type_corr_m_pai * tr_outcome_adm_discharge_rule_violation_undet Death, Readmission 0.510000 Stable ➡️, Robust (6/6 Horizons) Plan type PAI with undetermined/rule violation discharge. Stable risk amplification; underscores non-compliance issues in specific plans, recommending enhanced discharge planning and follow-up support.
3 4 plan_type_corr_m_pr * plan_type_corr_pg_pr Readmission (Primary), Death 0.450000 Stable ➡️, Robust (6/6 Horizons) Interaction among PR plan variants. Consistent elevation of hazards; suggests overlapping inefficiencies in plan designs, with opportunities for streamlining to prevent synergistic negative effects.
4 5 adm_motive_justice_sector * first_sub_used_cocaine_powder Readmission (Primary), Death 0.447000 Stable ➡️ (Readm), Fading ↘️ (Death), Robust/Time-Dependent (5/6, 2/6) Justice sector referral with early cocaine powder use. Amplifies risks, more stably for readmission; fading in death implies acute phase vulnerabilities, advocating for specialized legal-substance programs.
5 6 eva_consumo * eva_fisica Readmission 0.442000 Growing ↗️, Time-Dependent (2/6 Horizons) Consumption and physical health evaluations. Increasing interaction over time; points to worsening physical conditions tied to substance use driving readmissions, emphasizing holistic health assessments.
6 7 eva_consumo * evaluacindelprocesoteraputico Readmission (Primary), Death 0.395000 Growing ↗️, Robust/Time-Dependent (5/6, 4/6) Consumption eval with therapeutic process. Growing strength signals poor therapy engagement linked to consumption patterns; recommends adaptive therapeutic approaches to counteract this escalation.
7 8 first_sub_used_cocaine_powder * primary_sub_mod_marijuana Death, Readmission 0.394000 Stable ➡️, Robust (6/6 Horizons) Early cocaine powder with primary marijuana modality. Stable risk elevation; may represent gateway effects or mixed substance profiles, with marijuana offering partial mitigation but not elimination of risks.
8 9 national_foreign * tr_outcome_adm_discharge_rule_violation_undet Death, Readmission 0.393000 Stable ➡️, Time-Dependent (4/6 Horizons) National/foreign status with rule violation discharge. Consistent interaction; could reflect access barriers for foreigners in discharge processes, suggesting culturally sensitive support mechanisms.
9 10 adm_motive_another_sud_facility_fonodrogas_senda_previene * tr_outcome_adm_discharge_rule_violation_undet Death, Readmission 0.389000 Stable ➡️, Time-Dependent/Robust (4/6, 6/6) Referral from other SUD facilities with violation discharge. Amplifies risks stably; indicates challenges in facility transfers and compliance, calling for improved inter-agency protocols.
10 11 eva_fisica * eva_transgnorma Death 0.373000 Growing ↗️, Robust (5/6 Horizons) Physical and normative transgression evaluations. Increasing over time; suggests physical health decline linked to behavioral norms violations, with growing impact on mortality risks.
11 12 tr_outcome_adm_discharge_adm_reasons * tr_outcome_adm_discharge_rule_violation_undet Readmission (Primary), Death 0.373000 Stable ➡️, Time-Dependent (2/6, 3/6) Administrative discharge reasons with rule violations. Interaction elevates readmission/death risks; highlights administrative failures compounding outcomes, recommending better discharge categorization.
12 13 eva_consumo * eva_sm Readmission 0.367000 Growing ↗️, Robust (6/6 Horizons) Consumption and mental health evals. Escalating interaction; mental health issues intertwined with substance consumption drive recurrent readmissions, urging dual-diagnosis treatments.
13 14 eva_relinterp * evaluacindelprocesoteraputico Readmission 0.411000 Stable ➡️, Time-Dependent (2/6 Horizons) Interpersonal relations eval with therapeutic process. Stable in select horizons; poor relations may hinder therapy, increasing readmission risks in mid-term periods.
14 15 eva_fam * eva_relinterp Readmission 0.405000 Stable ➡️, Time-Dependent (2/6 Horizons) Family and interpersonal evals. Interaction suggests family dynamics affecting relations, stably raising readmission in certain horizons; family therapy integration could mitigate.
15 16 eva_relinterp * eva_sm Readmission 0.403000 Stable ➡️, Time-Dependent (2/6 Horizons) Interpersonal and mental health evals. Consistent in mid-horizons; relational issues compounded by mental health may predict readmissions, highlighting social support needs.
16 17 eva_consumo * eva_fam Death 0.419000 Growing ↗️, Time-Dependent (2/6 Horizons) Consumption and family evals. Growing interaction for death; family dysfunction linked to consumption patterns escalates long-term mortality, suggesting family-focused prevention.
17 18 primary_sub_mod_marijuana * tipo_de_vivienda_rec2_other_unknown Death, Readmission 0.390000 Growing ↗️ (Death), Stable ➡️ (Readm), Time-Dependent/Robust (3/6, 6/6) Primary marijuana with other/unknown housing type. Amplifies risks; unstable housing with marijuana use may worsen outcomes, with growing death impact over time.
18 19 ethnicity * plan_type_corr_m_pr Death 0.386000 Growing ↗️, Time-Dependent (2/6 Horizons) Ethnicity with PR plan type. Increasing interaction; potential cultural-plan mismatches elevate death risks, recommending culturally adapted treatment plans.
19 20 eva_consumo * evaluacindelprocesoteraputico Death 0.359000 Growing ↗️, Time-Dependent (4/6 Horizons) Repeated for death context: Consumption and therapeutic eval growing synergy; poor process engagement with high consumption predicts escalating death hazards.

Main interactions:

  • Plan × discharge outcome
  • Consumption × evaluation scores
  • Justice / referral × substance type
  • Family / interpersonal × evaluation

Functional form

Code
#@title ⚡ Step 8: Functional Form Analysis (Parquet + In-Memory)
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
import glob
import shutil
from datetime import datetime

# --- 1. CONFIGURATION ---
CONTINUOUS_VARS = ['adm_age_rec3', 'porc_pobr', 'dit_m']
SEARCH_DIR = r"G:\My Drive\Alvacast\SISTRAT 2023\dh"
SEARCH_PATTERN = "DH4_Final_AJ_*_SHAP_MultiHorizon.pkl"

USER_HOME = os.path.expanduser("~")
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
LOCAL_PKL = os.path.join(WORK_DIR, "shap_data_functional.pkl")
OUTPUT_DIR = os.path.join(WORK_DIR, "dh8")

TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M")
EXCEL_FILENAME = f"DH8_Functional_Forms_{TIMESTAMP}.xlsx"
PARQUET_FILENAME = f"DH8_Functional_Forms_{TIMESTAMP}.parquet"

EXCEL_PATH = os.path.join(WORK_DIR, EXCEL_FILENAME)
PARQUET_PATH = os.path.join(WORK_DIR, PARQUET_FILENAME)

os.makedirs(WORK_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"📍 Working Directory: {WORK_DIR}")

# --- 2. FIND & COPY DATA ---
search_path = os.path.join(SEARCH_DIR, SEARCH_PATTERN)
found_files = sorted(glob.glob(search_path))

if not found_files:
    if os.path.exists(LOCAL_PKL):
        print("⚠️ Source file not found in G: Drive. Using existing local copy.")
    else:
        raise FileNotFoundError(f"❌ Could not find file matching: {search_path}")
else:
    target_file = found_files[-1]
    print(f"📂 Found Source: {os.path.basename(target_file)}")
    if not os.path.exists(LOCAL_PKL) or os.path.getsize(LOCAL_PKL) != os.path.getsize(target_file):
        print("   🔄 Copying to local disk...")
        shutil.copyfile(target_file, LOCAL_PKL)
        print("   ✅ Copy complete.")
    else:
        print("   ✅ Local copy is up to date.")

# --- 3. ANALYSIS LOOP ---
print(f"🚀 Loading Data...")
with open(LOCAL_PKL, 'rb') as f:
    shap_data_export = pickle.load(f)

writer = pd.ExcelWriter(EXCEL_PATH, engine='xlsxwriter')
all_data_list = [] # List to collect all data for Parquet

print(f"⚡ Analyzing Variables: {CONTINUOUS_VARS}")

for outcome in ['death', 'readm']:
    if outcome not in shap_data_export: continue
    
    print(f"\n🔍 Processing Outcome: {outcome.upper()}")
    horizons = sorted(shap_data_export[outcome].keys())
    
    for h in horizons:
        data_dict = shap_data_export[outcome][h]
        if not data_dict or 'shap_values' not in data_dict: continue
        
        # Scale SHAP by 100 (%)
        shap_vals = data_dict['shap_values'] * 100 
        X_data = data_dict['data']
        feature_names = X_data.columns.tolist()
        
        for var in CONTINUOUS_VARS:
            if var not in feature_names: continue
            
            col_idx = feature_names.index(var)
            x_vec = X_data.iloc[:, col_idx].values
            y_vec = shap_vals[:, col_idx]
            
            # 1. PLOTTING
            plt.figure(figsize=(8, 6))
            plt.scatter(x_vec, y_vec, alpha=0.5, c='#1f77b4', s=30, label='Patients')
            try:
                z = np.polyfit(x_vec, y_vec, 3)
                p = np.poly1d(z)
                x_trend = np.linspace(min(x_vec), max(x_vec), 100)
                plt.plot(x_trend, p(x_trend), "r--", linewidth=2.5, label="Trend")
            except: pass

            plt.title(f"Functional Form: {var}\n({outcome.capitalize()} @ {h} Months)", fontsize=14)
            plt.xlabel(f"Feature Value: {var}", fontsize=12)
            plt.ylabel("Impact on Risk (%)", fontsize=12)
            plt.axhline(0, color='k', linestyle=':', alpha=0.5)
            plt.grid(True, alpha=0.3)
            plt.legend()
            
            plot_fname = f"{outcome}_{h}m_{var}.png"
            plot_path = os.path.join(OUTPUT_DIR, plot_fname)
            plt.savefig(plot_path, dpi=300, bbox_inches='tight')
            plt.close()
            
            # 2. COLLECT DATA
            # Store everything in a clean format
            temp_df = pd.DataFrame({
                "Feature_Value": x_vec,
                "SHAP_Impact": y_vec,
                "Predictor": var,
                "Outcome": outcome,
                "Time": int(h)
            })
            
            # Add to Master List (for Parquet)
            all_data_list.append(temp_df)
            
            # Add to Excel (Sheet by Sheet)
            sheet_name = f"{outcome[:1]}_{h}m_{var[:5]}"
            temp_df[['Feature_Value', 'SHAP_Impact']].to_excel(writer, sheet_name=sheet_name, index=False)
            
        print(f"   Done Horizon: {h}m")

# Save Excel
writer.close()

# Save Parquet (Much faster for reloading)
if all_data_list:
    full_df = pd.concat(all_data_list, ignore_index=True)
    full_df.to_parquet(PARQUET_PATH, index=False)
    print(f"\n✅ Parquet Saved: {PARQUET_PATH}")
    
    # Store in a global variable for immediate use in next cell
    # (This avoids reloading from disk if you run Step 8.1 immediately)
    global_functional_data = full_df 
    print("✅ Data stored in memory as 'global_functional_data'")

print(f"✅ Analysis Complete.")

# --- 4. COPY BACK TO G: DRIVE ---
print("\n📤 Uploading results to Google Drive...")
try:
    # Copy Excel
    shutil.copy(EXCEL_PATH, os.path.join(SEARCH_DIR, EXCEL_FILENAME))
    
    # Copy Parquet (New!)
    shutil.copy(PARQUET_PATH, os.path.join(SEARCH_DIR, PARQUET_FILENAME))
    print(f"   ✅ Parquet uploaded: {PARQUET_FILENAME}")

    # Copy Plots
    dest_folder = os.path.join(SEARCH_DIR, "dh8")
    if os.path.exists(dest_folder): shutil.rmtree(dest_folder)
    shutil.copytree(OUTPUT_DIR, dest_folder)
    print(f"   ✅ Plots uploaded: {dest_folder}")
    
except Exception as e:
    print(f"⚠️ Upload failed ({e}). Files are safe on C: Drive.")
📍 Working Directory: C:\Users\andre\DH_Analysis
📂 Found Source: DH4_Final_AJ_20260209_2329_SHAP_MultiHorizon.pkl
   ✅ Local copy is up to date.
🚀 Loading Data...
⚡ Analyzing Variables: ['adm_age_rec3', 'porc_pobr', 'dit_m']

🔍 Processing Outcome: DEATH
   Done Horizon: 3m
   Done Horizon: 12m
   Done Horizon: 24m
   Done Horizon: 48m
   Done Horizon: 60m
   Done Horizon: 72m

🔍 Processing Outcome: READM
   Done Horizon: 3m
   Done Horizon: 12m
   Done Horizon: 24m
   Done Horizon: 48m
   Done Horizon: 60m
   Done Horizon: 72m

✅ Parquet Saved: C:\Users\andre\DH_Analysis\DH8_Functional_Forms_20260210_1543.parquet
✅ Data stored in memory as 'global_functional_data'
✅ Analysis Complete.

📤 Uploading results to Google Drive...
   ✅ Parquet uploaded: DH8_Functional_Forms_20260210_1543.parquet
   ✅ Plots uploaded: G:\My Drive\Alvacast\SISTRAT 2023\dh\dh8
Code
#@title 📊 Step 9: Faceted Functional Forms (Wrapped Layout)
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import glob
import os
import shutil

# --- 1. CONFIGURATION ---
USER_HOME = os.path.expanduser("~")
WORK_DIR = os.path.join(USER_HOME, "DH_Analysis")
OUTPUT_DIR = os.path.join(WORK_DIR, "dh_9")
SEARCH_DIR = r"G:\My Drive\Alvacast\SISTRAT 2023\dh"
TIMESTAMP = datetime.now().strftime("%Y%m%d_%H%M")
os.makedirs(OUTPUT_DIR, exist_ok=True)

# --- 2. LOAD DATA ---
df_plot = None

# Option A: Check Memory (Fastest)
if 'global_functional_data' in locals():
    print("⚡ Using data from memory (global_functional_data)...")
    df_plot = global_functional_data.copy()
else:
    # Option B: Load from Disk (Parquet)
    search_pattern = os.path.join(WORK_DIR, "DH8_Functional_Forms_*.parquet")
    found_files = sorted(glob.glob(search_pattern))
    
    if found_files:
        target_file = found_files[-1]
        print(f"📂 Loading Data from: {os.path.basename(target_file)}")
        df_plot = pd.read_parquet(target_file)
    else:
        print("❌ No Parquet file found from Step 8.")
        print("   Please run the 'Step 8' code block first.")

# --- 3. GENERATE PLOTS ---
if df_plot is not None:
    # Rename columns for cleaner plotting labels
    df_plot = df_plot.rename(columns={
        "Feature_Value": "Feature Value",
        "SHAP_Impact": "Risk Impact (%)"
    })

    # Set Style
    sns.set_style("whitegrid")
    sns.set_context("notebook", font_scale=1.2)
    
    unique_features = df_plot['Predictor'].unique()
    unique_outcomes = df_plot['Outcome'].unique()
    
    print(f"⚡ Generating Plots for {len(unique_features)} Predictors x {len(unique_outcomes)} Outcomes...")
    
    for feature in unique_features:
        for outcome in unique_outcomes:
            
            # Filter Data: Specific Feature AND Specific Outcome
            subset = df_plot[
                (df_plot['Predictor'] == feature) & 
                (df_plot['Outcome'] == outcome)
            ].copy()
            
            if subset.empty: continue

            # Determine Color
            # Red for Death, Blue for Readmission
            color = '#d62728' if 'death' in outcome.lower() else '#1f77b4'

            # Create FacetGrid
            # col="Time" creates columns for 3m, 6m, 12m...
            # col_wrap=4 ensures we get a 2nd row after 24m (fitting your request)
            g = sns.lmplot(
                data=subset, 
                x="Feature Value", 
                y="Risk Impact (%)", 
                col="Time",      
                col_wrap=3,      # 🟢 WRAP: 3-24m on Row 1, 36-96m on Row 2
                height=3.5, 
                aspect=1.2,
                scatter_kws={'alpha': 0.2, 's': 15, 'color': color, 'linewidths': 0}, 
                line_kws={'linewidth': 2.5, 'color': 'black'}, 
                order=3,         # Polynomial fit (Degree 3)
                sharex=True, 
                sharey=True      # Share Y-axis to compare magnitude changes over time
            )
            
            # Titles & Layout
            g.fig.suptitle(f"{outcome.capitalize()}: {feature}", fontsize=20, y=1.05, weight='bold', color='#333')
            g.set_titles("{col_name} Months")
            
            # Add Zero Line to every subplot
            for ax in g.axes.flatten():
                ax.axhline(0, color='gray', linestyle='--', linewidth=1)
                
            # Save Locally
            # Filename: Facet_Death_age.png
            fname = f"Facet_{outcome.capitalize()}_{feature}_{TIMESTAMP}.png"
            fname2 = f"Facet_{outcome.capitalize()}_{feature}_{TIMESTAMP}.pdf"
            save_path = os.path.join(OUTPUT_DIR, fname)
            save_path2 = os.path.join(OUTPUT_DIR, fname2)
            g.savefig(save_path, dpi=300, bbox_inches='tight')
            g.savefig(save_path2, bbox_inches='tight')
            plt.close()
            
            print(f"   ✅ Saved: {fname}")

    print(f"\n🏁 Done! Plots saved in: {OUTPUT_DIR}")

    # --- 4. UPLOAD TO G: DRIVE ---
    print("\n📤 Uploading plots to Google Drive...")
    
    # 🟢 NEW DESTINATION: dh9
    dest_folder = os.path.join(SEARCH_DIR, "dh_9")
    
    try:
        # Remove old folder if exists to ensure clean update
        if os.path.exists(dest_folder): shutil.rmtree(dest_folder)
        shutil.copytree(OUTPUT_DIR, dest_folder)
        print(f"   ✅ Success: {dest_folder}")
    except Exception as e:
        print(f"⚠️ Upload failed ({e}). Files are safe on C: Drive.")
⚡ Using data from memory (global_functional_data)...
⚡ Generating Plots for 3 Predictors x 2 Outcomes...
   ✅ Saved: Facet_Death_adm_age_rec3_20260210_1634.png
   ✅ Saved: Facet_Readm_adm_age_rec3_20260210_1634.png
   ✅ Saved: Facet_Death_porc_pobr_20260210_1634.png
   ✅ Saved: Facet_Readm_porc_pobr_20260210_1634.png
   ✅ Saved: Facet_Death_dit_m_20260210_1634.png
   ✅ Saved: Facet_Readm_dit_m_20260210_1634.png

🏁 Done! Plots saved in: C:\Users\andre\DH_Analysis\dh_9

📤 Uploading plots to Google Drive...
   ✅ Success: G:\My Drive\Alvacast\SISTRAT 2023\dh\dh_9
Code
import pandas as pd
from IPython.display import display

# 1. Prepare the Data (with Functional Form Correction)
data = [
    {
        "Feature": "1. Age at Admission\n(adm_age_rec3)",
        "Death (Biological Risk)": "🔴 Strong Monotonic Risk\n\nLinear dose-response. Older age is the dominant structural driver of mortality. No 'safe' threshold.",
        "Readmission (System Risk)": "🔵 Moderate Monotonic Risk\n\nPositive association, but significantly flatter slope than Death. Risk is less deterministic.",
        "Temporal Dynamics": "⚡ Amplifies over time\n\nSlope is steepest at 48-72 months. Biological vulnerability compounds, making age a progressively stronger constraint on survival.",
        "Suggested Functional Form Correction": "✅ Linear Term (or Polynomial Degree 2)\n\nThe relationship is predominantly linear. A standard linear term `Age` is sufficient, though `Age^2` might capture the slight acceleration in elderly cohorts."
    },
    {
        "Feature": "2. Days in Treatment\n(dit_m)",
        "Death (Biological Risk)": "🛡️ Strong Protective Effect\n(Logarithmic Decay)\n\nHigh risk for early dropouts. Risk drops precipitously in first 6-9 months, then plateaus.",
        "Readmission (System Risk)": "🛡️ Strong Protective Effect\n\nMirror image of death. Retention prevents the 'revolving door'. Longer stay = deeper system integration.",
        "Temporal Dynamics": "⚓ Early Critical Window\n\nMaximal impact at 3-12 months. The protective effect is established early; extending care beyond 18 months yields diminishing marginal returns.",
        "Suggested Functional Form Correction": "✅ Log-Transformation: log(X + 1)\n\nThe sharp initial drop followed by a plateau violates the linearity assumption. Using `log(Duration)` or a spline (e.g., `pspline(Duration)`) is critical to model the diminishing returns accurately."
    },
    {
        "Feature": "3. Poverty of Commune\n(porc_pobr)",
        "Death (Biological Risk)": "⚪ Noise / Weak Risk\n\nVery slight positive trend or noise. Socioeconomic status is not a direct biological killer in this context.",
        "Readmission (System Risk)": "🛡️ Slight Protective Paradox\n\nHigher poverty correlates with slightly LOWER readmission risk (Negative SHAP).",
        "Temporal Dynamics": "⏳ Latent Effect\n\nEffects are negligible at 3 months but become clearer at 60-72 months, suggesting long-term structural influence.",
        "Suggested Functional Form Correction": "✅ Threshold / Categorical\n\nThe effect is weak and noisy. It may be better modeled as a binary indicator (`High Poverty` vs. `Low Poverty`) or an interaction term with `Age` rather than a continuous linear predictor."
    }
]

# 2. Convert to DataFrame
df_summary = pd.DataFrame(data)

# 3. Style for HTML display (Positron-safe)
styled_table = (
    df_summary.style
    .set_caption("📊 Table: Functional Form & Interaction Synthesis")
    .set_properties(**{
        "text-align": "left",
        "white-space": "pre-wrap",
        "font-size": "14px",
        "vertical-align": "top"
    })
    .set_table_styles([
        {"selector": "caption", "props": [
            ("font-size", "16px"),
            ("font-weight", "bold"),
            ("margin-bottom", "10px"),
            ("color", "#333")
        ]},
        {"selector": "th", "props": [
            ("background-color", "#f4f4f4"),
            ("border-bottom", "2px solid #555"),
            ("text-align", "center"),
            ("font-weight", "bold"),
            ("font-size", "14px")
        ]},
        {"selector": "td", "props": [
            ("padding", "12px"),
            ("border-bottom", "1px solid #e0e0e0")
        ]},
        {"selector": "tr:hover", "props": [
            ("background-color", "#f9f9f9")
        ]}
    ])
)

# 4. Display
display(styled_table)
Table 5: 📊 Table: Functional Form & Interaction Synthesis (Thesis Results)
  Feature Death (Biological Risk) Readmission (System Risk) Temporal Dynamics Suggested Functional Form Correction
0 1. Age at Admission (adm_age_rec3) 🔴 Strong Monotonic Risk Linear dose-response. Older age is the dominant structural driver of mortality. No 'safe' threshold. 🔵 Moderate Monotonic Risk Positive association, but significantly flatter slope than Death. Risk is less deterministic. ⚡ Amplifies over time Slope is steepest at 48-72 months. Biological vulnerability compounds, making age a progressively stronger constraint on survival. ✅ Linear Term (or Polynomial Degree 2) The relationship is predominantly linear. A standard linear term `Age` is sufficient, though `Age^2` might capture the slight acceleration in elderly cohorts.
1 2. Days in Treatment (dit_m) 🛡️ Strong Protective Effect (Logarithmic Decay) High risk for early dropouts. Risk drops precipitously in first 6-9 months, then plateaus. 🛡️ Strong Protective Effect Mirror image of death. Retention prevents the 'revolving door'. Longer stay = deeper system integration. ⚓ Early Critical Window Maximal impact at 3-12 months. The protective effect is established early; extending care beyond 18 months yields diminishing marginal returns. ✅ Log-Transformation: log(X + 1) The sharp initial drop followed by a plateau violates the linearity assumption. Using `log(Duration)` or a spline (e.g., `pspline(Duration)`) is critical to model the diminishing returns accurately.
2 3. Poverty of Commune (porc_pobr) ⚪ Noise / Weak Risk Very slight positive trend or noise. Socioeconomic status is not a direct biological killer in this context. 🛡️ Slight Protective Paradox Higher poverty correlates with slightly LOWER readmission risk (Negative SHAP). ⏳ Latent Effect Effects are negligible at 3 months but become clearer at 60-72 months, suggesting long-term structural influence. ✅ Threshold / Categorical The effect is weak and noisy. It may be better modeled as a binary indicator (`High Poverty` vs. `Low Poverty`) or an interaction term with `Age` rather than a continuous linear predictor.
Back to top