Deep Surv (part 1)

This notebook implements a DeepSurv model for competing risks. It uses a single neural network to simultaneously predict the risk of death and readmission, explicitly modeling their competition. The model outputs Cumulative Incidence Functions (CIFs) for each risk, allowing evaluation of prediction accuracy over time for both outcomes using metrics like Uno’s C-Index (corrected for competing risks) and cause-specific Brier scores. It processes multiple imputed datasets and evaluates performance across different time horizons.

Author

ags

Published

April 1, 2026

0. Package loading and installation

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda install -c conda-forge \
#    numpy \
#    scipy \
#    pandas \
#    pyarrow \
#    scikit-survival \
#    spyder \
#    lifelines

# conda install -c conda-forge fastparquet
# conda install -c conda-forge xgboost
# conda install -c conda-forge pytorch cpuonly
# conda install -c pytorch pytorch cpuonly
# conda install -c conda-forge matplotlib
# conda install -c conda-forge seaborn
# conda install spyder-notebook -c spyder-ide
# conda install notebook nbformat nbconvert
# conda install -c conda-forge xlsxwriter
# conda install -c conda-forge shap

# import subprocess, sys

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "matplotlib"
# ])

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "seaborn"
# ])

print("numpy:", np.__version__)


from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv

#Dput
def dput_df(df, digits=6):
    data = {
        "columns": list(df.columns),
        "data": [
            [round(x, digits) if isinstance(x, (float, np.floating)) else x
             for x in row]
            for row in df.to_numpy()
        ]
    }
    print(data)


#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df
numpy: 2.0.1

Load data

Code

from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)

import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)
Code

import pandas as pd

for i in range(1, 6):
    globals()[f"imputation_nodum_{i}"] = pd.read_parquet(
        BASE_DIR / f"imputation_nondum_{i}.parquet",
        engine="fastparquet"
    )
Code
from IPython.display import display, HTML
import io
import sys

def fold_output(title, func):
    buffer = io.StringIO()
    sys.stdout = buffer
    func()
    sys.stdout = sys.__stdout__
    
    html = f"""
    <details>
      <summary>{title}</summary>
      <pre>{buffer.getvalue()}</pre>
    </details>
    """
    display(HTML(html))


fold_output(
    "Show imputation_nodum_1 structure",
    lambda: imputation_nodum_1.info()
)

fold_output(
    "Show imputation_1 structure",
    lambda: imputation_1.info()
)
Show imputation_nodum_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 43 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   readmit_time_from_adm_m        88504 non-null  float64
 1   death_time_from_adm_m          88504 non-null  float64
 2   adm_age_rec3                   88504 non-null  float64
 3   porc_pobr                      88504 non-null  float64
 4   dit_m                          88504 non-null  float64
 5   sex_rec                        88504 non-null  object 
 6   tenure_status_household        88504 non-null  object 
 7   cohabitation                   88504 non-null  object 
 8   sub_dep_icd10_status           88504 non-null  object 
 9   any_violence                   88504 non-null  object 
 10  prim_sub_freq_rec              88504 non-null  object 
 11  tr_outcome                     88504 non-null  object 
 12  adm_motive                     88504 non-null  object 
 13  first_sub_used                 88504 non-null  object 
 14  primary_sub_mod                88504 non-null  object 
 15  tipo_de_vivienda_rec2          88504 non-null  object 
 16  national_foreign               88504 non-null  int32  
 17  plan_type_corr                 88504 non-null  object 
 18  occupation_condition_corr24    88504 non-null  object 
 19  marital_status_rec             88504 non-null  object 
 20  urbanicity_cat                 88504 non-null  object 
 21  ed_attainment_corr             88504 non-null  object 
 22  evaluacindelprocesoteraputico  88504 non-null  object 
 23  eva_consumo                    88504 non-null  object 
 24  eva_fam                        88504 non-null  object 
 25  eva_relinterp                  88504 non-null  object 
 26  eva_ocupacion                  88504 non-null  object 
 27  eva_sm                         88504 non-null  object 
 28  eva_fisica                     88504 non-null  object 
 29  eva_transgnorma                88504 non-null  object 
 30  ethnicity                      88504 non-null  float64
 31  dg_psiq_cie_10_instudy         88504 non-null  bool   
 32  dg_psiq_cie_10_dg              88504 non-null  bool   
 33  dx_f3_mood                     88504 non-null  int32  
 34  dx_f6_personality              88504 non-null  int32  
 35  dx_f_any_severe_mental         88504 non-null  bool   
 36  any_phys_dx                    88504 non-null  bool   
 37  polysubstance_strict           88504 non-null  int32  
 38  readmit_event                  88504 non-null  float64
 39  death_event                    88504 non-null  int32  
 40  readmit_time_from_disch_m      88504 non-null  float64
 41  death_time_from_disch_m        88504 non-null  float64
 42  center_id                      88475 non-null  object 
dtypes: bool(4), float64(9), int32(5), object(25)
memory usage: 25.0+ MB
Show imputation_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 78 columns):
 #   Column                                                              Non-Null Count  Dtype  
---  ------                                                              --------------  -----  
 0   readmit_time_from_adm_m                                             88504 non-null  float64
 1   death_time_from_adm_m                                               88504 non-null  float64
 2   adm_age_rec3                                                        88504 non-null  float64
 3   porc_pobr                                                           88504 non-null  float64
 4   dit_m                                                               88504 non-null  float64
 5   national_foreign                                                    88504 non-null  int32  
 6   ethnicity                                                           88504 non-null  float64
 7   dg_psiq_cie_10_instudy                                              88504 non-null  bool   
 8   dg_psiq_cie_10_dg                                                   88504 non-null  bool   
 9   dx_f3_mood                                                          88504 non-null  int32  
 10  dx_f6_personality                                                   88504 non-null  int32  
 11  dx_f_any_severe_mental                                              88504 non-null  bool   
 12  any_phys_dx                                                         88504 non-null  bool   
 13  polysubstance_strict                                                88504 non-null  int32  
 14  readmit_time_from_disch_m                                           88504 non-null  float64
 15  readmit_event                                                       88504 non-null  float64
 16  death_time_from_disch_m                                             88504 non-null  float64
 17  death_event                                                         88504 non-null  int32  
 18  sex_rec_woman                                                       88504 non-null  float64
 19  tenure_status_household_illegal_settlement                          88504 non-null  float64
 20  tenure_status_household_owner_transferred_dwellings_pays_dividends  88504 non-null  float64
 21  tenure_status_household_renting                                     88504 non-null  float64
 22  tenure_status_household_stays_temporarily_with_a_relative           88504 non-null  float64
 23  cohabitation_alone                                                  88504 non-null  float64
 24  cohabitation_with_couple_children                                   88504 non-null  float64
 25  cohabitation_family_of_origin                                       88504 non-null  float64
 26  sub_dep_icd10_status_drug_dependence                                88504 non-null  float64
 27  any_violence_1_domestic_violence_sex_abuse                          88504 non-null  float64
 28  prim_sub_freq_rec_2_2_6_days_wk                                     88504 non-null  float64
 29  prim_sub_freq_rec_3_daily                                           88504 non-null  float64
 30  tr_outcome_adm_discharge_adm_reasons                                88504 non-null  float64
 31  tr_outcome_adm_discharge_rule_violation_undet                       88504 non-null  float64
 32  tr_outcome_completion                                               88504 non-null  float64
 33  tr_outcome_dropout                                                  88504 non-null  float64
 34  tr_outcome_referral                                                 88504 non-null  float64
 35  adm_motive_another_sud_facility_fonodrogas_senda_previene           88504 non-null  float64
 36  adm_motive_justice_sector                                           88504 non-null  float64
 37  adm_motive_sanitary_sector                                          88504 non-null  float64
 38  adm_motive_spontaneous_consultation                                 88504 non-null  float64
 39  first_sub_used_alcohol                                              88504 non-null  float64
 40  first_sub_used_cocaine_paste                                        88504 non-null  float64
 41  first_sub_used_cocaine_powder                                       88504 non-null  float64
 42  first_sub_used_marijuana                                            88504 non-null  float64
 43  first_sub_used_opioids                                              88504 non-null  float64
 44  first_sub_used_tranquilizers_hypnotics                              88504 non-null  float64
 45  primary_sub_mod_cocaine_paste                                       88504 non-null  float64
 46  primary_sub_mod_cocaine_powder                                      88504 non-null  float64
 47  primary_sub_mod_alcohol                                             88504 non-null  float64
 48  primary_sub_mod_marijuana                                           88504 non-null  float64
 49  tipo_de_vivienda_rec2_other_unknown                                 88504 non-null  float64
 50  plan_type_corr_m_pai                                                88504 non-null  float64
 51  plan_type_corr_m_pr                                                 88504 non-null  float64
 52  plan_type_corr_pg_pai                                               88504 non-null  float64
 53  plan_type_corr_pg_pr                                                88504 non-null  float64
 54  occupation_condition_corr24_inactive                                88504 non-null  float64
 55  occupation_condition_corr24_unemployed                              88504 non-null  float64
 56  marital_status_rec_separated_divorced_annulled_widowed              88504 non-null  float64
 57  marital_status_rec_single                                           88504 non-null  float64
 58  urbanicity_cat_1_rural                                              88504 non-null  float64
 59  urbanicity_cat_2_mixed                                              88504 non-null  float64
 60  ed_attainment_corr_2_completed_high_school_or_less                  88504 non-null  float64
 61  ed_attainment_corr_3_completed_primary_school_or_less               88504 non-null  float64
 62  evaluacindelprocesoteraputico_logro_intermedio                      88504 non-null  float64
 63  evaluacindelprocesoteraputico_logro_minimo                          88504 non-null  float64
 64  eva_consumo_logro_intermedio                                        88504 non-null  float64
 65  eva_consumo_logro_minimo                                            88504 non-null  float64
 66  eva_fam_logro_intermedio                                            88504 non-null  float64
 67  eva_fam_logro_minimo                                                88504 non-null  float64
 68  eva_relinterp_logro_intermedio                                      88504 non-null  float64
 69  eva_relinterp_logro_minimo                                          88504 non-null  float64
 70  eva_ocupacion_logro_intermedio                                      88504 non-null  float64
 71  eva_ocupacion_logro_minimo                                          88504 non-null  float64
 72  eva_sm_logro_intermedio                                             88504 non-null  float64
 73  eva_sm_logro_minimo                                                 88504 non-null  float64
 74  eva_fisica_logro_intermedio                                         88504 non-null  float64
 75  eva_fisica_logro_minimo                                             88504 non-null  float64
 76  eva_transgnorma_logro_intermedio                                    88504 non-null  float64
 77  eva_transgnorma_logro_minimo                                        88504 non-null  float64
dtypes: bool(4), float64(69), int32(5)
memory usage: 48.6 MB
Code
from IPython.display import display, Markdown

if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    display(Markdown(f"**First element type:** `{type(imputations_list_jan26[0])}`"))

    if isinstance(imputations_list_jan26[0], dict):
        display(Markdown(f"**First element keys:** `{list(imputations_list_jan26[0].keys())}`"))

    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        display(Markdown(f"**First element shape:** `{imputations_list_jan26[0].shape}`"))

First element type: <class 'pandas.core.frame.DataFrame'>

First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Format data

Due to inconsistencies and structural heterogeneity across previously merged datasets, we decided not to proceed with a direct inspection and comparison of column names between the first imputed dataset from imputations_list_jan26 (which likely included dummy-encoded variables) and imputation_nodum_1 (which likely retained non–dummy-encoded variables).

Instead, we reconstructed the analytic datasets de novo using the most recent source files available in the original directory (BASE_DIR). Time-to-event variables were re-derived to ensure internal consistency. Variables that could introduce information leakage (e.g., time from admission) were excluded, and the center identifier variable was removed prior to modeling.

Code
#1.2. Build Surv objects from df_final
from IPython.display import display, Markdown
from sksurv.util import Surv

for i in range(1, 6):
    # Get the DataFrame
    df = globals()[f"imputation_nodum_{i}"]

    # Extract time and event arrays
    time_readm  = df["readmit_time_from_disch_m"].to_numpy()
    event_readm = (df["readmit_event"].to_numpy() == 1)
    time_death  = df["death_time_from_disch_m"].to_numpy()
    event_death = (df["death_event"].to_numpy() == 1)

    # Create survival objects
    y_surv_readm = Surv.from_arrays(event=event_readm, time=time_readm)
    y_surv_death = Surv.from_arrays(event=event_death, time=time_death)

    # Store in global variables (optional but matches your pattern)
    globals()[f"y_surv_readm_{i}"]  = y_surv_readm
    globals()[f"y_surv_death_{i}"]  = y_surv_death

    # Print info
    display(Markdown(f"\n--- Imputation {i} ---"))
    display(Markdown(
    f"**y_surv_readm dtype:** {y_surv_readm.dtype}  \n"
    f"**shape:** {y_surv_readm.shape}"
    ))
    display(Markdown(
    f"**y_surv_death dtype:** {y_surv_death.dtype}  \n"
    f"**shape:** {y_surv_death.shape}"
    ))

— Imputation 1 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 2 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 3 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 4 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 5 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

Code
fold_output(
    "Show imputation_nodum_1 (newer database) glimpse",
    lambda: glimpse(imputation_nodum_1)
)
fold_output(
    "Show first db of imputations_list_jan26 (older) glimpse",
    lambda: glimpse(imputations_list_jan26[0])
)
Show imputation_nodum_1 (newer database) glimpse
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Show first db of imputations_list_jan26 (older) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             float64         1.0, 2.0, 1.0, 1.0, 2.0
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset (1–5), we identified and removed predictors with zero variance, as they provide no useful information and can destabilize models. We printed the dropped variables and produced a cleaned version of each design matrix. This ensures that all downstream analyses use only informative predictors.

Code
# Keep only these objects
objects_to_keep = {
    "objects_to_keep",
    "imputation_nodum_1",
    "imputation_nodum_2",
    "imputation_nodum_3",
    "imputation_nodum_4",
    "imputation_nodum_5",
    "y_surv_readm",
    "y_surv_death",
    "imputations_list_jan26"
}

import types

for name in list(globals().keys()):
    obj = globals()[name]
    if (
        name not in objects_to_keep
        and not name.startswith("_")
        and not callable(obj)
        and not isinstance(obj, types.ModuleType)  # <- protects ALL modules
    ):
        del globals()[name]
Code
from IPython.display import display, Markdown

# 1. Define columns to exclude (same as before)
target_cols = [
    "readmit_time_from_disch_m",
    "readmit_event",
    "death_time_from_disch_m",
    "death_event",
]

leak_time_cols = [
    "readmit_time_from_adm_m",
    "death_time_from_adm_m",
]

center_id = ["center_id"]

cols_to_exclude = target_cols + center_id  + leak_time_cols

# 2. Create list of your EXISTING imputation DataFrames (1-5)
imputed_dfs = [
    imputation_nodum_1,
    imputation_nodum_2,
    imputation_nodum_3,
    imputation_nodum_4,
    imputation_nodum_5
]

# 3. Preprocessing loop
X_reduced_list = []

for d, df in enumerate(imputed_dfs):
    imputation_num = d + 1  # Convert 0-index to 1-index for display

    display(Markdown(f"\n=== Imputation dataset {imputation_num} ==="))

    # a) Identify and drop constant predictors
    const_mask = (df.nunique(dropna=False) <= 1)
    dropped_const = df.columns[const_mask].tolist()
    display(Markdown(f"**Constant predictors dropped ({len(dropped_const)}):**"))
    display(Markdown(f"{dropped_const if dropped_const else 'None'}"))

    # b) Remove constant columns
    X_reduced = df.loc[:, ~const_mask]

    # c) Drop target/leakage columns (if present)
    cols_to_drop = [col for col in cols_to_exclude if col in X_reduced.columns]
    if cols_to_drop:
        X_reduced = X_reduced.drop(columns=cols_to_drop)
        display(Markdown(f"**Dropped target/leakage columns:** {cols_to_drop}"))
    else:
        display(Markdown("No target/leakage columns found to drop"))

    # d) Store cleaned DataFrame
    X_reduced_list.append(X_reduced)

    # e) Report shapes
    display(Markdown(f"**Original shape:** {df.shape}"))
    display(Markdown(
        f"**Cleaned shape:** {X_reduced.shape} "
        f"(removed {df.shape[1] - X_reduced.shape[1]} columns)"
    ))

display(Markdown("\n✅ **Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.**"))

=== Imputation dataset 1 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 2 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 3 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 4 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 5 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.

Dummify

A structured preprocessing pipeline was implemented prior to modeling. Ordered categorical variables (e.g., housing status, educational attainment, clinical evaluations, and substance use frequency) were manually mapped to numeric scales reflecting their natural ordering. For nominal categorical variables, prespecified reference categories were enforced to ensure consistent baseline comparisons across imputations. All remaining categorical predictors were then converted to dummy variables using one-hot encoding with the first category dropped to prevent multicollinearity. The procedure was applied consistently across all imputed datasets to ensure harmonized model inputs.

Code
import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype

def preprocess_features_robust(df):
    df_clean = df.copy()

    # ---------------------------------------------------------
    # 1. Ordinal encoding (your existing code)
    # ---------------------------------------------------------
    ordered_mappings = {
        # --- NEW: Housing & Urbanicity ---
        "tenure_status_household": {
            "illegal settlement": 4,                       # Situación Calle
            "stays temporarily with a relative": 3,        # Allegado
            "others": 2,                                   # En pensión / Otros
            "renting": 1,                                  # Arrendando
            "owner/transferred dwellings/pays dividends": 0 # Vivienda Propia
        },
        "urbanicity_cat": {
            "1.Rural": 2,
            "2.Mixed": 1,
            "3.Urban": 0
        },

        # --- Clinical Evaluations (Minimo -> Intermedio -> Alto) ---
        "evaluacindelprocesoteraputico": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_consumo":      {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fam":          {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_relinterp":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_ocupacion":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_sm":           {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fisica":       {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_transgnorma":  {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},

        # --- Frequency (Less freq -> More freq) ---
        "prim_sub_freq_rec": {
            "1.≤1 day/wk": 0,
            "2.2–6 days/wk": 1,
            "3.Daily": 2
        },

        # --- Education (Less -> More) ---
        "ed_attainment_corr": {
            "3-Completed primary school or less": 2,
            "2-Completed high school or less": 1,
            "1-More than high school": 0
        }
    }

    for col, mapping in ordered_mappings.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            df_clean[col] = df_clean[col].map(mapping)

            n_missing = df_clean[col].isnull().sum()
            if n_missing > 0:
                if n_missing == len(df_clean):
                    print(f"⚠️ WARNING: Mapping failed completely for '{col}'.")
                mode_val = df_clean[col].mode()[0]
                df_clean[col] = df_clean[col].fillna(mode_val)

    # ---------------------------------------------------------
    # 2. FORCE reference categories for dummies
    # ---------------------------------------------------------
    dummy_reference = {
        "sex_rec": "man",
        "plan_type_corr": "ambulatory",
        "marital_status_rec": "married/cohabiting",
        "cohabitation": "alone",
        "sub_dep_icd10_status": "hazardous consumption",
        "tr_outcome": "completion",
        "adm_motive": "spontaneous consultation",
        "tipo_de_vivienda_rec2": "formal housing",
        "plan_type_corr": "pg-pab",
        "occupation_condition_corr24": "employed",
        "any_violence": "0.No domestic violence/sex abuse",
        "first_sub_used": "marijuana",
        "primary_sub_mod": "marijuana",
        }

    for col, ref in dummy_reference.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            cats = df_clean[col].unique().tolist()

            if ref in cats:
                new_order = [ref] + [c for c in cats if c != ref]
                cat_type = CategoricalDtype(categories=new_order, ordered=False)
                df_clean[col] = df_clean[col].astype(cat_type)
            else:
                print(f"⚠️ Reference '{ref}' not found in {col}")

    # ---------------------------------------------------------
    # 3. One-hot encoding
    # ---------------------------------------------------------
    df_final = pd.get_dummies(df_clean, drop_first=True, dtype=float)

    return df_final

X_encoded_list_final = [preprocess_features_robust(X) for X in X_reduced_list]
X_encoded_list_final = [clean_names(X) for X in X_encoded_list_final]
Code
from IPython.display import display, Markdown

# 1. DIAGNOSTIC: Check exact string values
display(Markdown("### --- Diagnostic Check ---"))
sample_df = X_encoded_list_final[0]

if 'tenure_status_household' in sample_df.columns:
    display(Markdown("**Unique values in 'tenure_status_household':**"))
    display(Markdown(str(sample_df['tenure_status_household'].unique())))
else:
    display(Markdown("❌ 'tenure_status_household' is missing entirely from input data!"))

if 'urbanicity_cat' in sample_df.columns:
    display(Markdown("**Unique values in 'urbanicity_cat':**"))
    display(Markdown(str(sample_df['urbanicity_cat'].unique())))

if 'ed_attainment_corr' in sample_df.columns:
    display(Markdown("**Unique values in 'ed_attainment_corr':**"))
    display(Markdown(str(sample_df['ed_attainment_corr'].unique())))

— Diagnostic Check —

Unique values in ‘tenure_status_household’:

[3 0 1 2 4]

Unique values in ‘urbanicity_cat’:

[0 1 2]

Unique values in ‘ed_attainment_corr’:

[1 2 0]

We recoded first substance use so small categories are grouped into Others

Code
# Columns to combine
cols_to_group = [
    "first_sub_used_opioids",
    "first_sub_used_others",
    "first_sub_used_hallucinogens",
    "first_sub_used_inhalants",
    "first_sub_used_tranquilizers_hypnotics",
    "first_sub_used_amphetamine_type_stimulants",
]

# Loop over datasets 0–4 and modify in place
for i in range(5):
    df = X_encoded_list_final[i].copy()
    # Collapse into one dummy: if any of these == 1, mark as 1
    df["first_sub_used_other"] = df[cols_to_group].max(axis=1)
    # Drop the rest except the new combined column
    df = df.drop(columns=[c for c in cols_to_group if c != "first_sub_used_other"])
    # Replace the dataset in the original list
    X_encoded_list_final[i] = df
Code
import sys
fold_output(
    "Show first db of X_encoded_list_final (newer) glimpse",
    lambda: glimpse(X_encoded_list_final[0])
)
Show first db of X_encoded_list_final (newer) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             int64           1, 2, 1, 1, 2
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset, we fitted two regularized Cox models (one for readmission and one for death) using Coxnet, which applies elastic-net penalization with a strong LASSO component to enable variable selection. The loop fits both models on every imputation, prints basic model information, and stores all fitted models so they can later be combined or compared across imputations.

Create bins for followup (landmarks)

We extracted the observed event times and corresponding event indicators directly from the structured survival objects (y_surv_readm and y_surv_death). Using the observed event times, we constructed evaluation grids based on the 5th to 95th percentiles of the event-time distribution. These grids define standardized time points at which model performance is assessed for both readmission and mortality outcomes.

Code
import numpy as np
from IPython.display import display, Markdown

# Extract event times directly from structured arrays
event_times_readm = y_surv_readm["time"][y_surv_readm["event"]]
event_times_death = y_surv_death["time"][y_surv_death["event"]]

# Build evaluation grids (5th–95th percentiles, 50 points)
times_eval_readm = np.unique(
    np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50))
)

times_eval_death = np.unique(
    np.quantile(event_times_death, np.linspace(0.05, 0.95, 50))
)

# Display only final result
display(Markdown(
    f"**Eval times (readmission):** `{times_eval_readm[:5]}` ... `{times_eval_readm[-5:]}`"
))

display(Markdown(
    f"**Eval times (death):** `{times_eval_death[:5]}` ... `{times_eval_death[-5:]}`"
))

Eval times (readmission): [0.38709677 0.67741935 1.03225806 1.41935484 1.76666667][46.81833443 50.96030612 55.16129032 60.84848585 67.08322581]

Eval times (death): [0. 0.09677419 1.06666667 2.1691691 3.34812377][74.72632653 78.4516129 82.39472203 86.41935484 92.36311828]

Correct inmortal time bias

First, we eliminated inmortal time bias (dead patients look like without readmission).

This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.

Code
import numpy as np

# Step 3. Replicate across imputations (safe copies)
n_imputations = len(X_encoded_list_final)
y_surv_readm_list = [y_surv_readm.copy() for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death.copy() for _ in range(n_imputations)]

def correct_competing_risks(y_readm_list, y_death_list):
    corrected = []
    for y_readm, y_death in zip(y_readm_list, y_death_list):
        y_corr = y_readm.copy()

        # death observed and occurs before (or at) readmission/censoring time
        mask = (y_death["event"]) & (y_death["time"] < y_corr["time"])

        y_corr["event"][mask] = False
        y_corr["time"][mask] = y_death["time"][mask]

        corrected.append(y_corr)
    return corrected

# Step 4. Apply correction
y_surv_readm_list_corrected = correct_competing_risks(
    y_surv_readm_list,
    y_surv_death_list
)
Code
# Check type and length
type(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)

# Look at the first element
y_surv_readm_list_corrected[0][:5]   # first 5 rows
array([(False, 68.96774194), ( True,  7.        ), ( True, 13.25806452),
       ( True,  5.        ), ( True,  7.35483871)],
      dtype=[('event', '?'), ('time', '<f8')])
Code
from IPython.display import display, HTML
import html

def nb_print(*args, sep=" "):
    msg = sep.join(str(a) for a in args)
    display(HTML(f"<pre style='margin:0'>{html.escape(msg)}</pre>"))

The fully preprocessed and encoded feature matrices were renamed from X_encoded_list_final to imputations_list_mar26 to reflect the finalized February 2026 analytic version of the imputed datasets.

This object contains the harmonized, ordinal-encoded, and one-hot encoded predictor matrices for all five imputations and will serve as the definitive input for subsequent modeling procedures.

Code
imputations_list_mar26 = X_encoded_list_final
del X_encoded_list_final
Code
import numpy as np
import pandas as pd
from IPython.display import display, Markdown

# ── Build exclusion mask (same for all imputations, based on imputation 0) ──
df0 = imputations_list_mar26[0]

# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7
mask_adm_death = (
    (df0["tr_outcome_adm_discharge_adm_reasons"] == 1)
    & (y_surv_death["event"] == True)
    & (y_surv_death["time"] <= 0.23)
)

# Condition 2: tr_outcome_other == 1 (any time)
mask_other = df0["tr_outcome_other"] == 1

# Combined exclusion mask
exclude = mask_adm_death | mask_other
keep = ~exclude

# ── Report ──
n_total = len(df0)
n_excl_adm = mask_adm_death.sum()
n_excl_other = mask_other.sum()
n_excl_both = (mask_adm_death & mask_other).sum()
n_excl_total = exclude.sum()
n_remaining = keep.sum()

report = f"""### Exclusion Report

| Criterion | n excluded |
|---|---:|
| `tr_outcome_adm_discharge_adm_reasons == 1` & death time ≤ 7 days | {n_excl_adm} |
| `tr_outcome_other == 1` (any time) | {n_excl_other} |
| Both criteria (overlap) | {n_excl_both} |
| **Total unique excluded** | **{n_excl_total}** |
| **Remaining observations** | **{n_remaining}** / {n_total} |
"""
display(Markdown(report))

# ── Apply filter to all imputation lists and outcome arrays ──
imputations_list_mar26 = [df.loc[keep].reset_index(drop=True) for df in imputations_list_mar26]
y_surv_readm_list_mar26 = [y[keep] for y in y_surv_readm_list_corrected]
y_surv_death_list_mar26 = [y[keep] for y in y_surv_death_list]
y_surv_readm_list_corrected_mar26 = [y[keep] for y in y_surv_readm_list_corrected]

# Single (non-list) outcome arrays for convenience
y_surv_readm_mar26 = y_surv_readm[keep]
y_surv_death_mar26 = y_surv_death[keep]

# Rebuild eval time grids on the filtered data
event_times_readm_mar26 = y_surv_readm_mar26["time"][y_surv_readm_mar26["event"]]
event_times_death_mar26 = y_surv_death_mar26["time"][y_surv_death_mar26["event"]]

times_eval_readm_mar26 = np.unique(
    np.quantile(event_times_readm_mar26, np.linspace(0.05, 0.95, 50))
)
times_eval_death_mar26 = np.unique(
    np.quantile(event_times_death_mar26, np.linspace(0.05, 0.95, 50))
)

Exclusion Report

Criterion n excluded
tr_outcome_adm_discharge_adm_reasons == 1 & death time ≤ 7 days 137
tr_outcome_other == 1 (any time) 215
Both criteria (overlap) 0
Total unique excluded 352
Remaining observations 88152 / 88504

Train / test split (80/20)

  1. Sets a fixed random seed to make the 80/20 split exactly reproducible.

  2. Verifies required datasets exist (features and survival outcomes) before doing anything.

  3. Creates a “death-corrected” outcome list if it was not already available.

  4. Derives stratification labels from treatment plan and completion categories plus readmission/death events.

  5. Uses a step-down strategy if some strata are too rare, simplifying the stratification to keep it feasible.

  6. Caches a “full snapshot” of all imputations and outcomes so reruns don’t silently change the split.

  7. Checks row alignment so every imputation and every outcome has the same number of observations.

  8. Optionally checks stability across imputations for plan/completion columns (should not vary much).

  9. Loads split indices from disk when available, ensuring the exact same train/test split across sessions.

  10. Builds train/test datasets consistently for all imputations, then runs strict diagnostics to confirm balance.

Code
#@title 🧪 / 🎓 Reproducible 80/20 split before ML (integrated + idempotent + persisted)
# Stratification hierarchy:
#   1) plan + completion + readm_event + death_event
#   2) mixed fallback for rare full strata (<2) -> plan + readm + death
#   3) full fallback -> plan + readm + death

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from IPython.display import display, Markdown

SEED = 2125
TEST_SIZE = 0.20
FORCE_RESPLIT = False          # True to force new split
STRICT_SPLIT_CHECKS = True     # CI-style hard checks
MAX_EVENT_GAP = 0.01           # 1% tolerance
PERSIST_SPLIT_INDICES = True

# --- Project-root anchored paths ---
from pathlib import Path

def find_project_root(markers=("AGENTS.md", ".git", "environment.yml")):
    try:
        cur = Path.cwd().resolve()
    except OSError as e:
        raise RuntimeError(
            "Invalid working directory. Run this notebook from inside the project folder."
        ) from e

    for p in (cur, *cur.parents):
        if any((p / m).exists() for m in markers):
            return p

    raise RuntimeError(
        f"Could not locate project root starting from {cur}. "
        f"Expected one of markers: {markers}."
    )

PROJECT_ROOT = find_project_root()
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_FILE = OUT_DIR / f"deepsurv_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.npz"

def nb_print_md(msg):
    display(Markdown(str(msg)))

nb_print_md(f"**Project root:** `{PROJECT_ROOT}`")


# ---------- Requirements ----------
required = [
    "imputations_list_mar26",
    "y_surv_readm_list_mar26",
    "y_surv_readm_list_corrected_mar26",
    "y_surv_death_list_mar26",
]
missing = [v for v in required if v not in globals()]
if missing:
    raise ValueError(f"Missing required objects: {missing}")

if "y_surv_death_list_corrected_mar26" not in globals():
    y_surv_death_list_corrected_mar26 = [y.copy() for y in y_surv_death_list_mar26]

# ---------- Helpers ----------
def get_plan_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "plan_type_corr_pg_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "plan_type_corr_m_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "plan_type_corr_m_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    return labels

def get_completion_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "tr_outcome_referral" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_referral"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "tr_outcome_dropout" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_dropout"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "tr_outcome_adm_discharge_rule_violation_undet" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_rule_violation_undet"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "tr_outcome_adm_discharge_adm_reasons" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_adm_reasons"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    if "tr_outcome_other" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_other"], errors="coerce").fillna(0).to_numpy() == 1] = 5
    return labels

def build_strata(X0, y_readm0, y_death0):
    """
    Build stratification labels with progressive fallback:

    1) full: plan + completion + readmission_event + death_event
    2) mixed: only rare full strata (<2 rows) are replaced by fallback labels
       (plan + readmission_event + death_event)
    3) fallback: plan + readmission_event + death_event for all rows

    Returns:
        strata (np.ndarray), mode (str), readm_evt (np.ndarray), death_evt (np.ndarray)
    """
    plan = get_plan_labels(X0)
    comp = get_completion_labels(X0)
    readm_evt = y_readm0["event"].astype(int)
    death_evt = y_death0["event"].astype(int)

    full = pd.Series(plan.astype(str) + "_" + comp.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    if full.value_counts().min() >= 2:
        return full.to_numpy(), "full(plan+completion+readm+death)", readm_evt, death_evt

    fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    mixed = full.copy()
    rare = mixed.map(mixed.value_counts()) < 2
    mixed[rare] = fb[rare]
    if mixed.value_counts().min() >= 2:
        return mixed.to_numpy(), "mixed(rare->plan+readm+death)", readm_evt, death_evt

    if fb.value_counts().min() >= 2:
        return fb.to_numpy(), "fallback(plan+readm+death)", readm_evt, death_evt

    raise ValueError("Could not build stratification labels with >=2 rows per stratum.")

def split_df_list(df_list, tr_idx, te_idx):
    tr = [df.iloc[tr_idx].reset_index(drop=True).copy() for df in df_list]
    te = [df.iloc[te_idx].reset_index(drop=True).copy() for df in df_list]
    return tr, te

def split_surv_list(y_list, tr_idx, te_idx):
    tr = [y[tr_idx].copy() for y in y_list]
    te = [y[te_idx].copy() for y in y_list]
    return tr, te

# ---------- Cache full data once (idempotent re-runs) ----------
if "_split_cache_death_mar26" not in globals():
    _split_cache_death_mar26 = {}
cache = _split_cache_death_mar26

if FORCE_RESPLIT:
    cache.pop("idx", None)

if FORCE_RESPLIT or "full" not in cache:
    cache["full"] = {
        "X": [df.reset_index(drop=True).copy() for df in imputations_list_mar26],
        "y_readm": [y.copy() for y in y_surv_readm_list_mar26],
        "y_readm_corr": [y.copy() for y in y_surv_readm_list_corrected_mar26],
        "y_death": [y.copy() for y in y_surv_death_list_mar26],
        "y_death_corr": [y.copy() for y in y_surv_death_list_corrected_mar26],
    }

full = cache["full"]

# ---------- Consistency checks ----------
n_imp = len(full["X"])
n = len(full["X"][0])

if any(len(df) != n for df in full["X"]):
    raise ValueError("Row mismatch inside full X list.")

for name, obj in [
    ("y_readm", full["y_readm"]),
    ("y_readm_corr", full["y_readm_corr"]),
    ("y_death", full["y_death"]),
    ("y_death_corr", full["y_death_corr"]),
]:
    if len(obj) != n_imp:
        raise ValueError(f"{name} length ({len(obj)}) != n_imputations ({n_imp})")
    if any(len(y) != n for y in obj):
        raise ValueError(f"Row mismatch between X and {name}.")

# ---------- Optional diagnostic: plan/completion consistency across imputations ----------
plan_comp_cols = [
    c for c in [
        "plan_type_corr_pg_pr",
        "plan_type_corr_m_pr",
        "plan_type_corr_pg_pai",
        "plan_type_corr_m_pai",
        "tr_outcome_referral",
        "tr_outcome_dropout",
        "tr_outcome_adm_discharge_rule_violation_undet",
        "tr_outcome_adm_discharge_adm_reasons"#,
        #"tr_outcome_other", #2026-03-26: Excluded from consistency check since it was used as an exclusion criterion and thus may differ by design across imputations
    ] if c in full["X"][0].columns
]

max_diff_rows = 0
if plan_comp_cols:
    base_pc = full["X"][0][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
    for i in range(1, n_imp):
        cur_pc = full["X"][i][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
        diff_rows = int((base_pc != cur_pc).any(axis=1).sum())
        max_diff_rows = max(max_diff_rows, diff_rows)

# ---------- Try loading indices from disk ----------
loaded_from_disk = False
if PERSIST_SPLIT_INDICES and (not FORCE_RESPLIT) and SPLIT_FILE.exists() and ("idx" not in cache):
    z = np.load(SPLIT_FILE, allow_pickle=False)
    tr = z["train_idx"].astype(int)
    te = z["test_idx"].astype(int)
    n_disk = int(z["n_full"][0]) if "n_full" in z else n
    if n_disk == n and tr.max() < n and te.max() < n:
        cache["idx"] = (np.sort(tr), np.sort(te))
        cache["strat_mode"] = str(z["strat_mode"][0]) if "strat_mode" in z else "loaded_from_disk"
        loaded_from_disk = True

# ---------- Compute or reuse split indices ----------
if FORCE_RESPLIT or "idx" not in cache:
    strata_used, strat_mode, readm_evt_all, death_evt_all = build_strata(
        full["X"][0], full["y_readm"][0], full["y_death"][0]
    )
    idx = np.arange(n)
    train_idx, test_idx = train_test_split(
        idx, test_size=TEST_SIZE, random_state=SEED, shuffle=True, stratify=strata_used
    )
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    cache["idx"] = (train_idx, test_idx)
    cache["strat_mode"] = strat_mode

    if PERSIST_SPLIT_INDICES:
        np.savez_compressed(
            SPLIT_FILE,
            train_idx=train_idx,
            test_idx=test_idx,
            n_full=np.array([n], dtype=int),
            seed=np.array([SEED], dtype=int),
            test_size=np.array([TEST_SIZE], dtype=float),
            strat_mode=np.array([strat_mode], dtype="U64"),
        )
else:
    train_idx, test_idx = cache["idx"]
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    readm_evt_all = full["y_readm"][0]["event"].astype(int)
    death_evt_all = full["y_death"][0]["event"].astype(int)

# ---------- Build train/test from full snapshot every run ----------
imputations_list_mar26_train, imputations_list_mar26_test = split_df_list(full["X"], train_idx, test_idx)

y_surv_readm_list_train, y_surv_readm_list_test = split_surv_list(full["y_readm"], train_idx, test_idx)
y_surv_readm_list_corrected_train, y_surv_readm_list_corrected_test = split_surv_list(full["y_readm_corr"], train_idx, test_idx)

y_surv_death_list_train, y_surv_death_list_test = split_surv_list(full["y_death"], train_idx, test_idx)
y_surv_death_list_corrected_train, y_surv_death_list_corrected_test = split_surv_list(full["y_death_corr"], train_idx, test_idx)

# Downstream code uses TRAIN only
imputations_list_mar26 = imputations_list_mar26_train
y_surv_readm_list = y_surv_readm_list_train
y_surv_readm_list_corrected = y_surv_readm_list_corrected_train
y_surv_death_list = y_surv_death_list_train
y_surv_death_list_corrected = y_surv_death_list_corrected_train

# ---------- Diagnostics + strict checks ----------
strata_diag, strat_mode_diag, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])
sdiag = pd.Series(strata_diag)
train_strata = set(sdiag.iloc[train_idx].unique())
test_strata = set(sdiag.iloc[test_idx].unique())
missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

readm_gap = abs(readm_evt_all[train_idx].mean() - readm_evt_all[test_idx].mean())
death_gap = abs(death_evt_all[train_idx].mean() - death_evt_all[test_idx].mean())

# full-strata rarity report (before fallback)
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)
vc_full = strata_full.value_counts()
rare_rows = int((strata_full.map(vc_full) < 2).sum())

if STRICT_SPLIT_CHECKS:
    assert len(np.intersect1d(train_idx, test_idx)) == 0, "Train/Test index overlap detected."
    assert (len(train_idx) + len(test_idx)) == n, "Train/Test sizes do not sum to n."
    assert len(missing_in_test) == 0, f"Strata missing in test: {missing_in_test}"
    assert len(missing_in_train) == 0, f"Strata missing in train: {missing_in_train}"
    assert readm_gap < MAX_EVENT_GAP, f"Readmission rate imbalance > {MAX_EVENT_GAP:.0%} (gap={readm_gap:.4f})"
    assert death_gap < MAX_EVENT_GAP, f"Death rate imbalance > {MAX_EVENT_GAP:.0%} (gap={death_gap:.4f})"

# ---------- Summary ----------
nb_print_md(f"**Loaded indices from disk:** `{loaded_from_disk}`")
nb_print_md(f"**Split file:** `{SPLIT_FILE}`")
nb_print_md(f"**Split mode used:** `{cache.get('strat_mode', strat_mode_diag)}`")
nb_print_md(f"**Plan/completion diff rows across imputations (max vs imp0):** `{max_diff_rows}`")
nb_print_md(f"**Full strata count:** `{vc_full.size}` | **Min full stratum size:** `{int(vc_full.min())}` | **Rows in rare full strata (<2):** `{rare_rows}`")
nb_print_md(f"**Train/Test sizes:** `{len(train_idx)}` ({len(train_idx)/n:.1%}) / `{len(test_idx)}` ({len(test_idx)/n:.1%})")
nb_print_md(
    "**Readmission rate all/train/test:** "
    f"`{readm_evt_all.mean():.3%}` / `{readm_evt_all[train_idx].mean():.3%}` / `{readm_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    "**Death rate all/train/test:** "
    f"`{death_evt_all.mean():.3%}` / `{death_evt_all[train_idx].mean():.3%}` / `{death_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    f"**Strata in train/test:** `{len(train_strata)}` / `{len(test_strata)}` | "
    f"**Missing train→test:** `{len(missing_in_test)}` | **Missing test→train:** `{len(missing_in_train)}`"
)

Project root: G:\My Drive\Alvacast\SISTRAT 2023\dh

Loaded indices from disk: True

Split file: G:\My Drive\Alvacast\SISTRAT 2023\dh\_out\deepsurv_split_seed2125_test20_mar26.npz

Split mode used: fallback(plan+readm+death)

Plan/completion diff rows across imputations (max vs imp0): 0

Full strata count: 95 | Min full stratum size: 1 | Rows in rare full strata (<2): 4

Train/Test sizes: 70521 (80.0%) / 17631 (20.0%)

Readmission rate all/train/test: 21.622% / 21.621% / 21.627%

Death rate all/train/test: 4.310% / 4.309% / 4.311%

Strata in train/test: 20 / 20 | Missing train→test: 0 | Missing test→train: 0

Code
# counts per stratum in train/test
train_counts = sdiag.iloc[train_idx].value_counts()
test_counts  = sdiag.iloc[test_idx].value_counts()

min_train = int(train_counts.min())
min_test  = int(test_counts.min())

nb_print_md(f"**Min stratum count in TRAIN (used strata):** `{min_train}`")
nb_print_md(f"**Min stratum count in TEST (used strata):** `{min_test}`")

# strata that got 0 in test or 0 in train
zero_in_test = sorted(set(train_counts.index) - set(test_counts.index))
zero_in_train = sorted(set(test_counts.index) - set(train_counts.index))

nb_print_md(f"**Strata with 0 in TEST:** `{len(zero_in_test)}`")
nb_print_md(f"**Strata with 0 in TRAIN:** `{len(zero_in_train)}`")

# show examples with their full-data counts
if len(zero_in_test) > 0:
    ex = zero_in_test[:10]
    nb_print_md(f"**Examples 0 in TEST (up to 10):** `{ex}`")
    nb_print_md(f"**Full-data counts:** `{[int(sdiag.value_counts()[k]) for k in ex]}`")

Min stratum count in TRAIN (used strata): 18

Min stratum count in TEST (used strata): 5

Strata with 0 in TEST: 0

Strata with 0 in TRAIN: 0

Code
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)

vc = strata_full.value_counts()
display(Markdown(f"**# full strata:** `{vc.size}`"))
display(Markdown(f"**Min stratum size (full):** `{int(vc.min())}`"))
display(Markdown(f"**# strata with count < 2:** `{int((vc < 2).sum())}`"))

# full strata: 95

Min stratum size (full): 1

# strata with count < 2: 4

Code
plan = get_plan_labels(full["X"][0])
readm_evt = full["y_readm"][0]["event"].astype(int)
death_evt = full["y_death"][0]["event"].astype(int)

fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))

rare_mask = strata_full.map(strata_full.value_counts()) < 2
n_rare = int(rare_mask.sum())

display(Markdown(f"**Rows in rare full-strata (<2):** `{n_rare}`"))
if n_rare > 0:
    display(Markdown(
        f"**Rare rows proportion:** `{n_rare/len(strata_full):.3%}`"
    ))

Rows in rare full-strata (<2): 4

Rare rows proportion: 0.005%

Code
# Use the actual stratification mode that was used to split
strata_used, strat_mode, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])

s = pd.Series(strata_used)
train_strata = set(s.iloc[train_idx].unique())
test_strata = set(s.iloc[test_idx].unique())

missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

display(Markdown(f"**Strata used:** `{strat_mode}`"))
display(Markdown(f"**# strata in train:** `{len(train_strata)}` | **# strata in test:** `{len(test_strata)}`"))
display(Markdown(f"**Strata present in train but missing in test:** `{len(missing_in_test)}`"))
display(Markdown(f"**Strata present in test but missing in train:** `{len(missing_in_train)}`"))

Strata used: fallback(plan+readm+death)

# strata in train: 20 | # strata in test: 20

Strata present in train but missing in test: 0

Strata present in test but missing in train: 0

Code
from pathlib import Path
import pandas as pd
import numpy as np
from IPython.display import display, Markdown

PROJECT_ROOT = find_project_root()   # no hardcoded absolute path
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_PARQUET = OUT_DIR / f"deepsurv_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"

split_df = pd.DataFrame({
    "row_id": np.arange(n),
    "is_train": np.isin(np.arange(n), train_idx)
})

split_df.to_parquet(SPLIT_PARQUET, index=False)

display(Markdown(f"**Project root:** `{PROJECT_ROOT}`"))
display(Markdown(f"**Saved split to:** `{SPLIT_PARQUET}`"))

Project root: G:\My Drive\Alvacast\SISTRAT 2023\dh

Saved split to: G:\My Drive\Alvacast\SISTRAT 2023\dh\_out\deepsurv_split_seed2125_test20_mar26.parquet

Code
import pandas as pd
import numpy as np
from pathlib import Path

SEED = 2125
TEST_SIZE = 0.20

# Use the first imputation (complete data)
X_full = full["X"][0]
y_death_full = full["y_death"][0]

# Find admission age column
age_col = 'adm_age_rec3' if 'adm_age_rec3' in X_full.columns else \
          [c for c in X_full.columns if 'adm_age' in c][0]

# Create the 4-column split file
split_export = pd.DataFrame({
    'row_id': np.arange(1, len(X_full) + 1),  # 1-based for R
    'is_train': [i in train_idx for i in range(len(X_full))],
    'death_time_from_disch_m': np.round(y_death_full['time'], 2),
    'adm_age_rec3': X_full[age_col]
})

out_dir = PROJECT_ROOT / "_out"
out_dir.mkdir(exist_ok=True)

# Export
fname = out_dir / f"comb_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"
split_export.to_parquet(fname, index=False)

nb_print(f"Exported: {fname}")
nb_print(f"Total: {len(split_export)} rows")
nb_print(f"Train: {split_export['is_train'].sum()} ({100*split_export['is_train'].mean():.1f}%)")
nb_print(f"Test: {(~split_export['is_train']).sum()} ({100*(~split_export['is_train']).mean():.1f}%)")
nb_print(f"\nFirst 5 rows:")
nb_print(split_export.head())
Exported: G:\My Drive\Alvacast\SISTRAT 2023\dh\_out\comb_split_seed2125_test20_mar26.parquet
Total: 88152 rows
Train: 70521 (80.0%)
Test: 17631 (20.0%)
First 5 rows:
   row_id  is_train  death_time_from_disch_m  adm_age_rec3
0       1      True                    68.97         31.53
1       2      True                    81.32         20.61
2       3      True                   116.74         42.52
3       4     False                    91.97         60.61
4       5      True                    31.03         45.08
Code
df0 = imputations_list_mar26[0]

# Calculate exactly what you need
mean_age = df0["adm_age_rec3"].mean()
count_foreign = (df0["national_foreign"] == 1).sum()

# Print results
nb_print(f"Mean of adm_age_rec3: {mean_age:.4f}")
nb_print(f"Count of national_foreign == 1: {count_foreign}")
Mean of adm_age_rec3: 35.7256
Count of national_foreign == 1: 453

We cleaned our environment safely so that:

  • Old models

  • Temporary objects

  • Large intermediate datasets

do not interfere with the next modeling block.

Code
# Safe cleanup before Readmission XGBoost blocks
import types
import gc

# Ensure logger exists (some target cells expect it)
if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Compatibility: one Optuna/Bootstrap block checks jan26 naming
#if "imputations_list_jan26" not in globals() and "imputations_list_mar26" in globals():
#    imputations_list_jan26 = imputations_list_mar26

KEEP = {
    "nb_print", "study",
    "imputations_list_mar26", #"imputations_list",#"imputations_list_jan26"
    "X_train", "y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list",
    # Optional plot config objects:
    "plt", "sns", "matplotlib", "mpl", "rcParams", "PROJECT_ROOT"
}

# ensure both variants are kept
KEEP.update({
    "y_surv_readm_list_corrected_mar26", "y_surv_readm_list_corrected",
    "y_surv_readm_list_mar26", "y_surv_death_list_mar26"
})

# after cleanup, sanity-check alignment before tuning
if "imputations_list_mar26" in globals() and "y_surv_readm_list_corrected" in globals():
    assert len(imputations_list_mar26[0]) == len(y_surv_readm_list_corrected[0]), \
        f"Row mismatch: X={len(imputations_list_mar26[0])}, y={len(y_surv_readm_list_corrected[0])}"
        
for name, obj in list(globals().items()):
    if name in KEEP or name.startswith("_"):
        continue
    if isinstance(obj, types.ModuleType):   # keep imports
        continue
    if callable(obj):                        # keep functions/classes
        continue
    del globals()[name]

gc.collect()

required = ["y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list"]
missing = [x for x in required if x not in globals()]
nb_print("Missing required objects:", missing)
Missing required objects: []

PyCox

We tune DeepSurv cause-specific Cox models for 1–5 year death and readmission via Optuna TPE with 5-fold stratified CV.

  1. Optuna TPE replaces grid search (100 trials).
  2. Cause-specific Cox handles competing risks.
  3. Separate models fit for death and readmission.
  4. Risk evaluated at 3, 6, 12, 36, 60 months.
  5. IPCW C-index + IBS combined into one metric.
  6. Metric: √((1−C)² + IBS²), lower is better.
  7. Horizon weights emphasize 1–5 year outcomes.
  8. Wide search: 1–4 layers, lr/dropout/decay.
  9. Re-seeded per trial for full reproducibility.
  10. Early stopping + patience tuned per trial.
Code
# Ver cuántas imputaciones hay y filas de cada una
nb_print(f"Imputations: {len(imputations_list_mar26)}")
for i, df in enumerate(imputations_list_mar26):
    nb_print(f"Imputation{i}: {df.shape[0]} rows × {df.shape[1]} columns")
Imputations: 5
Imputation0: 70521 rows × 56 columns
Imputation1: 70521 rows × 56 columns
Imputation2: 70521 rows × 56 columns
Imputation3: 70521 rows × 56 columns
Imputation4: 70521 rows × 56 columns
Code
#@title ⚡ Step 1: DeepSurv (Cause-Specific CoxPH) Tuning – Optuna 100 trials, 5-Fold, Multi-Horizon

import gc
import time
import warnings
import numpy as np
import pandas as pd
import torch
import torchtuples as tt
import random
import os
import optuna
from datetime import datetime
from pycox.models import CoxPH
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_score
import tempfile
from contextlib import contextmanager

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_HORIZONS = [3, 6, 12, 36, 60]
HORIZON_WEIGHTS = {3: 0.125, 6: 0.125, 12: 0.25, 36: 0.25, 60: 0.25}
N_FOLDS = 5
N_TRIALS = 100
SEED = 2125
COOLDOWN_SEC = 2

warnings.filterwarnings("ignore", message=".*weights_only=False.*")
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*Mean of empty slice.*")

@contextmanager
def local_workdir():
    """Temporarily switch CWD to a local temp dir to avoid Google Drive file locks."""
    original = os.getcwd()
    with tempfile.TemporaryDirectory() as tmpdir:
        os.chdir(tmpdir)
        try:
            yield tmpdir
        finally:
            os.chdir(original)

def set_seed(seed=SEED):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

def prepare_stratified_data(df_idx=0):
    df = imputations_list_mar26[df_idx]
    y_d = y_surv_death_list[df_idx]
    y_r = y_surv_readm_list[df_idx]

    t_d = y_d['time'].values if hasattr(y_d['time'], 'values') else y_d['time']
    e_d_raw = y_d['event'].values if hasattr(y_d['event'], 'values') else y_d['event']
    e_r_raw = y_r['event'].values if hasattr(y_r['event'], 'values') else y_r['event']

    events = np.zeros(len(df), dtype=int)
    times = t_d.copy().astype('float32')
    e_d = e_d_raw.astype(bool)
    e_r = e_r_raw.astype(bool)

    events[e_r] = 2
    events[e_d] = 1

    plan_cols = ['plan_type_corr_pg_pr', 'plan_type_corr_m_pr',
                 'plan_type_corr_pg_pai', 'plan_type_corr_m_pai']
    available_plans = [c for c in plan_cols if c in df.columns]

    plan_category = np.zeros(len(df), dtype=int)
    for i, col in enumerate(available_plans, 1):
        plan_category[df[col] == 1] = i

    strat_labels = (events * 10) + plan_category
    return df, events, times, strat_labels

def to_structured(times, events_bool):
    arr = np.zeros(len(times), dtype=[('e', bool), ('t', float)])
    arr['e'] = events_bool.astype(bool)
    arr['t'] = times.astype(float)
    return arr

def fit_deepsurv(X_train_s, t_train, e_train, X_val_s, t_val, e_val, params):
    net = tt.practical.MLPVanilla(
        in_features=X_train_s.shape[1],
        num_nodes=params['nodes'],
        out_features=1,
        batch_norm=True,
        dropout=params['dropout'],
        output_bias=False
    )
    model = CoxPH(net, tt.optim.Adam)
    model.set_device(DEVICE)
    model.optimizer.set_lr(params['lr'])
    model.optimizer.param_groups[0]['weight_decay'] = params['weight_decay']

    model.fit(
        X_train_s,
        (t_train.astype('float32'), e_train.astype('int64')),
        batch_size=params['batch_size'],
        epochs=params['epochs'],
        callbacks=[tt.callbacks.EarlyStopping(patience=params['patience'])],
        verbose=False,
        val_data=(X_val_s, (t_val.astype('float32'), e_val.astype('int64')))
    )
    model.compute_baseline_hazards()
    return model

def risk_at_horizon(model, X, horizon):
    surv_df = model.predict_surv_df(X)
    grid = surv_df.index.values
    idx = np.searchsorted(grid, horizon, side='right') - 1
    idx = int(np.clip(idx, 0, len(grid) - 1))
    return 1.0 - surv_df.iloc[idx].values

def compute_ibs(model, X_val_s, y_train_struct, y_val_struct, horizons):
    """Compute IBS over the evaluation horizons for one cause-specific model."""
    surv_df = model.predict_surv_df(X_val_s)
    grid = surv_df.index.values

    t_min = max(horizons[0], grid[0] + 1e-4)
    t_max = min(horizons[-1], grid[-1] - 1e-4)
    if t_min >= t_max:
        return np.nan

    eval_times = np.array([h for h in horizons if t_min <= h <= t_max])
    if len(eval_times) < 2:
        return np.nan

    surv_matrix = np.zeros((X_val_s.shape[0], len(eval_times)))
    for j, h in enumerate(eval_times):
        idx = np.searchsorted(grid, h, side='right') - 1
        idx = int(np.clip(idx, 0, len(grid) - 1))
        surv_matrix[:, j] = surv_df.iloc[idx].values

    try:
        ibs = integrated_brier_score(y_train_struct, y_val_struct, surv_matrix, eval_times)
    except Exception:
        ibs = np.nan
    return ibs

# --- Prepare data ---
X_all, events_all, times_all, strat_labels = prepare_stratified_data()
start_time = time.time()

# --- Optuna objective ---
def objective(trial):
    set_seed(SEED)

    # Wide search space
    n_layers = trial.suggest_int('n_layers', 1, 4)
    nodes = []
    for layer_i in range(n_layers):
        nodes.append(trial.suggest_categorical(f'n_units_l{layer_i}', [64, 128, 256, 512]))

    params = {
        'lr': trial.suggest_float('lr', 1e-5, 1e-2, log=True),
        'weight_decay': trial.suggest_float('weight_decay', 1e-6, 1e-2, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [256, 512, 1024, 2048]),
        'dropout': trial.suggest_float('dropout', 0.05, 0.6),
        'nodes': nodes,
        'epochs': trial.suggest_int('epochs', 50, 300), #2026-03-31:moved from 200 to 500 for more thorough tuning
        'patience': trial.suggest_int('patience', 8, 25),
    }

    skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=2125)
    fold_scores = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_all, strat_labels)):
        torch.cuda.empty_cache()
        gc.collect()

        model_d = None  # for cleanup in case of early failure
        model_r = None  # for cleanup in case of early failure

        X_train, X_val = X_all.iloc[train_idx], X_all.iloc[val_idx]
        e_train_raw, e_val_raw = events_all[train_idx], events_all[val_idx]
        t_train, t_val = times_all[train_idx], times_all[val_idx]

        scaler = StandardScaler().fit(X_train)
        X_train_s = scaler.transform(X_train).astype('float32')
        X_val_s = scaler.transform(X_val).astype('float32')

        e_train_d = (e_train_raw == 1)
        e_val_d = (e_val_raw == 1)
        e_train_r = (e_train_raw == 2)
        e_val_r = (e_val_raw == 2)

        try:
            with local_workdir():
                model_d = fit_deepsurv(X_train_s, t_train, e_train_d, X_val_s, t_val, e_val_d, params)
                model_r = fit_deepsurv(X_train_s, t_train, e_train_r, X_val_s, t_val, e_val_r, params)

            y_tr_d = to_structured(t_train, e_train_d)
            y_va_d = to_structured(t_val, e_val_d)
            y_tr_r = to_structured(t_train, e_train_r)
            y_va_r = to_structured(t_val, e_val_r)

            # --- Per-horizon weighted C-index ---
            horizon_cindices = []
            weights_used = []
            for h in EVAL_HORIZONS:
                risk_d = risk_at_horizon(model_d, X_val_s, h)
                risk_r = risk_at_horizon(model_r, X_val_s, h)
                try:
                    c_d = concordance_index_ipcw(y_tr_d, y_va_d, risk_d, tau=h)[0]
                except Exception:
                    c_d = np.nan
                try:
                    c_r = concordance_index_ipcw(y_tr_r, y_va_r, risk_r, tau=h)[0]
                except Exception:
                    c_r = np.nan

                c_avg = np.nanmean([c_d, c_r])
                if not np.isnan(c_avg):
                    horizon_cindices.append(c_avg)
                    weights_used.append(HORIZON_WEIGHTS[h])

            if len(horizon_cindices) == 0:
                fold_scores.append(np.nan)
                del model_d, model_r
                gc.collect()
                time.sleep(COOLDOWN_SEC)
                continue

            weighted_c = np.average(horizon_cindices, weights=weights_used)

            # --- IBS for each cause ---
            ibs_d = compute_ibs(model_d, X_val_s, y_tr_d, y_va_d, EVAL_HORIZONS)
            ibs_r = compute_ibs(model_r, X_val_s, y_tr_r, y_va_r, EVAL_HORIZONS)

            ibs_vals = [v for v in [ibs_d, ibs_r] if not np.isnan(v)]
            avg_ibs = np.mean(ibs_vals) if len(ibs_vals) > 0 else np.nan

            # --- Combined metric: sqrt((1 - C)^2 + IBS^2), lower is better ---
            if not np.isnan(avg_ibs):
                combined = np.sqrt((1.0 - weighted_c) ** 2 + avg_ibs ** 2)
            else:
                combined = 1.0 - weighted_c  # fallback: just error from C-index

            fold_scores.append(combined)

        except Exception as exc:
            nb_print(f"  [FAIL trial {trial.number} fold {fold_idx}] {exc}")
            fold_scores.append(np.nan)

        del model_d, model_r
        gc.collect()
        time.sleep(COOLDOWN_SEC)

    avg_score = np.nanmean(fold_scores)
    nb_print(f"  Trial {trial.number:>3d} | combined={avg_score:.4f} | {params['nodes']} lr={params['lr']:.1e} do={params['dropout']:.2f}")
    return avg_score

# --- Run Optuna (minimize combined metric) ---
nb_print(f"⚡ Starting DeepSurv Optuna tuning: {N_TRIALS} trials, {N_FOLDS}-fold CV, device={DEVICE}")

sampler = optuna.samplers.TPESampler(seed=SEED)
study = optuna.create_study(direction='minimize', sampler=sampler,
                            study_name='deepsurv_cs_tuning')

study.optimize(objective, n_trials=N_TRIALS, show_progress_bar=True)

# --- Results ---
nb_print("\n" + "=" * 60)
nb_print(f"🏆 Best trial: #{study.best_trial.number}")
nb_print(f"   Combined metric (lower=better): {study.best_value:.4f}")
nb_print(f"   Params: {study.best_params}")
nb_print("=" * 60)

# Save all trials
trials_df = study.trials_dataframe()
trials_df = trials_df.sort_values('value', ascending=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"DS_Optuna_{N_TRIALS}t_{N_FOLDS}f_{timestamp}.csv"
trials_df.to_csv(filename, index=False)
nb_print(f"💾 Results saved to: {filename}")
nb_print(f"⏱️ Total Time: {(time.time() - start_time) / 60:.2f} min")
⚡ Starting DeepSurv Optuna tuning: 100 trials, 5-fold CV, device=cuda
[I 2026-03-16 14:56:26,789] A new study created in memory with name: deepsurv_cs_tuning

  0%|          | 0/100 [00:00<?, ?it/s]
  Trial   0 | combined=0.2649 | [512] lr=9.1e-04 do=0.28
Best trial: 0. Best value: 0.264927:   1%|          | 1/100 [01:50<3:02:10, 110.40s/it]
  Trial   1 | combined=0.2644 | [512, 128, 256, 256] lr=7.3e-04 do=0.21
Best trial: 1. Best value: 0.264445:   2%|▏         | 2/100 [06:07<5:20:50, 196.43s/it]
  Trial   2 | combined=0.2616 | [64, 256, 64, 256] lr=9.4e-03 do=0.18
Best trial: 2. Best value: 0.261556:   3%|▎         | 3/100 [09:35<5:26:40, 202.07s/it]
  Trial   3 | combined=0.2655 | [64] lr=1.5e-03 do=0.22
Best trial: 2. Best value: 0.261556:   4%|▍         | 4/100 [11:36<4:31:46, 169.86s/it]
  Trial   4 | combined=0.2664 | [128, 64] lr=4.7e-04 do=0.23
Best trial: 2. Best value: 0.261556:   5%|▌         | 5/100 [16:28<5:38:42, 213.92s/it]
  Trial   5 | combined=0.2731 | [64, 128, 64] lr=6.0e-05 do=0.18
Best trial: 2. Best value: 0.261556:   6%|▌         | 6/100 [20:46<5:58:46, 229.01s/it]
  Trial   6 | combined=0.2623 | [512, 256] lr=3.2e-04 do=0.54
Best trial: 2. Best value: 0.261556:   7%|▋         | 7/100 [23:26<5:20:02, 206.48s/it]
  Trial   7 | combined=0.2771 | [256, 256] lr=4.1e-03 do=0.15
Best trial: 2. Best value: 0.261556:   8%|▊         | 8/100 [24:56<4:19:46, 169.42s/it]
  Trial   8 | combined=0.2594 | [128, 64, 64] lr=5.1e-04 do=0.16
Best trial: 8. Best value: 0.259373:   9%|▉         | 9/100 [29:28<5:05:35, 201.49s/it]
  Trial   9 | combined=0.3128 | [64, 128] lr=1.1e-05 do=0.18
Best trial: 8. Best value: 0.259373:  10%|█         | 10/100 [33:14<5:13:17, 208.86s/it]
  Trial  10 | combined=0.2766 | [128, 64, 128] lr=1.0e-04 do=0.07
Best trial: 8. Best value: 0.259373:  11%|█         | 11/100 [37:14<5:23:54, 218.36s/it]
  Trial  11 | combined=0.2893 | [128, 512, 64, 64] lr=9.0e-03 do=0.40
Best trial: 8. Best value: 0.259373:  12%|█▏        | 12/100 [41:03<5:25:08, 221.69s/it]
  Trial  12 | combined=0.2564 | [256, 64, 64, 256] lr=4.1e-03 do=0.06
Best trial: 12. Best value: 0.256403:  13%|█▎        | 13/100 [47:15<6:27:21, 267.15s/it]
  Trial  13 | combined=0.2621 | [256, 64, 512] lr=3.1e-03 do=0.05
Best trial: 12. Best value: 0.256403:  14%|█▍        | 14/100 [52:12<6:35:54, 276.22s/it]
  Trial  14 | combined=0.2576 | [256, 64, 64, 512] lr=2.3e-04 do=0.37
Best trial: 12. Best value: 0.256403:  15%|█▌        | 15/100 [1:02:16<8:51:17, 375.03s/it]
  Trial  15 | combined=0.2638 | [256, 64, 64, 512] lr=1.2e-04 do=0.45
Best trial: 12. Best value: 0.256403:  16%|█▌        | 16/100 [1:14:00<11:03:35, 473.99s/it]
  Trial  16 | combined=0.3196 | [256, 512, 128, 128] lr=2.4e-05 do=0.37
Best trial: 12. Best value: 0.256403:  17%|█▋        | 17/100 [1:23:33<11:36:49, 503.73s/it]
  Trial  17 | combined=0.2911 | [256, 64, 512, 512] lr=2.5e-04 do=0.52
Best trial: 12. Best value: 0.256403:  18%|█▊        | 18/100 [1:28:52<10:12:55, 448.49s/it]
  Trial  18 | combined=0.2612 | [256, 64, 256] lr=1.7e-03 do=0.32
Best trial: 12. Best value: 0.256403:  19%|█▉        | 19/100 [1:30:38<7:46:34, 345.61s/it] 
  Trial  19 | combined=0.2690 | [256, 64, 64, 512] lr=4.6e-05 do=0.46
Best trial: 12. Best value: 0.256403:  20%|██        | 20/100 [1:43:03<10:20:37, 465.47s/it]
  Trial  20 | combined=0.2566 | [256, 512, 64] lr=2.0e-04 do=0.58
Best trial: 12. Best value: 0.256403:  21%|██        | 21/100 [1:53:36<11:18:58, 515.68s/it]
  Trial  21 | combined=0.2559 | [256, 512, 64] lr=1.7e-04 do=0.57
Best trial: 21. Best value: 0.255863:  22%|██▏       | 22/100 [2:04:49<12:11:48, 562.93s/it]
  Trial  22 | combined=0.2604 | [256, 512, 64] lr=1.1e-04 do=0.60
Best trial: 21. Best value: 0.255863:  23%|██▎       | 23/100 [2:17:57<13:29:05, 630.46s/it]
  Trial  23 | combined=0.2618 | [256, 512, 64] lr=1.5e-04 do=0.58
Best trial: 21. Best value: 0.255863:  24%|██▍       | 24/100 [2:28:45<13:25:17, 635.75s/it]
  Trial  24 | combined=0.2631 | [256, 512, 64] lr=4.4e-05 do=0.49
Best trial: 21. Best value: 0.255863:  25%|██▌       | 25/100 [2:45:34<15:34:42, 747.77s/it]
  Trial  25 | combined=0.2650 | [256, 512] lr=1.8e-04 do=0.56
Best trial: 21. Best value: 0.255863:  26%|██▌       | 26/100 [2:54:08<13:55:48, 677.68s/it]
  Trial  26 | combined=0.2721 | [256, 512, 128] lr=6.8e-05 do=0.44
Best trial: 21. Best value: 0.255863:  27%|██▋       | 27/100 [2:58:28<11:12:04, 552.39s/it]
  Trial  27 | combined=0.2593 | [256, 512, 512] lr=1.2e-03 do=0.50
Best trial: 21. Best value: 0.255863:  28%|██▊       | 28/100 [3:00:43<8:32:21, 426.96s/it] 
  Trial  28 | combined=0.2634 | [512, 512] lr=4.4e-04 do=0.29
Best trial: 21. Best value: 0.255863:  29%|██▉       | 29/100 [3:02:44<6:36:48, 335.33s/it]
  Trial  29 | combined=0.2666 | [512] lr=2.6e-03 do=0.42
Best trial: 21. Best value: 0.255863:  30%|███       | 30/100 [3:04:31<5:11:04, 266.63s/it]
  Trial  30 | combined=0.2662 | [256, 512, 256, 256] lr=1.1e-03 do=0.11
Best trial: 21. Best value: 0.255863:  31%|███       | 31/100 [3:08:15<4:51:59, 253.90s/it]
  Trial  31 | combined=0.2621 | [256, 64, 64, 128] lr=3.0e-04 do=0.35
Best trial: 21. Best value: 0.255863:  32%|███▏      | 32/100 [3:14:49<5:35:32, 296.06s/it]
  Trial  32 | combined=0.2579 | [256, 64, 64, 64] lr=7.8e-04 do=0.59
Best trial: 21. Best value: 0.255863:  33%|███▎      | 33/100 [3:22:57<6:34:51, 353.60s/it]
  Trial  33 | combined=0.2620 | [256, 128, 64, 256] lr=1.9e-04 do=0.53
Best trial: 21. Best value: 0.255863:  34%|███▍      | 34/100 [3:32:54<7:49:18, 426.64s/it]
  Trial  34 | combined=0.2661 | [256, 256, 64, 512] lr=8.7e-05 do=0.27
Best trial: 21. Best value: 0.255863:  35%|███▌      | 35/100 [3:41:19<8:07:39, 450.15s/it]
  Trial  35 | combined=0.2531 | [64, 128, 64] lr=5.8e-04 do=0.39
Best trial: 35. Best value: 0.253092:  36%|███▌      | 36/100 [3:48:25<7:52:16, 442.76s/it]
  Trial  36 | combined=0.2609 | [512, 128, 64] lr=5.8e-03 do=0.48
Best trial: 35. Best value: 0.253092:  37%|███▋      | 37/100 [3:54:53<7:27:45, 426.44s/it]
  Trial  37 | combined=0.2597 | [64, 128, 64] lr=5.7e-04 do=0.26
Best trial: 35. Best value: 0.253092:  38%|███▊      | 38/100 [4:00:27<6:51:59, 398.70s/it]
  Trial  38 | combined=0.2584 | [64, 128] lr=2.2e-03 do=0.55
Best trial: 35. Best value: 0.253092:  39%|███▉      | 39/100 [4:03:13<5:34:21, 328.88s/it]
  Trial  39 | combined=0.2715 | [64, 128, 64] lr=4.1e-04 do=0.11
Best trial: 35. Best value: 0.253092:  40%|████      | 40/100 [4:05:09<4:25:08, 265.15s/it]
  Trial  40 | combined=0.2639 | [64, 256] lr=6.9e-04 do=0.21
Best trial: 35. Best value: 0.253092:  41%|████      | 41/100 [4:09:26<4:18:11, 262.56s/it]
  Trial  41 | combined=0.2629 | [64] lr=2.3e-04 do=0.38
Best trial: 35. Best value: 0.253092:  42%|████▏     | 42/100 [4:17:30<5:17:57, 328.92s/it]
  Trial  42 | combined=0.2603 | [128, 64, 64, 512] lr=3.7e-04 do=0.34
Best trial: 35. Best value: 0.253092:  43%|████▎     | 43/100 [4:26:58<6:20:36, 400.63s/it]
  Trial  43 | combined=0.2732 | [64, 128, 256] lr=1.6e-04 do=0.57
Best trial: 35. Best value: 0.253092:  44%|████▍     | 44/100 [4:37:58<7:26:36, 478.51s/it]
  Trial  44 | combined=0.2610 | [256, 64, 64] lr=4.6e-03 do=0.43
Best trial: 35. Best value: 0.253092:  45%|████▌     | 45/100 [4:44:03<6:47:18, 444.34s/it]
  Trial  45 | combined=0.2600 | [128, 256, 64, 256] lr=2.8e-04 do=0.39
Best trial: 35. Best value: 0.253092:  46%|████▌     | 46/100 [4:49:45<6:12:18, 413.68s/it]
  Trial  46 | combined=0.2705 | [256, 512, 64] lr=2.7e-05 do=0.31
Best trial: 35. Best value: 0.253092:  47%|████▋     | 47/100 [5:01:49<7:27:38, 506.76s/it]
  Trial  47 | combined=0.2587 | [512, 64, 64, 64] lr=7.5e-05 do=0.41
Best trial: 35. Best value: 0.253092:  48%|████▊     | 48/100 [5:17:21<9:09:52, 634.47s/it]
  Trial  48 | combined=0.2898 | [64, 128, 128] lr=1.3e-04 do=0.51
Best trial: 35. Best value: 0.253092:  49%|████▉     | 49/100 [5:19:20<6:47:50, 479.81s/it]
  Trial  49 | combined=0.2611 | [256, 512, 64, 128] lr=6.0e-04 do=0.36
Best trial: 35. Best value: 0.253092:  50%|█████     | 50/100 [5:22:01<5:20:09, 384.18s/it]
  Trial  50 | combined=0.2725 | [256, 64, 512] lr=9.8e-03 do=0.53
Best trial: 35. Best value: 0.253092:  51%|█████     | 51/100 [5:24:17<4:12:54, 309.69s/it]
  Trial  51 | combined=0.2579 | [256, 64, 64, 64] lr=7.7e-04 do=0.59
Best trial: 35. Best value: 0.253092:  52%|█████▏    | 52/100 [5:32:15<4:48:09, 360.20s/it]
  Trial  52 | combined=0.2537 | [256, 64, 64, 64] lr=1.6e-03 do=0.58
Best trial: 35. Best value: 0.253092:  53%|█████▎    | 53/100 [5:39:53<5:05:11, 389.61s/it]
  Trial  53 | combined=0.2587 | [256, 64, 64, 64] lr=2.0e-03 do=0.56
Best trial: 35. Best value: 0.253092:  54%|█████▍    | 54/100 [5:47:11<5:09:42, 403.96s/it]
  Trial  54 | combined=0.2556 | [256, 64, 64, 64] lr=1.4e-03 do=0.48
Best trial: 35. Best value: 0.253092:  55%|█████▌    | 55/100 [5:53:28<4:56:53, 395.86s/it]
  Trial  55 | combined=0.2587 | [128, 64, 64, 64] lr=1.6e-03 do=0.48
Best trial: 35. Best value: 0.253092:  56%|█████▌    | 56/100 [6:00:54<5:01:19, 410.91s/it]
  Trial  56 | combined=0.2606 | [256, 64, 64, 64] lr=6.7e-03 do=0.57
Best trial: 35. Best value: 0.253092:  57%|█████▋    | 57/100 [6:05:06<4:20:27, 363.44s/it]
  Trial  57 | combined=0.2574 | [256, 512, 256] lr=3.1e-03 do=0.52
Best trial: 35. Best value: 0.253092:  58%|█████▊    | 58/100 [6:10:05<4:00:42, 343.87s/it]
  Trial  58 | combined=0.2567 | [256, 64, 128, 64] lr=1.3e-03 do=0.49
Best trial: 35. Best value: 0.253092:  59%|█████▉    | 59/100 [6:16:43<4:06:13, 360.33s/it]
  Trial  59 | combined=0.2609 | [256, 256, 64] lr=4.0e-03 do=0.56
Best trial: 35. Best value: 0.253092:  60%|██████    | 60/100 [6:23:13<4:06:02, 369.07s/it]
  Trial  60 | combined=0.2570 | [64, 512, 512] lr=1.0e-03 do=0.51
Best trial: 35. Best value: 0.253092:  61%|██████    | 61/100 [6:25:43<3:17:18, 303.56s/it]
  Trial  61 | combined=0.2587 | [256, 64, 128, 64] lr=1.4e-03 do=0.47
Best trial: 35. Best value: 0.253092:  62%|██████▏   | 62/100 [6:32:04<3:26:48, 326.55s/it]
  Trial  62 | combined=0.2566 | [256, 64, 128, 64] lr=1.0e-03 do=0.50
Best trial: 35. Best value: 0.253092:  63%|██████▎   | 63/100 [6:39:26<3:42:49, 361.34s/it]
  Trial  63 | combined=0.2537 | [256, 64, 128, 64] lr=9.3e-04 do=0.54
Best trial: 35. Best value: 0.253092:  64%|██████▍   | 64/100 [6:47:10<3:55:18, 392.18s/it]
  Trial  64 | combined=0.2616 | [256, 64, 128, 64] lr=2.4e-03 do=0.54
Best trial: 35. Best value: 0.253092:  65%|██████▌   | 65/100 [6:53:50<3:50:01, 394.33s/it]
  Trial  65 | combined=0.2582 | [256, 64, 64, 256] lr=5.0e-04 do=0.58
Best trial: 35. Best value: 0.253092:  66%|██████▌   | 66/100 [7:02:29<4:04:44, 431.89s/it]
  Trial  66 | combined=0.2570 | [256, 512, 64, 256] lr=1.7e-03 do=0.60
Best trial: 35. Best value: 0.253092:  67%|██████▋   | 67/100 [7:09:56<4:00:04, 436.52s/it]
  Trial  67 | combined=0.2581 | [512, 64, 128, 64] lr=8.6e-04 do=0.55
Best trial: 35. Best value: 0.253092:  68%|██████▊   | 68/100 [7:12:42<3:09:28, 355.27s/it]
  Trial  68 | combined=0.2828 | [256, 128, 64] lr=3.5e-04 do=0.52
Best trial: 35. Best value: 0.253092:  69%|██████▉   | 69/100 [7:17:27<2:52:40, 334.20s/it]
  Trial  69 | combined=0.2589 | [256, 512] lr=3.2e-03 do=0.45
Best trial: 35. Best value: 0.253092:  70%|███████   | 70/100 [7:22:10<2:39:21, 318.70s/it]
  Trial  70 | combined=0.2544 | [128, 64, 64] lr=6.3e-04 do=0.58
Best trial: 35. Best value: 0.253092:  71%|███████   | 71/100 [7:30:06<2:56:49, 365.83s/it]
  Trial  71 | combined=0.2587 | [128, 64, 64] lr=5.9e-04 do=0.58
Best trial: 35. Best value: 0.253092:  72%|███████▏  | 72/100 [7:37:48<3:04:15, 394.82s/it]
  Trial  72 | combined=0.2618 | [128, 64, 64] lr=1.2e-03 do=0.56
Best trial: 35. Best value: 0.253092:  73%|███████▎  | 73/100 [7:43:43<2:52:14, 382.76s/it]
  Trial  73 | combined=0.2570 | [128, 64, 64] lr=6.8e-04 do=0.54
Best trial: 35. Best value: 0.253092:  74%|███████▍  | 74/100 [7:49:36<2:42:02, 373.94s/it]
  Trial  74 | combined=0.2559 | [128, 64] lr=4.0e-04 do=0.58
Best trial: 35. Best value: 0.253092:  75%|███████▌  | 75/100 [7:56:40<2:42:05, 389.02s/it]
  Trial  75 | combined=0.2989 | [128, 64] lr=4.2e-04 do=0.20
Best trial: 35. Best value: 0.253092:  76%|███████▌  | 76/100 [8:00:30<2:16:32, 341.36s/it]
  Trial  76 | combined=0.2636 | [128] lr=8.4e-04 do=0.15
Best trial: 35. Best value: 0.253092:  77%|███████▋  | 77/100 [8:03:23<1:51:24, 290.62s/it]
  Trial  77 | combined=0.2554 | [128, 64] lr=5.2e-04 do=0.43
Best trial: 35. Best value: 0.253092:  78%|███████▊  | 78/100 [8:08:50<1:50:34, 301.55s/it]
  Trial  78 | combined=0.2550 | [128, 64] lr=5.0e-04 do=0.43
Best trial: 35. Best value: 0.253092:  79%|███████▉  | 79/100 [8:13:48<1:45:12, 300.62s/it]
  Trial  79 | combined=0.2545 | [128, 64] lr=5.0e-04 do=0.43
Best trial: 35. Best value: 0.253092:  80%|████████  | 80/100 [8:18:50<1:40:18, 300.91s/it]
  Trial  80 | combined=0.2571 | [128, 64] lr=5.1e-04 do=0.42
Best trial: 35. Best value: 0.253092:  81%|████████  | 81/100 [8:23:43<1:34:36, 298.75s/it]
  Trial  81 | combined=0.2628 | [128, 64] lr=3.0e-04 do=0.40
Best trial: 35. Best value: 0.253092:  82%|████████▏ | 82/100 [8:29:07<1:31:54, 306.35s/it]
  Trial  82 | combined=0.2551 | [128, 64] lr=6.5e-04 do=0.43
Best trial: 35. Best value: 0.253092:  83%|████████▎ | 83/100 [8:33:53<1:25:03, 300.19s/it]
  Trial  83 | combined=0.2555 | [128, 64] lr=6.8e-04 do=0.40
Best trial: 35. Best value: 0.253092:  84%|████████▍ | 84/100 [8:38:35<1:18:33, 294.61s/it]
  Trial  84 | combined=0.2554 | [128, 64] lr=6.6e-04 do=0.44
Best trial: 35. Best value: 0.253092:  85%|████████▌ | 85/100 [8:44:00<1:15:55, 303.73s/it]
  Trial  85 | combined=0.2557 | [128, 64] lr=9.3e-04 do=0.44
Best trial: 35. Best value: 0.253092:  86%|████████▌ | 86/100 [8:45:46<57:03, 244.56s/it]  
  Trial  86 | combined=0.2570 | [128, 64] lr=5.0e-04 do=0.42
Best trial: 35. Best value: 0.253092:  87%|████████▋ | 87/100 [8:50:17<54:38, 252.23s/it]
  Trial  87 | combined=0.2601 | [128, 64] lr=3.5e-04 do=0.35
Best trial: 35. Best value: 0.253092:  88%|████████▊ | 88/100 [8:52:57<44:57, 224.78s/it]
  Trial  88 | combined=0.2581 | [128, 64] lr=6.4e-04 do=0.46
Best trial: 35. Best value: 0.253092:  89%|████████▉ | 89/100 [8:58:04<45:42, 249.36s/it]
  Trial  89 | combined=0.2532 | [128, 128] lr=7.7e-04 do=0.38
Best trial: 35. Best value: 0.253092:  90%|█████████ | 90/100 [9:02:25<42:08, 252.89s/it]
  Trial  90 | combined=0.2649 | [128] lr=1.1e-03 do=0.38
Best trial: 35. Best value: 0.253092:  91%|█████████ | 91/100 [9:05:04<33:40, 224.55s/it]
  Trial  91 | combined=0.2567 | [128, 128] lr=7.6e-04 do=0.43
Best trial: 35. Best value: 0.253092:  92%|█████████▏| 92/100 [9:10:25<33:48, 253.59s/it]
  Trial  92 | combined=0.2601 | [128, 128] lr=5.1e-04 do=0.41
Best trial: 35. Best value: 0.253092:  93%|█████████▎| 93/100 [9:14:15<28:46, 246.63s/it]
  Trial  93 | combined=0.2590 | [128, 128] lr=2.5e-04 do=0.45
Best trial: 35. Best value: 0.253092:  94%|█████████▍| 94/100 [9:20:01<27:38, 276.42s/it]
  Trial  94 | combined=0.2542 | [128, 128] lr=8.8e-04 do=0.38
Best trial: 35. Best value: 0.253092:  95%|█████████▌| 95/100 [9:24:16<22:30, 270.00s/it]
  Trial  95 | combined=0.2527 | [128, 128] lr=9.3e-04 do=0.33
Best trial: 95. Best value: 0.252654:  96%|█████████▌| 96/100 [9:28:53<18:08, 272.06s/it]
  Trial  96 | combined=0.2567 | [128, 128] lr=2.0e-03 do=0.33
Best trial: 95. Best value: 0.252654:  97%|█████████▋| 97/100 [9:33:06<13:18, 266.19s/it]
  Trial  97 | combined=0.2902 | [128, 128] lr=8.9e-04 do=0.38
Best trial: 95. Best value: 0.252654:  98%|█████████▊| 98/100 [9:36:49<08:26, 253.43s/it]
  Trial  98 | combined=0.2564 | [64, 128] lr=1.5e-03 do=0.37
Best trial: 95. Best value: 0.252654:  99%|█████████▉| 99/100 [9:40:52<04:10, 250.24s/it]
  Trial  99 | combined=0.2570 | [128, 128] lr=1.1e-03 do=0.28
Best trial: 95. Best value: 0.252654: 100%|██████████| 100/100 [9:45:12<00:00, 351.12s/it]
============================================================
🏆 Best trial: #95
   Combined metric (lower=better): 0.2527
   Params: {'n_layers': 2, 'n_units_l0': 128, 'n_units_l1': 128, 'lr': 0.0009295223057515964, 'weight_decay': 0.0058088363839024935, 'batch_size': 256, 'dropout': 0.3273981278421983, 'epochs': 192, 'patience': 15}
============================================================
💾 Results saved to: DS_Optuna_100t_5f_20260317_0041.csv
⏱️ Total Time: 585.20 min

~10 hrs.

Code
import pandas as pd
import numpy as np

# Load the study results
df = pd.read_csv('DS_Optuna_100t_5f_20260317_0041.csv')

# Build architecture string from n_layers and n_units columns
def build_arch(row):
    layers = int(row['params_n_layers'])
    units = []
    for i in range(layers):
        col = f'params_n_units_l{i}'
        if col in row and pd.notna(row[col]):
            units.append(int(row[col]))
    return str(units).replace(" ", "")

df['Architecture'] = df.apply(build_arch, axis=1)

# Rename columns for clarity
df = df.rename(columns={
    'value': 'Score (lower=better)',
    'params_dropout': 'Dropout',
    'params_weight_decay': 'Weight Decay',
    'params_batch_size': 'Batch',
    'params_lr': 'LR',
    'params_epochs': 'Epochs',
    'params_patience': 'Patience'
})

# Select and order columns
cols = ['number', 'Score (lower=better)', 'Architecture', 'Dropout', 
        'Weight Decay', 'Batch', 'LR', 'Epochs', 'duration']
df_display = df[cols].copy()
df_display['Rank'] = df_display['Score (lower=better)'].rank(method='min').astype(int)
df_display = df_display.sort_values('Score (lower=better)')

# Show top 15
nb_print(">>> TOP 15 TRIALS")
nb_print(df_display.head(15).to_string(index=False))

# Summary statistics by architecture
nb_print("\n>>> PERFORMANCE BY ARCHITECTURE")
arch_summary = df_display.groupby('Architecture').agg({
    'Score (lower=better)': ['mean', 'min', 'count']
}).round(4)
nb_print(arch_summary.sort_values(('Score (lower=better)', 'mean')))

# Summary by batch size
nb_print("\n>>> PERFORMANCE BY BATCH SIZE")
batch_summary = df_display.groupby('Batch').agg({
    'Score (lower=better)': ['mean', 'min', 'count']
}).round(4)
nb_print(batch_summary.sort_values(('Score (lower=better)', 'mean')))
>>> TOP 15 TRIALS
 number  Score (lower=better)    Architecture  Dropout  Weight Decay  Batch       LR  Epochs               duration  Rank
     95              0.252654       [128,128] 0.327398      0.005809    256 0.000930     192 0 days 00:04:36.864264     1
     35              0.253092     [64,128,64] 0.393491      0.000560    256 0.000585     141 0 days 00:07:05.511931     2
     89              0.253187       [128,128] 0.376655      0.006366    256 0.000774     160 0 days 00:04:21.112906     3
     63              0.253676 [256,64,128,64] 0.543012      0.000475    256 0.000935     116 0 days 00:07:44.138161     4
     52              0.253681  [256,64,64,64] 0.583006      0.000920    256 0.001563      80 0 days 00:07:38.224442     5
     94              0.254230       [128,128] 0.376809      0.005664    256 0.000881     187 0 days 00:04:15.041723     6
     70              0.254421     [128,64,64] 0.581950      0.000498    256 0.000628      80 0 days 00:07:55.794516     7
     79              0.254464        [128,64] 0.433082      0.008241    256 0.000503     126 0 days 00:05:01.605609     8
     78              0.255014        [128,64] 0.432809      0.009161    256 0.000505     135 0 days 00:04:58.438799     9
     82              0.255056        [128,64] 0.433394      0.009323    256 0.000645     135 0 days 00:04:45.826871    10
     84              0.255355        [128,64] 0.442922      0.009044    256 0.000659     134 0 days 00:05:25.010301    11
     77              0.255363        [128,64] 0.429382      0.008141    256 0.000519     134 0 days 00:05:27.045874    12
     83              0.255522        [128,64] 0.399823      0.009554    256 0.000679     135 0 days 00:04:41.587367    13
     54              0.255583  [256,64,64,64] 0.478271      0.000427    256 0.001402      95 0 days 00:06:16.954964    14
     85              0.255729        [128,64] 0.437742      0.006872   2048 0.000928     124 0 days 00:01:46.473880    15
>>> PERFORMANCE BY ARCHITECTURE
                  Score (lower=better)              
                                  mean     min count
Architecture                                        
[256,512,64,256]                0.2570  0.2570     1
[64,512,512]                    0.2570  0.2570     1
[256,64,64,256]                 0.2573  0.2564     2
[256,512,256]                   0.2574  0.2574     1
[256,64,64,64]                  0.2574  0.2537     6
[256,64,128,64]                 0.2575  0.2537     5
[512,64,128,64]                 0.2581  0.2581     1
[128,64,64]                     0.2583  0.2544     5
[128,64,64,64]                  0.2587  0.2587     1
[512,64,64,64]                  0.2587  0.2587     1
[256,512,512]                   0.2593  0.2593     1
[128,128]                       0.2600  0.2527     9
[128,256,64,256]                0.2600  0.2600     1
[128,64]                        0.2602  0.2545    15
[128,64,64,512]                 0.2603  0.2603     1
[256,256,64]                    0.2609  0.2609     1
[512,128,64]                    0.2609  0.2609     1
[256,64,64]                     0.2610  0.2610     1
[256,512,64,128]                0.2611  0.2611     1
[256,64,256]                    0.2612  0.2612     1
[256,512,64]                    0.2614  0.2559     6
[64,256,64,256]                 0.2616  0.2616     1
[256,512]                       0.2620  0.2589     2
[256,128,64,256]                0.2620  0.2620     1
[256,64,64,128]                 0.2621  0.2621     1
[512,256]                       0.2623  0.2623     1
[512,512]                       0.2634  0.2634     1
[256,64,64,512]                 0.2635  0.2576     3
[64,256]                        0.2639  0.2639     1
[64]                            0.2642  0.2629     2
[64,128,64]                     0.2643  0.2531     4
[128]                           0.2643  0.2636     2
[512,128,256,256]               0.2644  0.2644     1
[512]                           0.2658  0.2649     2
[256,256,64,512]                0.2661  0.2661     1
[256,512,256,256]               0.2662  0.2662     1
[256,64,512]                    0.2673  0.2621     2
[256,512,128]                   0.2721  0.2721     1
[64,128,256]                    0.2732  0.2732     1
[64,128]                        0.2759  0.2564     3
[128,64,128]                    0.2766  0.2766     1
[256,256]                       0.2771  0.2771     1
[256,128,64]                    0.2828  0.2828     1
[128,512,64,64]                 0.2893  0.2893     1
[64,128,128]                    0.2898  0.2898     1
[256,64,512,512]                0.2911  0.2911     1
[256,512,128,128]               0.3196  0.3196     1
>>> PERFORMANCE BY BATCH SIZE
      Score (lower=better)              
                      mean     min count
Batch                                   
256                 0.2616  0.2527    75
1024                0.2649  0.2581     9
2048                0.2662  0.2557     8
512                 0.2765  0.2600     8
Code
#@title 🔬 Phase 1.2: Validate Top 10 from Original 100-Trial Study on All Imputations

import pandas as pd
import numpy as np
import torch
import gc
import time
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_ipcw, integrated_brier_score
import tempfile
import os
from contextlib import contextmanager

@contextmanager
def local_workdir():
    """Temporarily switch CWD to a local temp dir to avoid Google Drive file locks."""
    original = os.getcwd()
    with tempfile.TemporaryDirectory() as tmpdir:
        os.chdir(tmpdir)
        try:
            yield tmpdir
        finally:
            os.chdir(original)
def compute_ibs(model, X_val_s, y_train_struct, y_val_struct, horizons):
    surv_df = model.predict_surv_df(X_val_s)
    if surv_df is None or surv_df.shape[0] == 0:
        return np.nan

    grid = surv_df.index.values

    t_min = max(horizons[0], grid[0] + 1e-4)
    t_max = min(horizons[-1], grid[-1] - 1e-4)
    if t_min >= t_max:
        return np.nan

    eval_times = np.array([h for h in horizons if t_min <= h <= t_max])
    if len(eval_times) < 2:
        return np.nan

    surv_matrix = np.zeros((X_val_s.shape[0], len(eval_times)))
    for j, h in enumerate(eval_times):
        idx = np.searchsorted(grid, h, side='right') - 1
        idx = int(np.clip(idx, 0, len(grid) - 1))
        surv_matrix[:, j] = surv_df.iloc[idx].values

    try:
        return integrated_brier_score(y_train_struct, y_val_struct, surv_matrix, eval_times)
    except Exception:
        return np.nan

# --- CONFIGURATION ---
N_TOP_CANDIDATES = 10
N_FOLDS = 5
EVAL_HORIZONS = [3, 6, 12, 36, 60]
HORIZON_WEIGHTS = {3: 0.125, 6: 0.125, 12: 0.25, 36: 0.25, 60: 0.25}
SEED = 2125

nb_print(f"\n{'='*70}")
nb_print(f"🔬 PHASE 2: VALIDATING TOP {N_TOP_CANDIDATES} FROM 100-TRIAL STUDY")
nb_print(f"{'='*70}")

# --- LOAD AND PARSE THE ORIGINAL 100-TRIAL STUDY ---
df = pd.read_csv('DS_Optuna_100t_5f_20260317_0041.csv')

# Filter completed trials and sort by score (lower combined metric = better)
df = df[df['state'] == 'COMPLETE'].sort_values('value', ascending=True)

def extract_config_from_csv(row):
    """Reconstruct config from the CSV column format"""
    # Rebuild architecture from n_layers and n_units_l* columns
    n_layers = int(row['params_n_layers'])
    nodes = []
    for i in range(n_layers):
        col_name = f'params_n_units_l{i}'
        if col_name in row and pd.notna(row[col_name]):
            nodes.append(int(row[col_name]))
    
    config = {
        'trial_number': int(row['number']),
        'phase1_score': row['value'],  # Combined metric (lower is better)
        'lr': row['params_lr'],
        'weight_decay': row['params_weight_decay'],
        'batch_size': int(row['params_batch_size']),
        'dropout': row['params_dropout'],
        'nodes': nodes,
        'epochs': int(row['params_epochs']) if 'params_epochs' in row else 200,
        'patience': int(row['params_patience']) if 'params_patience' in row else 15,
    }
    return config

top_configs = []
for _, row in df.head(N_TOP_CANDIDATES).iterrows():
    top_configs.append(extract_config_from_csv(row))

nb_print(f"\n📋 Loaded Top {N_TOP_CANDIDATES} Configurations from DS_Optuna_100t_5f_20260317_0041.csv:")
nb_print(f"   (Lower 'Combined Metric' = Better - from Phase 1)")
for i, cfg in enumerate(top_configs):
    nb_print(f"  Rank {i+1}: Trial #{cfg['trial_number']} | "
             f"Arch={cfg['nodes']} | LR={cfg['lr']:.2e} | "
             f"Dropout={cfg['dropout']:.2f} | "
             f"Combined={cfg['phase1_score']:.4f}")

# --- VALIDATION FUNCTION (Same as before) ---
def validate_on_all_imputations(params, n_imputations=None):
    """Validate a single config on all imputations."""
    if n_imputations is None:
        n_imputations = len(imputations_list_mar26)
    
    imp_c_scores = []
    imp_ibs_scores = []
    imp_combined_scores = []
    
    for imp_idx in range(n_imputations):
        X_all, events_all, times_all, strat_labels = prepare_stratified_data(df_idx=imp_idx)
        
        skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=SEED)
        fold_c_scores = []
        fold_ibs_scores = []
        fold_combined_scores = []
        
        for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_all, strat_labels)):
            model_d = None
            model_r = None
            
            try:
                with local_workdir():
                    # Data split
                    X_train, X_val = X_all.iloc[train_idx], X_all.iloc[val_idx]
                    e_train_raw, e_val_raw = events_all[train_idx], events_all[val_idx]
                    t_train, t_val = times_all[train_idx], times_all[val_idx]
                    
                    # Scale
                    scaler = StandardScaler().fit(X_train)
                    X_train_s = scaler.transform(X_train).astype('float32')
                    X_val_s = scaler.transform(X_val).astype('float32')
                    
                    # Event masks
                    e_train_d = (e_train_raw == 1)
                    e_val_d = (e_val_raw == 1)
                    e_train_r = (e_train_raw == 2)
                    e_val_r = (e_val_raw == 2)
                    
                    # Fit models
                    model_d = fit_deepsurv(X_train_s, t_train, e_train_d, X_val_s, t_val, e_val_d, params)
                    model_r = fit_deepsurv(X_train_s, t_train, e_train_r, X_val_s, t_val, e_val_r, params)
                    
                    # Structured arrays
                    y_tr_d = to_structured(t_train, e_train_d)
                    y_va_d = to_structured(t_val, e_val_d)
                    y_tr_r = to_structured(t_train, e_train_r)
                    y_va_r = to_structured(t_val, e_val_r)
                    
                    # Evaluate on horizons
                    horizon_cindices = []
                    weights_used = []
                    
                    for h in EVAL_HORIZONS:
                        risk_d = risk_at_horizon(model_d, X_val_s, h)
                        risk_r = risk_at_horizon(model_r, X_val_s, h)
                        
                        try:
                            c_d = concordance_index_ipcw(y_tr_d, y_va_d, risk_d, tau=h)[0]
                        except:
                            c_d = np.nan
                        try:
                            c_r = concordance_index_ipcw(y_tr_r, y_va_r, risk_r, tau=h)[0]
                        except:
                            c_r = np.nan
                        
                        c_avg = np.nanmean([c_d, c_r])
                        if not np.isnan(c_avg):
                            horizon_cindices.append(c_avg)
                            weights_used.append(HORIZON_WEIGHTS[h])
                    
                    if len(horizon_cindices) > 0:
                        weighted_c = np.average(horizon_cindices, weights=weights_used)

                        ibs_d = compute_ibs(model_d, X_val_s, y_tr_d, y_va_d, EVAL_HORIZONS)
                        ibs_r = compute_ibs(model_r, X_val_s, y_tr_r, y_va_r, EVAL_HORIZONS)

                        ibs_vals = [v for v in [ibs_d, ibs_r] if not np.isnan(v)]
                        avg_ibs = np.mean(ibs_vals) if len(ibs_vals) > 0 else np.nan

                        if not np.isnan(avg_ibs):
                            combined = np.sqrt((1.0 - weighted_c) ** 2 + avg_ibs ** 2)
                        else:
                            combined = 1.0 - weighted_c

                        fold_c_scores.append(weighted_c)
                        fold_ibs_scores.append(avg_ibs)
                        fold_combined_scores.append(combined)
                        
            except Exception as e:
                nb_print(f"    [Imp {imp_idx} Fold {fold_idx} FAIL] {str(e)[:80]}")
                pass
            
            finally:
                if model_d is not None:
                    del model_d
                if model_r is not None:
                    del model_r
                gc.collect()
                torch.cuda.empty_cache()
        
        # Average across folds for this imputation
        valid_fold_c = [x for x in fold_c_scores if not np.isnan(x)]
        valid_fold_ibs = [x for x in fold_ibs_scores if not np.isnan(x)]
        valid_fold_combined = [x for x in fold_combined_scores if not np.isnan(x)]

        imp_c_scores.append(np.mean(valid_fold_c) if len(valid_fold_c) > 0 else np.nan)
        imp_ibs_scores.append(np.mean(valid_fold_ibs) if len(valid_fold_ibs) > 0 else np.nan)
        imp_combined_scores.append(np.mean(valid_fold_combined) if len(valid_fold_combined) > 0 else np.nan)
    
    # Aggregate across imputations
    valid_c = [s for s in imp_c_scores if not np.isnan(s)]
    valid_ibs = [s for s in imp_ibs_scores if not np.isnan(s)]
    valid_combined = [s for s in imp_combined_scores if not np.isnan(s)]

    mean_c = np.mean(valid_c) if len(valid_c) > 0 else np.nan
    std_c = np.std(valid_c) if len(valid_c) > 0 else np.nan

    mean_ibs = np.mean(valid_ibs) if len(valid_ibs) > 0 else np.nan
    std_ibs = np.std(valid_ibs) if len(valid_ibs) > 0 else np.nan

    mean_combined = np.mean(valid_combined) if len(valid_combined) > 0 else np.nan
    std_combined = np.std(valid_combined) if len(valid_combined) > 0 else np.nan

    return mean_c, std_c, mean_ibs, std_ibs, mean_combined, std_combined, imp_c_scores, imp_ibs_scores, imp_combined_scores

# --- RUN PHASE 2 VALIDATION ---
n_imps = len(imputations_list_mar26)
est_fits = N_TOP_CANDIDATES * N_FOLDS * n_imps

nb_print(f"\n🚀 Starting Phase 2 Validation...")
nb_print(f"   {N_TOP_CANDIDATES} configs × {N_FOLDS} folds × {n_imps} imputations = ~{est_fits} model fits")
nb_print(f"   Estimated time: ~{est_fits * 4 / 60:.1f} hours (assuming ~4s per fit)")

results = []
start_time = time.time()

for rank, config in enumerate(top_configs, 1):
    nb_print(f"\n  ⚙️  [{rank}/{N_TOP_CANDIDATES}] Validating Trial #{config['trial_number']}...")
    nb_print(f"      Arch={config['nodes']}, LR={config['lr']:.2e}, WD={config['weight_decay']:.2e}")
    
    mean_c, std_c, mean_ibs, std_ibs, mean_combined, std_combined, per_imp_c, per_imp_ibs, per_imp_combined = validate_on_all_imputations(config)
    
    result = {
        'Phase1_Rank': rank,
        'Trial_Number': config['trial_number'],
        'Phase1_Combined_Metric': config['phase1_score'],
        'Phase2_Mean_CIndex': mean_c,
        'Phase2_Std_CIndex': std_c,
        'Phase2_Mean_IBS': mean_ibs,
        'Phase2_Std_IBS': std_ibs,
        'Phase2_Mean_Combined': mean_combined,
        'Phase2_Std_Combined': std_combined,
        'Architecture': str(config['nodes']),
        'LR': config['lr'],
        'Weight_Decay': config['weight_decay'],
        'Dropout': config['dropout'],
        'Batch_Size': config['batch_size']
    }

    results.append(result)
    
    status = "✅" if not np.isnan(mean_c) else "❌"
    nb_print(
        f"      {status} Phase 2: "
        f"C={mean_c:.4f} ± {std_c:.4f} | "
        f"IBS={mean_ibs:.4f} ± {std_ibs:.4f} | "
        f"Combined={mean_combined:.4f} ± {std_combined:.4f}"
    )

# --- SUMMARY ---
results_df = pd.DataFrame(results)
results_df = results_df.sort_values('Phase2_Mean_Combined', ascending=True)
results_df['Phase2_Rank'] = range(1, len(results_df) + 1)

nb_print(f"\n{'='*70}")
nb_print("🏆 PHASE 2 FINAL RANKING (Validated on All Imputations):")
nb_print(f"{'='*70}")

display_cols = [
    'Phase2_Rank', 'Trial_Number',
    'Phase2_Mean_CIndex', 'Phase2_Std_CIndex',
    'Phase2_Mean_IBS', 'Phase2_Std_IBS',
    'Phase2_Mean_Combined', 'Phase2_Std_Combined',
    'Phase1_Rank', 'Phase1_Combined_Metric', 'Architecture'
]
nb_print(results_df[display_cols].to_string(index=False))

# Identify winners and rank changes
best_config = results_df.iloc[0]
nb_print(f"\n🥇 BEST CONFIGURATION (Phase 2):")
nb_print(f"   Trial #{int(best_config['Trial_Number'])} | {best_config['Architecture']}")
nb_print(f"   Phase 2 C-Index: {best_config['Phase2_Mean_CIndex']:.4f} ± {best_config['Phase2_Std_CIndex']:.4f}")
nb_print(f"   (Was Phase 1 Rank {int(best_config['Phase1_Rank'])} with Combined Metric {best_config['Phase1_Combined_Metric']:.4f})")

# Check if winner changed
if int(best_config['Phase1_Rank']) != 1:
    original_winner = results_df[results_df['Phase1_Rank'] == 1].iloc[0]
    nb_print(f"\n⚠️  RANKING CHANGE!")
    nb_print(f"   Phase 1 Winner (Trial #{int(original_winner['Trial_Number'])}) "
             f"dropped to Rank {int(original_winner['Phase2_Rank'])} in Phase 2")
    nb_print(f"   New winner generalizes better across imputations.")

# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"DS_Phase2_Top10_Validation_{timestamp}.csv"
results_df.to_csv(filename, index=False)
nb_print(f"\n💾 Saved to: {filename}")
nb_print(f"⏱️  Phase 2 Time: {(time.time() - start_time)/60:.2f} min")
======================================================================
🔬 PHASE 2: VALIDATING TOP 10 FROM 100-TRIAL STUDY
======================================================================
📋 Loaded Top 10 Configurations from DS_Optuna_100t_5f_20260317_0041.csv:
   (Lower 'Combined Metric' = Better - from Phase 1)
  Rank 1: Trial #95 | Arch=[128, 128] | LR=9.30e-04 | Dropout=0.33 | Combined=0.2527
  Rank 2: Trial #35 | Arch=[64, 128, 64] | LR=5.85e-04 | Dropout=0.39 | Combined=0.2531
  Rank 3: Trial #89 | Arch=[128, 128] | LR=7.74e-04 | Dropout=0.38 | Combined=0.2532
  Rank 4: Trial #63 | Arch=[256, 64, 128, 64] | LR=9.35e-04 | Dropout=0.54 | Combined=0.2537
  Rank 5: Trial #52 | Arch=[256, 64, 64, 64] | LR=1.56e-03 | Dropout=0.58 | Combined=0.2537
  Rank 6: Trial #94 | Arch=[128, 128] | LR=8.81e-04 | Dropout=0.38 | Combined=0.2542
  Rank 7: Trial #70 | Arch=[128, 64, 64] | LR=6.28e-04 | Dropout=0.58 | Combined=0.2544
  Rank 8: Trial #79 | Arch=[128, 64] | LR=5.03e-04 | Dropout=0.43 | Combined=0.2545
  Rank 9: Trial #78 | Arch=[128, 64] | LR=5.05e-04 | Dropout=0.43 | Combined=0.2550
  Rank 10: Trial #82 | Arch=[128, 64] | LR=6.45e-04 | Dropout=0.43 | Combined=0.2551
🚀 Starting Phase 2 Validation...
   10 configs × 5 folds × 5 imputations = ~250 model fits
   Estimated time: ~16.7 hours (assuming ~4s per fit)
  ⚙️  [1/10] Validating Trial #95...
      Arch=[128, 128], LR=9.30e-04, WD=5.81e-03
      ✅ Phase 2: C=0.7446 ± 0.0024 | IBS=0.0283 ± 0.0000 | Combined=0.2570 ± 0.0024
  ⚙️  [2/10] Validating Trial #35...
      Arch=[64, 128, 64], LR=5.85e-04, WD=5.60e-04
      ✅ Phase 2: C=0.7459 ± 0.0036 | IBS=0.0283 ± 0.0000 | Combined=0.2557 ± 0.0036
  ⚙️  [3/10] Validating Trial #89...
      Arch=[128, 128], LR=7.74e-04, WD=6.37e-03
      ✅ Phase 2: C=0.7454 ± 0.0013 | IBS=0.0283 ± 0.0000 | Combined=0.2562 ± 0.0013
  ⚙️  [4/10] Validating Trial #63...
      Arch=[256, 64, 128, 64], LR=9.35e-04, WD=4.75e-04
      ✅ Phase 2: C=0.7470 ± 0.0020 | IBS=0.0283 ± 0.0000 | Combined=0.2546 ± 0.0020
  ⚙️  [5/10] Validating Trial #52...
      Arch=[256, 64, 64, 64], LR=1.56e-03, WD=9.20e-04
      ✅ Phase 2: C=0.7457 ± 0.0014 | IBS=0.0284 ± 0.0000 | Combined=0.2559 ± 0.0014
  ⚙️  [6/10] Validating Trial #94...
      Arch=[128, 128], LR=8.81e-04, WD=5.66e-03
      ✅ Phase 2: C=0.7455 ± 0.0022 | IBS=0.0283 ± 0.0000 | Combined=0.2561 ± 0.0022
  ⚙️  [7/10] Validating Trial #70...
      Arch=[128, 64, 64], LR=6.28e-04, WD=4.98e-04
      ✅ Phase 2: C=0.7464 ± 0.0032 | IBS=0.0284 ± 0.0001 | Combined=0.2552 ± 0.0032
  ⚙️  [8/10] Validating Trial #79...
      Arch=[128, 64], LR=5.03e-04, WD=8.24e-03
      ✅ Phase 2: C=0.7459 ± 0.0031 | IBS=0.0283 ± 0.0000 | Combined=0.2557 ± 0.0031
  ⚙️  [9/10] Validating Trial #78...
      Arch=[128, 64], LR=5.05e-04, WD=9.16e-03
      ✅ Phase 2: C=0.7402 ± 0.0116 | IBS=0.0283 ± 0.0001 | Combined=0.2614 ± 0.0116
  ⚙️  [10/10] Validating Trial #82...
      Arch=[128, 64], LR=6.45e-04, WD=9.32e-03
      ✅ Phase 2: C=0.7363 ± 0.0125 | IBS=0.0284 ± 0.0001 | Combined=0.2653 ± 0.0125
======================================================================
🏆 PHASE 2 FINAL RANKING (Validated on All Imputations):
======================================================================
 Phase2_Rank  Trial_Number  Phase2_Mean_CIndex  Phase2_Std_CIndex  Phase2_Mean_IBS  Phase2_Std_IBS  Phase2_Mean_Combined  Phase2_Std_Combined  Phase1_Rank  Phase1_Combined_Metric       Architecture
           1            63            0.747027           0.001992         0.028331        0.000027              0.254563             0.001982            4                0.253676 [256, 64, 128, 64]
           2            70            0.746412           0.003212         0.028381        0.000074              0.255180             0.003199            7                0.254421      [128, 64, 64]
           3            79            0.745884           0.003080         0.028295        0.000025              0.255696             0.003062            8                0.254464          [128, 64]
           4            35            0.745859           0.003627         0.028331        0.000042              0.255725             0.003609            2                0.253092      [64, 128, 64]
           5            52            0.745729           0.001416         0.028413        0.000033              0.255863             0.001406            5                0.253681  [256, 64, 64, 64]
           6            94            0.745497           0.002206         0.028295        0.000016              0.256078             0.002194            6                0.254230         [128, 128]
           7            89            0.745425           0.001302         0.028301        0.000017              0.256152             0.001296            3                0.253187         [128, 128]
           8            95            0.744607           0.002437         0.028309        0.000008              0.256966             0.002423            1                0.252654         [128, 128]
           9            78            0.740175           0.011617         0.028340        0.000061              0.261383             0.011572            9                0.255014          [128, 64]
          10            82            0.736254           0.012495         0.028366        0.000072              0.265296             0.012454           10                0.255056          [128, 64]
🥇 BEST CONFIGURATION (Phase 2):
   Trial #63 | [256, 64, 128, 64]
   Phase 2 C-Index: 0.7470 ± 0.0020
   (Was Phase 1 Rank 4 with Combined Metric 0.2537)
⚠️  RANKING CHANGE!
   Phase 1 Winner (Trial #95) dropped to Rank 8 in Phase 2
   New winner generalizes better across imputations.
💾 Saved to: DS_Phase2_Top10_Validation_20260401_1249.csv
⏱️  Phase 2 Time: 236.21 min
Code
nb_print(best_config)
Phase1_Rank                                4
Trial_Number                              63
Phase1_Combined_Metric              0.253676
Phase2_Mean_CIndex                  0.747027
Phase2_Std_CIndex                   0.001992
Phase2_Mean_IBS                     0.028331
Phase2_Std_IBS                      0.000027
Phase2_Mean_Combined                0.254563
Phase2_Std_Combined                 0.001982
Architecture              [256, 64, 128, 64]
LR                                  0.000935
Weight_Decay                        0.000475
Dropout                             0.543012
Batch_Size                               256
Phase2_Rank                                1
Name: 3, dtype: object
Code
#@title 📝 Take-Home Message: Interpretation of Best DeepSurv Configuration

import pandas as pd
from IPython.display import display

# UPDATED 2026-04-01:
# Previous winner:
#   Nodes [64, 128, 64] | LR 0.000585 | WD 0.00056 | Dropout 0.393 | Batch 256
# New Phase 2 winner:
#   Nodes [256, 64, 128, 64] | LR 0.000935 | WD 0.000475 | Dropout 0.543 | Batch 256
#   Phase1 Combined Metric = 0.253676
#   Phase2 Mean C-Index    = 0.747027
#   Phase2 Std C-Index     = 0.001992
#   Phase2 Mean IBS        = 0.028331
#   Phase2 Std IBS         = 0.000027
#   Phase2 Mean Combined   = 0.254563
#   Phase2 Std Combined    = 0.001982

config_interpretation = pd.DataFrame([
    {
        'Component': 'Regularization (Stronger Stochastic Shield)',
        'Selected Value': 'Dropout: 0.543 | Weight Decay: 0.000475',
        'Interpretation': 'The Phase 2 winner uses stronger dropout than the previous best model, but still keeps weight decay relatively light. This suggests the network benefits more from aggressively disrupting noisy co-adaptations during training than from heavily shrinking coefficients. In practical terms, the model needs freedom to express nonlinear risk structure, but it also needs strong protection against memorizing unstable patterns.'
    },
    {
        'Component': 'Model Capacity (Deep Bottleneck Funnel)',
        'Selected Value': 'Nodes: [256, 64, 128, 64]',
        'Interpretation': 'The winning architecture is deeper and more structured than the earlier compact funnel. It starts wide, compresses sharply, expands again, and then contracts before output. That pattern is consistent with hierarchical feature extraction: broad first-pass interaction capture, bottleneck-based denoising, mid-level representation rebuilding, and final compression into a stable prognostic signal. The result implies the data supports more complex nonlinear structure than the earlier 3-layer winner suggested.'
    },
    {
        'Component': 'Optimization Mechanics',
        'Selected Value': 'Batch: 256 | LR: 0.000935',
        'Interpretation': 'The optimizer again favored batch size 256, reinforcing the pattern that smaller batches work better than larger ones for this Cox setup. The learning rate moved upward relative to the previous Phase 2 winner, indicating that once the architecture became more expressive, slightly faster optimization helped the model reach a better ranking solution without losing stability.'
    },
    {
        'Component': 'Performance Context (Best Generalization Across Imputations)',
        'Selected Value': 'Phase1 Combined: 0.253676 | Phase2 C-Index: 0.747027 ± 0.001992 | Phase2 IBS: 0.028331 ± 0.000027 | Phase2 Combined: 0.254563 ± 0.001982',
        'Interpretation': 'This model is especially interesting because it was only Rank 4 in Phase 1, yet it emerged as Rank 1 after validation across all imputations. That means the final winner was not simply the best single-study Optuna result; it was the configuration that generalized most reliably. IBS remained extremely stable across candidates, so the Phase 2 win was driven mainly by better discrimination while preserving essentially the same calibration level.'
    }
])

nb_print("\n>>> TAKE-HOME MESSAGE: UPDATED OPTUNA-OPTIMIZED DEEPSURV CONFIGURATION (PHASE 2 WINNER)")
pd.set_option('display.max_colwidth', None)

styled_table = (
    config_interpretation.style
    .set_properties(**{
        "text-align": "left",
        "white-space": "pre-wrap",
        "font-size": "14px",
        "vertical-align": "top"
    })
    .set_table_styles([
        {"selector": "th", "props": [("background-color", "#f0f2f6"), ("font-weight", "bold"), ("font-size", "14px")]},
        {"selector": "td", "props": [("padding", "12px"), ("border-bottom", "1px solid #ddd")]}
    ])
)

display(styled_table)
>>> TAKE-HOME MESSAGE: UPDATED OPTUNA-OPTIMIZED DEEPSURV CONFIGURATION (PHASE 2 WINNER)
  Component Selected Value Interpretation
0 Regularization (Stronger Stochastic Shield) Dropout: 0.543 | Weight Decay: 0.000475 The Phase 2 winner uses stronger dropout than the previous best model, but still keeps weight decay relatively light. This suggests the network benefits more from aggressively disrupting noisy co-adaptations during training than from heavily shrinking coefficients. In practical terms, the model needs freedom to express nonlinear risk structure, but it also needs strong protection against memorizing unstable patterns.
1 Model Capacity (Deep Bottleneck Funnel) Nodes: [256, 64, 128, 64] The winning architecture is deeper and more structured than the earlier compact funnel. It starts wide, compresses sharply, expands again, and then contracts before output. That pattern is consistent with hierarchical feature extraction: broad first-pass interaction capture, bottleneck-based denoising, mid-level representation rebuilding, and final compression into a stable prognostic signal. The result implies the data supports more complex nonlinear structure than the earlier 3-layer winner suggested.
2 Optimization Mechanics Batch: 256 | LR: 0.000935 The optimizer again favored batch size 256, reinforcing the pattern that smaller batches work better than larger ones for this Cox setup. The learning rate moved upward relative to the previous Phase 2 winner, indicating that once the architecture became more expressive, slightly faster optimization helped the model reach a better ranking solution without losing stability.
3 Performance Context (Best Generalization Across Imputations) Phase1 Combined: 0.253676 | Phase2 C-Index: 0.747027 ± 0.001992 | Phase2 IBS: 0.028331 ± 0.000027 | Phase2 Combined: 0.254563 ± 0.001982 This model is especially interesting because it was only Rank 4 in Phase 1, yet it emerged as Rank 1 after validation across all imputations. That means the final winner was not simply the best single-study Optuna result; it was the configuration that generalized most reliably. IBS remained extremely stable across candidates, so the Phase 2 win was driven mainly by better discrimination while preserving essentially the same calibration level.
Code
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.ticker as mtick
import seaborn as sns
from pathlib import Path
from datetime import datetime
import re

try:
    from adjustText import adjust_text
    HAS_ADJUST_TEXT = True
except ImportError:
    HAS_ADJUST_TEXT = False


PROJECT_ROOT = Path.cwd()
FIGS_DIR = PROJECT_ROOT / "_figs"
FIGS_DIR.mkdir(parents=True, exist_ok=True)

def find_latest_phase2_csv(project_root):
    pattern = re.compile(r"^DS_Phase2_Top10_Validation_(\d{8}_\d{4})\.csv$")
    candidates = []

    for folder in [project_root, project_root / "_out"]:
        if not folder.exists():
            continue

        for path in folder.glob("DS_Phase2_Top10_Validation_*.csv"):
            match = pattern.match(path.name)
            if match:
                stamp = datetime.strptime(match.group(1), "%Y%m%d_%H%M")
                candidates.append((stamp, path))

    if not candidates:
        raise FileNotFoundError("No DS_Phase2_Top10_Validation_YYYYMMDD_HHMM.csv files found.")

    return max(candidates, key=lambda x: x[0])[1]

csv_path = find_latest_phase2_csv(PROJECT_ROOT)
print(f"Using latest CSV: {csv_path}")

df = pd.read_csv(csv_path)

num_cols = [
    "Phase1_Rank", "Trial_Number", "Phase1_Combined_Metric",
    "Phase2_Mean_CIndex", "Phase2_Std_CIndex",
    "Phase2_Mean_IBS", "Phase2_Std_IBS",
    "LR", "Weight_Decay", "Dropout", "Batch_Size", "Phase2_Rank"
]
for col in num_cols:
    if col in df.columns:
        df[col] = pd.to_numeric(df[col], errors="coerce")

required_cols = ["Trial_Number", "Phase2_Mean_IBS", "Phase2_Mean_CIndex"]
missing = [c for c in required_cols if c not in df.columns]
if missing:
    raise ValueError(f"CSV is missing required columns: {missing}")

df = df.dropna(subset=required_cols).copy()
df["Trial_Number"] = df["Trial_Number"].astype(int)
df["Batch_Size"] = df["Batch_Size"].astype(int)
df = df.sort_values("Phase2_Rank").reset_index(drop=True)

def compute_pareto_front(df):
    # Minimize Phase2_Mean_IBS, maximize Phase2_Mean_CIndex
    sorted_df = df.sort_values("Phase2_Mean_IBS").reset_index(drop=True)
    pareto = []
    best_c = -np.inf
    for _, row in sorted_df.iterrows():
        if row["Phase2_Mean_CIndex"] >= best_c:
            pareto.append(row)
            best_c = row["Phase2_Mean_CIndex"]
    return pd.DataFrame(pareto)

phase1_winner = df.loc[df["Phase1_Combined_Metric"].idxmin()]
phase2_winner = df.loc[df["Phase2_Mean_CIndex"].idxmax()]
ibs_winner = df.loc[df["Phase2_Mean_IBS"].idxmin()]
pareto_df = compute_pareto_front(df)

mpl.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Times New Roman", "Times", "DejaVu Serif"],
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
    "axes.labelsize": 13,
    "axes.titlesize": 14,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "figure.dpi": 300,
})

palette = dict(
    zip(
        sorted(df["Architecture"].dropna().unique()),
        sns.color_palette("Set2", n_colors=df["Architecture"].nunique())
    )
)

fig, ax = plt.subplots(figsize=(8.8, 6.2))

# Error bars
for _, row in df.iterrows():
    ax.errorbar(
        row["Phase2_Mean_IBS"],
        row["Phase2_Mean_CIndex"],
        xerr=row["Phase2_Std_IBS"] if "Phase2_Std_IBS" in df.columns else None,
        yerr=row["Phase2_Std_CIndex"] if "Phase2_Std_CIndex" in df.columns else None,
        fmt="none",
        ecolor="#999999",
        elinewidth=1.0,
        alpha=0.8,
        capsize=2,
        zorder=1
    )

# Scatter by architecture
for arch, sub in df.groupby("Architecture"):
    ax.scatter(
        sub["Phase2_Mean_IBS"], sub["Phase2_Mean_CIndex"],
        s=90, color=palette[arch], alpha=0.9, linewidth=0,
        zorder=3, label=arch
    )

# Pareto front
pf = pareto_df.sort_values("Phase2_Mean_IBS")
ax.step(
    pf["Phase2_Mean_IBS"], pf["Phase2_Mean_CIndex"],
    where="post", color="#333333", alpha=0.85, linewidth=1.7, linestyle="--",
    zorder=2, label="Pareto front"
)

# Highlight Phase 1 winner
ax.scatter(
    phase1_winner["Phase2_Mean_IBS"], phase1_winner["Phase2_Mean_CIndex"],
    facecolors="none", edgecolors="#1f1f1f", s=260, linewidth=2.0, zorder=5,
    label=f"Best Phase 1 (T{int(phase1_winner['Trial_Number'])})"
)

# Highlight Phase 2 winner
ax.scatter(
    phase2_winner["Phase2_Mean_IBS"], phase2_winner["Phase2_Mean_CIndex"],
    marker="D", s=120, color="#c0392b", edgecolor="white", linewidth=0.8, zorder=6,
    label=f"Best Phase 2 C-index (T{int(phase2_winner['Trial_Number'])})"
)

# Optional: highlight IBS winner if different
if int(ibs_winner["Trial_Number"]) != int(phase2_winner["Trial_Number"]):
    ax.scatter(
        ibs_winner["Phase2_Mean_IBS"], ibs_winner["Phase2_Mean_CIndex"],
        marker="s", s=110, color="#2d6a4f", edgecolor="white", linewidth=0.8, zorder=6,
        label=f"Best IBS (T{int(ibs_winner['Trial_Number'])})"
    )

# Labels, offset right/up from points
x_span = df["Phase2_Mean_IBS"].max() - df["Phase2_Mean_IBS"].min()
y_span = df["Phase2_Mean_CIndex"].max() - df["Phase2_Mean_CIndex"].min()
x_offset = max(x_span * 0.0018, 0.000002)
y_offset = max(y_span * 0.0035, 0.00005)

texts = []
for _, row in df.iterrows():
    texts.append(
        ax.text(
            row["Phase2_Mean_IBS"] + x_offset,
            row["Phase2_Mean_CIndex"] + y_offset,
            f"T{int(row['Trial_Number'])}",
            fontsize=8, color="#444444", ha="left", va="bottom"
        )
    )

if HAS_ADJUST_TEXT and texts:
    adjust_text(
        texts, ax=ax,
        arrowprops=dict(arrowstyle="-", color="#bbbbbb", lw=0.5),
        expand=(1.1, 1.2),
        force_text=(0.5, 0.8),
    )

# Axis formatting
x_margin = max((df["Phase2_Mean_IBS"].max() - df["Phase2_Mean_IBS"].min()) * 0.15, 0.00003)
y_margin = max((df["Phase2_Mean_CIndex"].max() - df["Phase2_Mean_CIndex"].min()) * 0.15, 0.001)

ax.set_xlim(df["Phase2_Mean_IBS"].min() - x_margin, df["Phase2_Mean_IBS"].max() + x_margin)
ax.set_ylim(df["Phase2_Mean_CIndex"].min() - y_margin, df["Phase2_Mean_CIndex"].max() + y_margin)

ax.set_xlabel("Phase 2 Mean IBS")
ax.set_ylabel("Phase 2 Mean C-Index")
# ax.set_title("DeepSurv Phase 2 Validation Shortlist", fontsize=14, fontweight="bold", pad=10)

ax.xaxis.set_major_formatter(mtick.FormatStrFormatter("%.5f"))
ax.yaxis.set_major_formatter(mtick.FormatStrFormatter("%.4f"))

ax.legend(
    fontsize=8.5, loc="lower left",
    frameon=True, facecolor="white", edgecolor="#cccccc", framealpha=0.95
)
ax.grid(True, linestyle=":", alpha=0.4)

sns.despine()
plt.tight_layout()

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
fig_path_png = FIGS_DIR / f"deepsurv_phase2_ibs_cindex_{timestamp}.png"
fig_path_pdf = FIGS_DIR / f"deepsurv_phase2_ibs_cindex_{timestamp}.pdf"

plt.savefig(fig_path_png, dpi=300, bbox_inches="tight")
plt.savefig(fig_path_pdf, bbox_inches="tight")
plt.show()

print(f"Saved: {fig_path_png}")
print(f"Saved: {fig_path_pdf}")

display(
    df[
        [
            "Phase2_Rank", "Trial_Number", "Phase1_Rank",
            "Phase2_Mean_CIndex", "Phase2_Std_CIndex",
            "Phase2_Mean_IBS", "Phase2_Std_IBS",
            "Phase1_Combined_Metric",
            "Architecture", "LR", "Weight_Decay", "Dropout", "Batch_Size"
        ]
    ].sort_values("Phase2_Rank").reset_index(drop=True)
)

Phase2_Rank Trial_Number Phase1_Rank Phase2_Mean_CIndex Phase2_Std_CIndex Phase2_Mean_IBS Phase2_Std_IBS Phase1_Combined_Metric Architecture LR Weight_Decay Dropout Batch_Size
0 1 63 4 0.747027 0.001992 0.028331 0.000027 0.253676 [256, 64, 128, 64] 0.000935 0.000475 0.543012 256
1 2 70 7 0.746412 0.003212 0.028381 0.000074 0.254421 [128, 64, 64] 0.000628 0.000498 0.581950 256
2 3 79 8 0.745884 0.003080 0.028295 0.000025 0.254464 [128, 64] 0.000503 0.008241 0.433082 256
3 4 35 2 0.745859 0.003627 0.028331 0.000042 0.253092 [64, 128, 64] 0.000585 0.000560 0.393491 256
4 5 52 5 0.745729 0.001416 0.028413 0.000033 0.253681 [256, 64, 64, 64] 0.001563 0.000920 0.583006 256
5 6 94 6 0.745497 0.002206 0.028295 0.000016 0.254230 [128, 128] 0.000881 0.005664 0.376809 256
6 7 89 3 0.745425 0.001302 0.028301 0.000017 0.253187 [128, 128] 0.000774 0.006366 0.376655 256
7 8 95 1 0.744607 0.002437 0.028309 0.000008 0.252654 [128, 128] 0.000930 0.005809 0.327398 256
8 9 78 9 0.740175 0.011617 0.028340 0.000061 0.255014 [128, 64] 0.000505 0.009161 0.432809 256
9 10 82 10 0.736254 0.012495 0.028366 0.000072 0.255056 [128, 64] 0.000645 0.009323 0.433394 256
Back to top