Hyperparameter tuning based on death discrimination

This notebook tunes and validates an XGBoost survival model (survival:cox) for all-cause mortality prediction in a substance use disorder (SUD) treatment cohort. It uses 5 multiple imputations and ~56 predictors (demographic, clinical, socioeconomic, and treatment features).

Hyperparameter optimization is performed with Optuna using 5-fold cross-validation and dual stratification (treatment plan type + death event status), with fallback to simpler stratification when rare strata make 5-fold splitting infeasible.

The search is bi-objective (Pareto optimization), maximizing multi-horizon IPCW Uno’s C-index and minimizing Integrated Brier Score (IBS). Metrics are evaluated at 3, 6, 12, 36, and 60 months. For IBS, XGBoost risk scores are converted to absolute survival probabilities using Breslow baseline estimation.

The workflow uses stochastic multi-imputation tuning (one imputation per trial) followed by rescoring top Pareto candidates across all imputations to select a final robust configuration. Final performance is reported with bootstrap optimism correction and 95% confidence intervals.

Runs are CPU-only and use seed 2125 for reproducibility.

Author

ags

Published

March 5, 2026

Hyperparameter tuning XGBOOST (death as a reference)

0. Package loading and installation

Automatically generated by Colab.

Original file is located at https://colab.research.google.com/drive/1FMHIud9Hi0rIxnMqRfRFdzBQpKEzI796

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda install -c conda-forge \
#    numpy \
#    scipy \
#    pandas \
#    pyarrow \
#    scikit-survival \
#    spyder \
#    lifelines

# conda install -c conda-forge fastparquet
# conda install -c conda-forge xgboost
# conda install -c conda-forge pytorch cpuonly
# conda install -c pytorch pytorch cpuonly
# conda install -c conda-forge matplotlib
# conda install -c conda-forge seaborn
# conda install spyder-notebook -c spyder-ide
# conda install notebook nbformat nbconvert
# conda install -c conda-forge xlsxwriter
# conda install -c conda-forge shap

# import subprocess, sys

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "matplotlib"
# ])

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "seaborn"
# ])

print("numpy:", np.__version__)


from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv

#Dput
def dput_df(df, digits=6):
    data = {
        "columns": list(df.columns),
        "data": [
            [round(x, digits) if isinstance(x, (float, np.floating)) else x
             for x in row]
            for row in df.to_numpy()
        ]
    }
    print(data)


#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df
numpy: 2.0.1

Load data

Code

from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)

import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)
Code

import pandas as pd

for i in range(1, 6):
    globals()[f"imputation_nodum_{i}"] = pd.read_parquet(
        BASE_DIR / f"imputation_nondum_{i}.parquet",
        engine="fastparquet"
    )
Code
from IPython.display import display, HTML
import io
import sys

def fold_output(title, func):
    buffer = io.StringIO()
    sys.stdout = buffer
    func()
    sys.stdout = sys.__stdout__
    
    html = f"""
    <details>
      <summary>{title}</summary>
      <pre>{buffer.getvalue()}</pre>
    </details>
    """
    display(HTML(html))


fold_output(
    "Show imputation_nodum_1 structure",
    lambda: imputation_nodum_1.info()
)

fold_output(
    "Show imputation_1 structure",
    lambda: imputation_1.info()
)
Show imputation_nodum_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 43 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   readmit_time_from_adm_m        88504 non-null  float64
 1   death_time_from_adm_m          88504 non-null  float64
 2   adm_age_rec3                   88504 non-null  float64
 3   porc_pobr                      88504 non-null  float64
 4   dit_m                          88504 non-null  float64
 5   sex_rec                        88504 non-null  object 
 6   tenure_status_household        88504 non-null  object 
 7   cohabitation                   88504 non-null  object 
 8   sub_dep_icd10_status           88504 non-null  object 
 9   any_violence                   88504 non-null  object 
 10  prim_sub_freq_rec              88504 non-null  object 
 11  tr_outcome                     88504 non-null  object 
 12  adm_motive                     88504 non-null  object 
 13  first_sub_used                 88504 non-null  object 
 14  primary_sub_mod                88504 non-null  object 
 15  tipo_de_vivienda_rec2          88504 non-null  object 
 16  national_foreign               88504 non-null  int32  
 17  plan_type_corr                 88504 non-null  object 
 18  occupation_condition_corr24    88504 non-null  object 
 19  marital_status_rec             88504 non-null  object 
 20  urbanicity_cat                 88504 non-null  object 
 21  ed_attainment_corr             88504 non-null  object 
 22  evaluacindelprocesoteraputico  88504 non-null  object 
 23  eva_consumo                    88504 non-null  object 
 24  eva_fam                        88504 non-null  object 
 25  eva_relinterp                  88504 non-null  object 
 26  eva_ocupacion                  88504 non-null  object 
 27  eva_sm                         88504 non-null  object 
 28  eva_fisica                     88504 non-null  object 
 29  eva_transgnorma                88504 non-null  object 
 30  ethnicity                      88504 non-null  float64
 31  dg_psiq_cie_10_instudy         88504 non-null  bool   
 32  dg_psiq_cie_10_dg              88504 non-null  bool   
 33  dx_f3_mood                     88504 non-null  int32  
 34  dx_f6_personality              88504 non-null  int32  
 35  dx_f_any_severe_mental         88504 non-null  bool   
 36  any_phys_dx                    88504 non-null  bool   
 37  polysubstance_strict           88504 non-null  int32  
 38  readmit_event                  88504 non-null  float64
 39  death_event                    88504 non-null  int32  
 40  readmit_time_from_disch_m      88504 non-null  float64
 41  death_time_from_disch_m        88504 non-null  float64
 42  center_id                      88475 non-null  object 
dtypes: bool(4), float64(9), int32(5), object(25)
memory usage: 25.0+ MB
Show imputation_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 78 columns):
 #   Column                                                              Non-Null Count  Dtype  
---  ------                                                              --------------  -----  
 0   readmit_time_from_adm_m                                             88504 non-null  float64
 1   death_time_from_adm_m                                               88504 non-null  float64
 2   adm_age_rec3                                                        88504 non-null  float64
 3   porc_pobr                                                           88504 non-null  float64
 4   dit_m                                                               88504 non-null  float64
 5   national_foreign                                                    88504 non-null  int32  
 6   ethnicity                                                           88504 non-null  float64
 7   dg_psiq_cie_10_instudy                                              88504 non-null  bool   
 8   dg_psiq_cie_10_dg                                                   88504 non-null  bool   
 9   dx_f3_mood                                                          88504 non-null  int32  
 10  dx_f6_personality                                                   88504 non-null  int32  
 11  dx_f_any_severe_mental                                              88504 non-null  bool   
 12  any_phys_dx                                                         88504 non-null  bool   
 13  polysubstance_strict                                                88504 non-null  int32  
 14  readmit_time_from_disch_m                                           88504 non-null  float64
 15  readmit_event                                                       88504 non-null  float64
 16  death_time_from_disch_m                                             88504 non-null  float64
 17  death_event                                                         88504 non-null  int32  
 18  sex_rec_woman                                                       88504 non-null  float64
 19  tenure_status_household_illegal_settlement                          88504 non-null  float64
 20  tenure_status_household_owner_transferred_dwellings_pays_dividends  88504 non-null  float64
 21  tenure_status_household_renting                                     88504 non-null  float64
 22  tenure_status_household_stays_temporarily_with_a_relative           88504 non-null  float64
 23  cohabitation_alone                                                  88504 non-null  float64
 24  cohabitation_with_couple_children                                   88504 non-null  float64
 25  cohabitation_family_of_origin                                       88504 non-null  float64
 26  sub_dep_icd10_status_drug_dependence                                88504 non-null  float64
 27  any_violence_1_domestic_violence_sex_abuse                          88504 non-null  float64
 28  prim_sub_freq_rec_2_2_6_days_wk                                     88504 non-null  float64
 29  prim_sub_freq_rec_3_daily                                           88504 non-null  float64
 30  tr_outcome_adm_discharge_adm_reasons                                88504 non-null  float64
 31  tr_outcome_adm_discharge_rule_violation_undet                       88504 non-null  float64
 32  tr_outcome_completion                                               88504 non-null  float64
 33  tr_outcome_dropout                                                  88504 non-null  float64
 34  tr_outcome_referral                                                 88504 non-null  float64
 35  adm_motive_another_sud_facility_fonodrogas_senda_previene           88504 non-null  float64
 36  adm_motive_justice_sector                                           88504 non-null  float64
 37  adm_motive_sanitary_sector                                          88504 non-null  float64
 38  adm_motive_spontaneous_consultation                                 88504 non-null  float64
 39  first_sub_used_alcohol                                              88504 non-null  float64
 40  first_sub_used_cocaine_paste                                        88504 non-null  float64
 41  first_sub_used_cocaine_powder                                       88504 non-null  float64
 42  first_sub_used_marijuana                                            88504 non-null  float64
 43  first_sub_used_opioids                                              88504 non-null  float64
 44  first_sub_used_tranquilizers_hypnotics                              88504 non-null  float64
 45  primary_sub_mod_cocaine_paste                                       88504 non-null  float64
 46  primary_sub_mod_cocaine_powder                                      88504 non-null  float64
 47  primary_sub_mod_alcohol                                             88504 non-null  float64
 48  primary_sub_mod_marijuana                                           88504 non-null  float64
 49  tipo_de_vivienda_rec2_other_unknown                                 88504 non-null  float64
 50  plan_type_corr_m_pai                                                88504 non-null  float64
 51  plan_type_corr_m_pr                                                 88504 non-null  float64
 52  plan_type_corr_pg_pai                                               88504 non-null  float64
 53  plan_type_corr_pg_pr                                                88504 non-null  float64
 54  occupation_condition_corr24_inactive                                88504 non-null  float64
 55  occupation_condition_corr24_unemployed                              88504 non-null  float64
 56  marital_status_rec_separated_divorced_annulled_widowed              88504 non-null  float64
 57  marital_status_rec_single                                           88504 non-null  float64
 58  urbanicity_cat_1_rural                                              88504 non-null  float64
 59  urbanicity_cat_2_mixed                                              88504 non-null  float64
 60  ed_attainment_corr_2_completed_high_school_or_less                  88504 non-null  float64
 61  ed_attainment_corr_3_completed_primary_school_or_less               88504 non-null  float64
 62  evaluacindelprocesoteraputico_logro_intermedio                      88504 non-null  float64
 63  evaluacindelprocesoteraputico_logro_minimo                          88504 non-null  float64
 64  eva_consumo_logro_intermedio                                        88504 non-null  float64
 65  eva_consumo_logro_minimo                                            88504 non-null  float64
 66  eva_fam_logro_intermedio                                            88504 non-null  float64
 67  eva_fam_logro_minimo                                                88504 non-null  float64
 68  eva_relinterp_logro_intermedio                                      88504 non-null  float64
 69  eva_relinterp_logro_minimo                                          88504 non-null  float64
 70  eva_ocupacion_logro_intermedio                                      88504 non-null  float64
 71  eva_ocupacion_logro_minimo                                          88504 non-null  float64
 72  eva_sm_logro_intermedio                                             88504 non-null  float64
 73  eva_sm_logro_minimo                                                 88504 non-null  float64
 74  eva_fisica_logro_intermedio                                         88504 non-null  float64
 75  eva_fisica_logro_minimo                                             88504 non-null  float64
 76  eva_transgnorma_logro_intermedio                                    88504 non-null  float64
 77  eva_transgnorma_logro_minimo                                        88504 non-null  float64
dtypes: bool(4), float64(69), int32(5)
memory usage: 48.6 MB
Code
from IPython.display import display, Markdown

if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    display(Markdown(f"**First element type:** `{type(imputations_list_jan26[0])}`"))

    if isinstance(imputations_list_jan26[0], dict):
        display(Markdown(f"**First element keys:** `{list(imputations_list_jan26[0].keys())}`"))

    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        display(Markdown(f"**First element shape:** `{imputations_list_jan26[0].shape}`"))

First element type: <class 'pandas.DataFrame'>

First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Format data

Due to inconsistencies and structural heterogeneity across previously merged datasets, we decided not to proceed with a direct inspection and comparison of column names between the first imputed dataset from imputations_list_jan26 (which likely included dummy-encoded variables) and imputation_nodum_1 (which likely retained non–dummy-encoded variables).

Instead, we reconstructed the analytic datasets de novo using the most recent source files available in the original directory (BASE_DIR). Time-to-event variables were re-derived to ensure internal consistency. Variables that could introduce information leakage (e.g., time from admission) were excluded, and the center identifier variable was removed prior to modeling.

Code
#1.2. Build Surv objects from df_final
from IPython.display import display, Markdown
from sksurv.util import Surv

for i in range(1, 6):
    # Get the DataFrame
    df = globals()[f"imputation_nodum_{i}"]

    # Extract time and event arrays
    time_readm  = df["readmit_time_from_disch_m"].to_numpy()
    event_readm = (df["readmit_event"].to_numpy() == 1)
    time_death  = df["death_time_from_disch_m"].to_numpy()
    event_death = (df["death_event"].to_numpy() == 1)

    # Create survival objects
    y_surv_readm = Surv.from_arrays(event=event_readm, time=time_readm)
    y_surv_death = Surv.from_arrays(event=event_death, time=time_death)

    # Store in global variables (optional but matches your pattern)
    globals()[f"y_surv_readm_{i}"]  = y_surv_readm
    globals()[f"y_surv_death_{i}"]  = y_surv_death

    # Print info
    display(Markdown(f"\n--- Imputation {i} ---"))
    display(Markdown(
    f"**y_surv_readm dtype:** {y_surv_readm.dtype}  \n"
    f"**shape:** {y_surv_readm.shape}"
    ))
    display(Markdown(
    f"**y_surv_death dtype:** {y_surv_death.dtype}  \n"
    f"**shape:** {y_surv_death.shape}"
    ))

— Imputation 1 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 2 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 3 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 4 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 5 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

Code
fold_output(
    "Show imputation_nodum_1 (newer database) glimpse",
    lambda: glimpse(imputation_nodum_1)
)
fold_output(
    "Show first db of imputations_list_jan26 (older) glimpse",
    lambda: glimpse(imputations_list_jan26[0])
)
Show imputation_nodum_1 (newer database) glimpse
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Show first db of imputations_list_jan26 (older) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             float64         1.0, 2.0, 1.0, 1.0, 2.0
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset (1–5), we identified and removed predictors with zero variance, as they provide no useful information and can destabilize models. We printed the dropped variables and produced a cleaned version of each design matrix. This ensures that all downstream analyses use only informative predictors.

Code
# Keep only these objects
objects_to_keep = {
    "objects_to_keep",
    "imputation_nodum_1",
    "imputation_nodum_2",
    "imputation_nodum_3",
    "imputation_nodum_4",
    "imputation_nodum_5",
    "y_surv_readm",
    "y_surv_death",
    "imputations_list_jan26"
}

import types

for name in list(globals().keys()):
    obj = globals()[name]
    if (
        name not in objects_to_keep
        and not name.startswith("_")
        and not callable(obj)
        and not isinstance(obj, types.ModuleType)  # <- protects ALL modules
    ):
        del globals()[name]
Code
from IPython.display import display, Markdown

# 1. Define columns to exclude (same as before)
target_cols = [
    "readmit_time_from_disch_m",
    "readmit_event",
    "death_time_from_disch_m",
    "death_event",
]

leak_time_cols = [
    "readmit_time_from_adm_m",
    "death_time_from_adm_m",
]

center_id = ["center_id"]

cols_to_exclude = target_cols + center_id  + leak_time_cols

# 2. Create list of your EXISTING imputation DataFrames (1-5)
imputed_dfs = [
    imputation_nodum_1,
    imputation_nodum_2,
    imputation_nodum_3,
    imputation_nodum_4,
    imputation_nodum_5
]

# 3. Preprocessing loop
X_reduced_list = []

for d, df in enumerate(imputed_dfs):
    imputation_num = d + 1  # Convert 0-index to 1-index for display

    display(Markdown(f"\n=== Imputation dataset {imputation_num} ==="))

    # a) Identify and drop constant predictors
    const_mask = (df.nunique(dropna=False) <= 1)
    dropped_const = df.columns[const_mask].tolist()
    display(Markdown(f"**Constant predictors dropped ({len(dropped_const)}):**"))
    display(Markdown(f"{dropped_const if dropped_const else 'None'}"))

    # b) Remove constant columns
    X_reduced = df.loc[:, ~const_mask]

    # c) Drop target/leakage columns (if present)
    cols_to_drop = [col for col in cols_to_exclude if col in X_reduced.columns]
    if cols_to_drop:
        X_reduced = X_reduced.drop(columns=cols_to_drop)
        display(Markdown(f"**Dropped target/leakage columns:** {cols_to_drop}"))
    else:
        display(Markdown("No target/leakage columns found to drop"))

    # d) Store cleaned DataFrame
    X_reduced_list.append(X_reduced)

    # e) Report shapes
    display(Markdown(f"**Original shape:** {df.shape}"))
    display(Markdown(
        f"**Cleaned shape:** {X_reduced.shape} "
        f"(removed {df.shape[1] - X_reduced.shape[1]} columns)"
    ))

display(Markdown("\n✅ **Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.**"))

=== Imputation dataset 1 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 2 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 3 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 4 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 5 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.

Dummify

A structured preprocessing pipeline was implemented prior to modeling. Ordered categorical variables (e.g., housing status, educational attainment, clinical evaluations, and substance use frequency) were manually mapped to numeric scales reflecting their natural ordering. For nominal categorical variables, prespecified reference categories were enforced to ensure consistent baseline comparisons across imputations. All remaining categorical predictors were then converted to dummy variables using one-hot encoding with the first category dropped to prevent multicollinearity. The procedure was applied consistently across all imputed datasets to ensure harmonized model inputs.

Code
import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype

def preprocess_features_robust(df):
    df_clean = df.copy()

    # ---------------------------------------------------------
    # 1. Ordinal encoding (your existing code)
    # ---------------------------------------------------------
    ordered_mappings = {
        # --- NEW: Housing & Urbanicity ---
        "tenure_status_household": {
            "illegal settlement": 4,                       # Situación Calle
            "stays temporarily with a relative": 3,        # Allegado
            "others": 2,                                   # En pensión / Otros
            "renting": 1,                                  # Arrendando
            "owner/transferred dwellings/pays dividends": 0 # Vivienda Propia
        },
        "urbanicity_cat": {
            "1.Rural": 2,
            "2.Mixed": 1,
            "3.Urban": 0
        },

        # --- Clinical Evaluations (Minimo -> Intermedio -> Alto) ---
        "evaluacindelprocesoteraputico": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_consumo":      {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fam":          {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_relinterp":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_ocupacion":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_sm":           {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fisica":       {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_transgnorma":  {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},

        # --- Frequency (Less freq -> More freq) ---
        "prim_sub_freq_rec": {
            "1.≤1 day/wk": 0,
            "2.2–6 days/wk": 1,
            "3.Daily": 2
        },

        # --- Education (Less -> More) ---
        "ed_attainment_corr": {
            "3-Completed primary school or less": 2,
            "2-Completed high school or less": 1,
            "1-More than high school": 0
        }
    }

    for col, mapping in ordered_mappings.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            df_clean[col] = df_clean[col].map(mapping)

            n_missing = df_clean[col].isnull().sum()
            if n_missing > 0:
                if n_missing == len(df_clean):
                    print(f"⚠️ WARNING: Mapping failed completely for '{col}'.")
                mode_val = df_clean[col].mode()[0]
                df_clean[col] = df_clean[col].fillna(mode_val)

    # ---------------------------------------------------------
    # 2. FORCE reference categories for dummies
    # ---------------------------------------------------------
    dummy_reference = {
        "sex_rec": "man",
        "plan_type_corr": "ambulatory",
        "marital_status_rec": "married/cohabiting",
        "cohabitation": "alone",
        "sub_dep_icd10_status": "hazardous consumption",
        "tr_outcome": "completion",
        "adm_motive": "spontaneous consultation",
        "tipo_de_vivienda_rec2": "formal housing",
        "plan_type_corr": "pg-pab",
        "occupation_condition_corr24": "employed",
        "any_violence": "0.No domestic violence/sex abuse",
        "first_sub_used": "marijuana",
        }

    for col, ref in dummy_reference.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            cats = df_clean[col].unique().tolist()

            if ref in cats:
                new_order = [ref] + [c for c in cats if c != ref]
                cat_type = CategoricalDtype(categories=new_order, ordered=False)
                df_clean[col] = df_clean[col].astype(cat_type)
            else:
                print(f"⚠️ Reference '{ref}' not found in {col}")

    # ---------------------------------------------------------
    # 3. One-hot encoding
    # ---------------------------------------------------------
    df_final = pd.get_dummies(df_clean, drop_first=True, dtype=float)

    return df_final

X_encoded_list_final = [preprocess_features_robust(X) for X in X_reduced_list]
X_encoded_list_final = [clean_names(X) for X in X_encoded_list_final]
Code
from IPython.display import display, Markdown

# 1. DIAGNOSTIC: Check exact string values
display(Markdown("### --- Diagnostic Check ---"))
sample_df = X_encoded_list_final[0]

if 'tenure_status_household' in sample_df.columns:
    display(Markdown("**Unique values in 'tenure_status_household':**"))
    display(Markdown(str(sample_df['tenure_status_household'].unique())))
else:
    display(Markdown("❌ 'tenure_status_household' is missing entirely from input data!"))

if 'urbanicity_cat' in sample_df.columns:
    display(Markdown("**Unique values in 'urbanicity_cat':**"))
    display(Markdown(str(sample_df['urbanicity_cat'].unique())))

if 'ed_attainment_corr' in sample_df.columns:
    display(Markdown("**Unique values in 'ed_attainment_corr':**"))
    display(Markdown(str(sample_df['ed_attainment_corr'].unique())))

— Diagnostic Check —

Unique values in ‘tenure_status_household’:

[3 0 1 2 4]

Unique values in ‘urbanicity_cat’:

[0 1 2]

Unique values in ‘ed_attainment_corr’:

[1 2 0]

We recoded first substance use so small categories are grouped into Others

Code
# Columns to combine
cols_to_group = [
    "first_sub_used_opioids",
    "first_sub_used_others",
    "first_sub_used_hallucinogens",
    "first_sub_used_inhalants",
    "first_sub_used_tranquilizers_hypnotics",
    "first_sub_used_amphetamine_type_stimulants",
]

# Loop over datasets 0–4 and modify in place
for i in range(5):
    df = X_encoded_list_final[i].copy()
    # Collapse into one dummy: if any of these == 1, mark as 1
    df["first_sub_used_other"] = df[cols_to_group].max(axis=1)
    # Drop the rest except the new combined column
    df = df.drop(columns=[c for c in cols_to_group if c != "first_sub_used_other"])
    # Replace the dataset in the original list
    X_encoded_list_final[i] = df
Code
import sys
fold_output(
    "Show first db of X_encoded_list_final (newer) glimpse",
    lambda: glimpse(X_encoded_list_final[0])
)
Show first db of X_encoded_list_final (newer) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             int64           1, 2, 1, 1, 2
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset, we fitted two regularized Cox models (one for readmission and one for death) using Coxnet, which applies elastic-net penalization with a strong LASSO component to enable variable selection. The loop fits both models on every imputation, prints basic model information, and stores all fitted models so they can later be combined or compared across imputations.

Create bins for followup (landmarks)

We extracted the observed event times and corresponding event indicators directly from the structured survival objects (y_surv_readm and y_surv_death). Using the observed event times, we constructed evaluation grids based on the 5th to 95th percentiles of the event-time distribution. These grids define standardized time points at which model performance is assessed for both readmission and mortality outcomes.

Code
import numpy as np
from IPython.display import display, Markdown

# Extract event times directly from structured arrays
event_times_readm = y_surv_readm["time"][y_surv_readm["event"]]
event_times_death = y_surv_death["time"][y_surv_death["event"]]

# Build evaluation grids (5th–95th percentiles, 50 points)
times_eval_readm = np.unique(
    np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50))
)

times_eval_death = np.unique(
    np.quantile(event_times_death, np.linspace(0.05, 0.95, 50))
)

# Display only final result
display(Markdown(
    f"**Eval times (readmission):** `{times_eval_readm[:5]}` ... `{times_eval_readm[-5:]}`"
))

display(Markdown(
    f"**Eval times (death):** `{times_eval_death[:5]}` ... `{times_eval_death[-5:]}`"
))

Eval times (readmission): [0.38709677 0.67741935 1.03225806 1.41935484 1.76666667][46.81833443 50.96030612 55.16129032 60.84848585 67.08322581]

Eval times (death): [0. 0.09677419 1.06666667 2.1691691 3.34812377][74.72632653 78.4516129 82.39472203 86.41935484 92.36311828]

Correct inmortal time bias

First, we eliminated inmortal time bias (dead patients look like without readmission).

This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.

Code
import numpy as np

# Step 3. Replicate across imputations (safe copies)
n_imputations = len(X_encoded_list_final)
y_surv_readm_list = [y_surv_readm.copy() for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death.copy() for _ in range(n_imputations)]

def correct_competing_risks(y_readm_list, y_death_list):
    corrected = []
    for y_readm, y_death in zip(y_readm_list, y_death_list):
        y_corr = y_readm.copy()

        # death observed and occurs before (or at) readmission/censoring time
        mask = (y_death["event"]) & (y_death["time"] < y_corr["time"])

        y_corr["event"][mask] = False
        y_corr["time"][mask] = y_death["time"][mask]

        corrected.append(y_corr)
    return corrected

# Step 4. Apply correction
y_surv_readm_list_corrected = correct_competing_risks(
    y_surv_readm_list,
    y_surv_death_list
)
Code
# Check type and length
type(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)

# Look at the first element
y_surv_readm_list_corrected[0][:5]   # first 5 rows
array([(False, 68.96774194), ( True,  7.        ), ( True, 13.25806452),
       ( True,  5.        ), ( True,  7.35483871)],
      dtype=[('event', '?'), ('time', '<f8')])
Code
from IPython.display import display, HTML
import html

def nb_print(*args, sep=" "):
    msg = sep.join(str(a) for a in args)
    display(HTML(f"<pre style='margin:0'>{html.escape(msg)}</pre>"))

The fully preprocessed and encoded feature matrices were renamed from X_encoded_list_final to imputations_list_mar26 to reflect the finalized February 2026 analytic version of the imputed datasets.

This object contains the harmonized, ordinal-encoded, and one-hot encoded predictor matrices for all five imputations and will serve as the definitive input for subsequent modeling procedures.

Code
imputations_list_mar26 = X_encoded_list_final
del X_encoded_list_final
Code
import numpy as np
import pandas as pd
from IPython.display import display, Markdown

# ── Build exclusion mask (same for all imputations, based on imputation 0) ──
df0 = imputations_list_mar26[0]

# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7
mask_adm_death = (
    (df0["tr_outcome_adm_discharge_adm_reasons"] == 1)
    & (y_surv_death["event"] == True)
    & (y_surv_death["time"] <= 0.23)
)

# Condition 2: tr_outcome_other == 1 (any time)
mask_other = df0["tr_outcome_other"] == 1

# Combined exclusion mask
exclude = mask_adm_death | mask_other
keep = ~exclude

# ── Report ──
n_total = len(df0)
n_excl_adm = mask_adm_death.sum()
n_excl_other = mask_other.sum()
n_excl_both = (mask_adm_death & mask_other).sum()
n_excl_total = exclude.sum()
n_remaining = keep.sum()

report = f"""### Exclusion Report

| Criterion | n excluded |
|---|---:|
| `tr_outcome_adm_discharge_adm_reasons == 1` & death time ≤ 7 days | {n_excl_adm} |
| `tr_outcome_other == 1` (any time) | {n_excl_other} |
| Both criteria (overlap) | {n_excl_both} |
| **Total unique excluded** | **{n_excl_total}** |
| **Remaining observations** | **{n_remaining}** / {n_total} |
"""
display(Markdown(report))

# ── Apply filter to all imputation lists and outcome arrays ──
imputations_list_mar26 = [df.loc[keep].reset_index(drop=True) for df in imputations_list_mar26]
y_surv_readm_list_mar26 = [y[keep] for y in y_surv_readm_list_corrected]
y_surv_death_list_mar26 = [y[keep] for y in y_surv_death_list]
y_surv_readm_list_corrected_mar26 = [y[keep] for y in y_surv_readm_list_corrected]

# Single (non-list) outcome arrays for convenience
y_surv_readm_mar26 = y_surv_readm[keep]
y_surv_death_mar26 = y_surv_death[keep]

# Rebuild eval time grids on the filtered data
event_times_readm_mar26 = y_surv_readm_mar26["time"][y_surv_readm_mar26["event"]]
event_times_death_mar26 = y_surv_death_mar26["time"][y_surv_death_mar26["event"]]

times_eval_readm_mar26 = np.unique(
    np.quantile(event_times_readm_mar26, np.linspace(0.05, 0.95, 50))
)
times_eval_death_mar26 = np.unique(
    np.quantile(event_times_death_mar26, np.linspace(0.05, 0.95, 50))
)

Exclusion Report

Criterion n excluded
tr_outcome_adm_discharge_adm_reasons == 1 & death time ≤ 7 days 137
tr_outcome_other == 1 (any time) 215
Both criteria (overlap) 0
Total unique excluded 352
Remaining observations 88152 / 88504

Train / test split (80/20)

  1. Sets a fixed random seed to make the 80/20 split exactly reproducible.

  2. Verifies required datasets exist (features and survival outcomes) before doing anything.

  3. Creates a “death-corrected” outcome list if it was not already available.

  4. Derives stratification labels from treatment plan and completion categories plus readmission/death events.

  5. Uses a step-down strategy if some strata are too rare, simplifying the stratification to keep it feasible.

  6. Caches a “full snapshot” of all imputations and outcomes so reruns don’t silently change the split.

  7. Checks row alignment so every imputation and every outcome has the same number of observations.

  8. Optionally checks stability across imputations for plan/completion columns (should not vary much).

  9. Loads split indices from disk when available, ensuring the exact same train/test split across sessions.

  10. Builds train/test datasets consistently for all imputations, then runs strict diagnostics to confirm balance.

Code
#@title 🧪 / 🎓 Reproducible 80/20 split before ML (integrated + idempotent + persisted)
# Stratification hierarchy:
#   1) plan + completion + readm_event + death_event
#   2) mixed fallback for rare full strata (<2) -> plan + readm + death
#   3) full fallback -> plan + readm + death

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from IPython.display import display, Markdown

SEED = 2125
TEST_SIZE = 0.20
FORCE_RESPLIT = False          # True to force new split
STRICT_SPLIT_CHECKS = True     # CI-style hard checks
MAX_EVENT_GAP = 0.01           # 1% tolerance
PERSIST_SPLIT_INDICES = True

# --- Project-root anchored paths ---
from pathlib import Path

def find_project_root(markers=("AGENTS.md", ".git")):
    try:
        cur = Path.cwd().resolve()
    except OSError as e:
        raise RuntimeError(
            "Invalid working directory. Run this notebook from inside the project folder."
        ) from e

    for p in (cur, *cur.parents):
        if any((p / m).exists() for m in markers):
            return p

    raise RuntimeError(
        f"Could not locate project root starting from {cur}. "
        f"Expected one of markers: {markers}."
    )

PROJECT_ROOT = find_project_root()
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_FILE = OUT_DIR / f"death_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.npz"

def nb_print_md(msg):
    display(Markdown(str(msg)))

nb_print_md(f"**Project root:** `{PROJECT_ROOT}`")


# ---------- Requirements ----------
required = [
    "imputations_list_mar26",
    "y_surv_readm_list_mar26",
    "y_surv_readm_list_corrected_mar26",
    "y_surv_death_list_mar26",
]
missing = [v for v in required if v not in globals()]
if missing:
    raise ValueError(f"Missing required objects: {missing}")

if "y_surv_death_list_corrected_mar26" not in globals():
    y_surv_death_list_corrected_mar26 = [y.copy() for y in y_surv_death_list_mar26]

# ---------- Helpers ----------
def get_plan_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "plan_type_corr_pg_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "plan_type_corr_m_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "plan_type_corr_m_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    return labels

def get_completion_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "tr_outcome_referral" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_referral"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "tr_outcome_dropout" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_dropout"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "tr_outcome_adm_discharge_rule_violation_undet" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_rule_violation_undet"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "tr_outcome_adm_discharge_adm_reasons" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_adm_reasons"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    if "tr_outcome_other" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_other"], errors="coerce").fillna(0).to_numpy() == 1] = 5
    return labels

def build_strata(X0, y_readm0, y_death0):
    """
    Build stratification labels with progressive fallback:

    1) full: plan + completion + readmission_event + death_event
    2) mixed: only rare full strata (<2 rows) are replaced by fallback labels
       (plan + readmission_event + death_event)
    3) fallback: plan + readmission_event + death_event for all rows

    Returns:
        strata (np.ndarray), mode (str), readm_evt (np.ndarray), death_evt (np.ndarray)
    """
    plan = get_plan_labels(X0)
    comp = get_completion_labels(X0)
    readm_evt = y_readm0["event"].astype(int)
    death_evt = y_death0["event"].astype(int)

    full = pd.Series(plan.astype(str) + "_" + comp.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    if full.value_counts().min() >= 2:
        return full.to_numpy(), "full(plan+completion+readm+death)", readm_evt, death_evt

    fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    mixed = full.copy()
    rare = mixed.map(mixed.value_counts()) < 2
    mixed[rare] = fb[rare]
    if mixed.value_counts().min() >= 2:
        return mixed.to_numpy(), "mixed(rare->plan+readm+death)", readm_evt, death_evt

    if fb.value_counts().min() >= 2:
        return fb.to_numpy(), "fallback(plan+readm+death)", readm_evt, death_evt

    raise ValueError("Could not build stratification labels with >=2 rows per stratum.")

def split_df_list(df_list, tr_idx, te_idx):
    tr = [df.iloc[tr_idx].reset_index(drop=True).copy() for df in df_list]
    te = [df.iloc[te_idx].reset_index(drop=True).copy() for df in df_list]
    return tr, te

def split_surv_list(y_list, tr_idx, te_idx):
    tr = [y[tr_idx].copy() for y in y_list]
    te = [y[te_idx].copy() for y in y_list]
    return tr, te

# ---------- Cache full data once (idempotent re-runs) ----------
if "_split_cache_death_mar26" not in globals():
    _split_cache_death_mar26 = {}
cache = _split_cache_death_mar26

if FORCE_RESPLIT:
    cache.pop("idx", None)

if FORCE_RESPLIT or "full" not in cache:
    cache["full"] = {
        "X": [df.reset_index(drop=True).copy() for df in imputations_list_mar26],
        "y_readm": [y.copy() for y in y_surv_readm_list_mar26],
        "y_readm_corr": [y.copy() for y in y_surv_readm_list_corrected_mar26],
        "y_death": [y.copy() for y in y_surv_death_list_mar26],
        "y_death_corr": [y.copy() for y in y_surv_death_list_corrected_mar26],
    }

full = cache["full"]

# ---------- Consistency checks ----------
n_imp = len(full["X"])
n = len(full["X"][0])

if any(len(df) != n for df in full["X"]):
    raise ValueError("Row mismatch inside full X list.")

for name, obj in [
    ("y_readm", full["y_readm"]),
    ("y_readm_corr", full["y_readm_corr"]),
    ("y_death", full["y_death"]),
    ("y_death_corr", full["y_death_corr"]),
]:
    if len(obj) != n_imp:
        raise ValueError(f"{name} length ({len(obj)}) != n_imputations ({n_imp})")
    if any(len(y) != n for y in obj):
        raise ValueError(f"Row mismatch between X and {name}.")

# ---------- Optional diagnostic: plan/completion consistency across imputations ----------
plan_comp_cols = [
    c for c in [
        "plan_type_corr_pg_pr",
        "plan_type_corr_m_pr",
        "plan_type_corr_pg_pai",
        "plan_type_corr_m_pai",
        "tr_outcome_referral",
        "tr_outcome_dropout",
        "tr_outcome_adm_discharge_rule_violation_undet",
        "tr_outcome_adm_discharge_adm_reasons"#,
        #"tr_outcome_other", #2026-03-26: Excluded from consistency check since it was used as an exclusion criterion and thus may differ by design across imputations
    ] if c in full["X"][0].columns
]

max_diff_rows = 0
if plan_comp_cols:
    base_pc = full["X"][0][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
    for i in range(1, n_imp):
        cur_pc = full["X"][i][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
        diff_rows = int((base_pc != cur_pc).any(axis=1).sum())
        max_diff_rows = max(max_diff_rows, diff_rows)

# ---------- Try loading indices from disk ----------
loaded_from_disk = False
if PERSIST_SPLIT_INDICES and (not FORCE_RESPLIT) and SPLIT_FILE.exists() and ("idx" not in cache):
    z = np.load(SPLIT_FILE, allow_pickle=False)
    tr = z["train_idx"].astype(int)
    te = z["test_idx"].astype(int)
    n_disk = int(z["n_full"][0]) if "n_full" in z else n
    if n_disk == n and tr.max() < n and te.max() < n:
        cache["idx"] = (np.sort(tr), np.sort(te))
        cache["strat_mode"] = str(z["strat_mode"][0]) if "strat_mode" in z else "loaded_from_disk"
        loaded_from_disk = True

# ---------- Compute or reuse split indices ----------
if FORCE_RESPLIT or "idx" not in cache:
    strata_used, strat_mode, readm_evt_all, death_evt_all = build_strata(
        full["X"][0], full["y_readm"][0], full["y_death"][0]
    )
    idx = np.arange(n)
    train_idx, test_idx = train_test_split(
        idx, test_size=TEST_SIZE, random_state=SEED, shuffle=True, stratify=strata_used
    )
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    cache["idx"] = (train_idx, test_idx)
    cache["strat_mode"] = strat_mode

    if PERSIST_SPLIT_INDICES:
        np.savez_compressed(
            SPLIT_FILE,
            train_idx=train_idx,
            test_idx=test_idx,
            n_full=np.array([n], dtype=int),
            seed=np.array([SEED], dtype=int),
            test_size=np.array([TEST_SIZE], dtype=float),
            strat_mode=np.array([strat_mode], dtype="U64"),
        )
else:
    train_idx, test_idx = cache["idx"]
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    readm_evt_all = full["y_readm"][0]["event"].astype(int)
    death_evt_all = full["y_death"][0]["event"].astype(int)

# ---------- Build train/test from full snapshot every run ----------
imputations_list_mar26_train, imputations_list_mar26_test = split_df_list(full["X"], train_idx, test_idx)

y_surv_readm_list_train, y_surv_readm_list_test = split_surv_list(full["y_readm"], train_idx, test_idx)
y_surv_readm_list_corrected_train, y_surv_readm_list_corrected_test = split_surv_list(full["y_readm_corr"], train_idx, test_idx)

y_surv_death_list_train, y_surv_death_list_test = split_surv_list(full["y_death"], train_idx, test_idx)
y_surv_death_list_corrected_train, y_surv_death_list_corrected_test = split_surv_list(full["y_death_corr"], train_idx, test_idx)

# Downstream code uses TRAIN only
imputations_list_mar26 = imputations_list_mar26_train
y_surv_readm_list = y_surv_readm_list_train
y_surv_readm_list_corrected = y_surv_readm_list_corrected_train
y_surv_death_list = y_surv_death_list_train
y_surv_death_list_corrected = y_surv_death_list_corrected_train

# ---------- Diagnostics + strict checks ----------
strata_diag, strat_mode_diag, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])
sdiag = pd.Series(strata_diag)
train_strata = set(sdiag.iloc[train_idx].unique())
test_strata = set(sdiag.iloc[test_idx].unique())
missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

readm_gap = abs(readm_evt_all[train_idx].mean() - readm_evt_all[test_idx].mean())
death_gap = abs(death_evt_all[train_idx].mean() - death_evt_all[test_idx].mean())

# full-strata rarity report (before fallback)
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)
vc_full = strata_full.value_counts()
rare_rows = int((strata_full.map(vc_full) < 2).sum())

if STRICT_SPLIT_CHECKS:
    assert len(np.intersect1d(train_idx, test_idx)) == 0, "Train/Test index overlap detected."
    assert (len(train_idx) + len(test_idx)) == n, "Train/Test sizes do not sum to n."
    assert len(missing_in_test) == 0, f"Strata missing in test: {missing_in_test}"
    assert len(missing_in_train) == 0, f"Strata missing in train: {missing_in_train}"
    assert readm_gap < MAX_EVENT_GAP, f"Readmission rate imbalance > {MAX_EVENT_GAP:.0%} (gap={readm_gap:.4f})"
    assert death_gap < MAX_EVENT_GAP, f"Death rate imbalance > {MAX_EVENT_GAP:.0%} (gap={death_gap:.4f})"

# ---------- Summary ----------
nb_print_md(f"**Loaded indices from disk:** `{loaded_from_disk}`")
nb_print_md(f"**Split file:** `{SPLIT_FILE}`")
nb_print_md(f"**Split mode used:** `{cache.get('strat_mode', strat_mode_diag)}`")
nb_print_md(f"**Plan/completion diff rows across imputations (max vs imp0):** `{max_diff_rows}`")
nb_print_md(f"**Full strata count:** `{vc_full.size}` | **Min full stratum size:** `{int(vc_full.min())}` | **Rows in rare full strata (<2):** `{rare_rows}`")
nb_print_md(f"**Train/Test sizes:** `{len(train_idx)}` ({len(train_idx)/n:.1%}) / `{len(test_idx)}` ({len(test_idx)/n:.1%})")
nb_print_md(
    "**Readmission rate all/train/test:** "
    f"`{readm_evt_all.mean():.3%}` / `{readm_evt_all[train_idx].mean():.3%}` / `{readm_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    "**Death rate all/train/test:** "
    f"`{death_evt_all.mean():.3%}` / `{death_evt_all[train_idx].mean():.3%}` / `{death_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    f"**Strata in train/test:** `{len(train_strata)}` / `{len(test_strata)}` | "
    f"**Missing train→test:** `{len(missing_in_test)}` | **Missing test→train:** `{len(missing_in_train)}`"
)

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Loaded indices from disk: True

Split file: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\death_split_seed2125_test20_mar26.npz

Split mode used: fallback(plan+readm+death)

Plan/completion diff rows across imputations (max vs imp0): 0

Full strata count: 95 | Min full stratum size: 1 | Rows in rare full strata (<2): 4

Train/Test sizes: 70521 (80.0%) / 17631 (20.0%)

Readmission rate all/train/test: 21.622% / 21.621% / 21.627%

Death rate all/train/test: 4.310% / 4.309% / 4.311%

Strata in train/test: 20 / 20 | Missing train→test: 0 | Missing test→train: 0

Code
# counts per stratum in train/test
train_counts = sdiag.iloc[train_idx].value_counts()
test_counts  = sdiag.iloc[test_idx].value_counts()

min_train = int(train_counts.min())
min_test  = int(test_counts.min())

nb_print_md(f"**Min stratum count in TRAIN (used strata):** `{min_train}`")
nb_print_md(f"**Min stratum count in TEST (used strata):** `{min_test}`")

# strata that got 0 in test or 0 in train
zero_in_test = sorted(set(train_counts.index) - set(test_counts.index))
zero_in_train = sorted(set(test_counts.index) - set(train_counts.index))

nb_print_md(f"**Strata with 0 in TEST:** `{len(zero_in_test)}`")
nb_print_md(f"**Strata with 0 in TRAIN:** `{len(zero_in_train)}`")

# show examples with their full-data counts
if len(zero_in_test) > 0:
    ex = zero_in_test[:10]
    nb_print_md(f"**Examples 0 in TEST (up to 10):** `{ex}`")
    nb_print_md(f"**Full-data counts:** `{[int(sdiag.value_counts()[k]) for k in ex]}`")

Min stratum count in TRAIN (used strata): 18

Min stratum count in TEST (used strata): 5

Strata with 0 in TEST: 0

Strata with 0 in TRAIN: 0

Code
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)

vc = strata_full.value_counts()
display(Markdown(f"**# full strata:** `{vc.size}`"))
display(Markdown(f"**Min stratum size (full):** `{int(vc.min())}`"))
display(Markdown(f"**# strata with count < 2:** `{int((vc < 2).sum())}`"))

# full strata: 95

Min stratum size (full): 1

# strata with count < 2: 4

Code
plan = get_plan_labels(full["X"][0])
readm_evt = full["y_readm"][0]["event"].astype(int)
death_evt = full["y_death"][0]["event"].astype(int)

fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))

rare_mask = strata_full.map(strata_full.value_counts()) < 2
n_rare = int(rare_mask.sum())

display(Markdown(f"**Rows in rare full-strata (<2):** `{n_rare}`"))
if n_rare > 0:
    display(Markdown(
        f"**Rare rows proportion:** `{n_rare/len(strata_full):.3%}`"
    ))

Rows in rare full-strata (<2): 4

Rare rows proportion: 0.005%

Code
# Use the actual stratification mode that was used to split
strata_used, strat_mode, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])

s = pd.Series(strata_used)
train_strata = set(s.iloc[train_idx].unique())
test_strata = set(s.iloc[test_idx].unique())

missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

display(Markdown(f"**Strata used:** `{strat_mode}`"))
display(Markdown(f"**# strata in train:** `{len(train_strata)}` | **# strata in test:** `{len(test_strata)}`"))
display(Markdown(f"**Strata present in train but missing in test:** `{len(missing_in_test)}`"))
display(Markdown(f"**Strata present in test but missing in train:** `{len(missing_in_train)}`"))

Strata used: fallback(plan+readm+death)

# strata in train: 20 | # strata in test: 20

Strata present in train but missing in test: 0

Strata present in test but missing in train: 0

Code
from pathlib import Path
import pandas as pd
import numpy as np
from IPython.display import display, Markdown

PROJECT_ROOT = find_project_root()   # no hardcoded absolute path
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_PARQUET = OUT_DIR / f"death_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"

split_df = pd.DataFrame({
    "row_id": np.arange(n),
    "is_train": np.isin(np.arange(n), train_idx)
})

split_df.to_parquet(SPLIT_PARQUET, index=False)

display(Markdown(f"**Project root:** `{PROJECT_ROOT}`"))
display(Markdown(f"**Saved split to:** `{SPLIT_PARQUET}`"))

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Saved split to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\death_split_seed2125_test20_mar26.parquet

Code
import pandas as pd
import numpy as np
from pathlib import Path

SEED = 2125
TEST_SIZE = 0.20

# Use the first imputation (complete data)
X_full = full["X"][0]
y_death_full = full["y_death"][0]

# Find admission age column
age_col = 'adm_age_rec3' if 'adm_age_rec3' in X_full.columns else \
          [c for c in X_full.columns if 'adm_age' in c][0]

# Create the 4-column split file
split_export = pd.DataFrame({
    'row_id': np.arange(1, len(X_full) + 1),  # 1-based for R
    'is_train': [i in train_idx for i in range(len(X_full))],
    'death_time_from_disch_m': np.round(y_death_full['time'], 2),
    'adm_age_rec3': X_full[age_col]
})

out_dir = PROJECT_ROOT / "_out"
out_dir.mkdir(exist_ok=True)

# Export
fname = out_dir / f"death_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"
split_export.to_parquet(fname, index=False)

nb_print(f"Exported: {fname}")
nb_print(f"Total: {len(split_export)} rows")
nb_print(f"Train: {split_export['is_train'].sum()} ({100*split_export['is_train'].mean():.1f}%)")
nb_print(f"Test: {(~split_export['is_train']).sum()} ({100*(~split_export['is_train']).mean():.1f}%)")
nb_print(f"\nFirst 5 rows:")
nb_print(split_export.head())
Exported: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\death_split_seed2125_test20_mar26.parquet
Total: 88152 rows
Train: 70521 (80.0%)
Test: 17631 (20.0%)
First 5 rows:
   row_id  is_train  death_time_from_disch_m  adm_age_rec3
0       1      True                    68.97         31.53
1       2      True                    81.32         20.61
2       3      True                   116.74         42.52
3       4     False                    91.97         60.61
4       5      True                    31.03         45.08
Code
df0 = imputations_list_mar26[0]

# Calculate exactly what you need
mean_age = df0["adm_age_rec3"].mean()
count_foreign = (df0["national_foreign"] == 1).sum()

# Print results
nb_print(f"Mean of adm_age_rec3: {mean_age:.4f}")
nb_print(f"Count of national_foreign == 1: {count_foreign}")
Mean of adm_age_rec3: 35.7256
Count of national_foreign == 1: 453

We cleaned our environment safely so that:

  • Old models

  • Temporary objects

  • Large intermediate datasets

do not interfere with the next modeling block.

Code
# Safe cleanup before Readmission XGBoost blocks
import types
import gc

# Ensure logger exists (some target cells expect it)
if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Compatibility: one Optuna/Bootstrap block checks jan26 naming
#if "imputations_list_jan26" not in globals() and "imputations_list_mar26" in globals():
#    imputations_list_jan26 = imputations_list_mar26

KEEP = {
    "nb_print", "study",
    "imputations_list_mar26", #"imputations_list",#"imputations_list_jan26"
    "X_train", "y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list",
    # Optional plot config objects:
    "plt", "sns", "matplotlib", "mpl", "rcParams", "PROJECT_ROOT"
}

# ensure both variants are kept
KEEP.update({
    "y_surv_readm_list_corrected_mar26", "y_surv_readm_list_corrected",
    "y_surv_readm_list_mar26", "y_surv_death_list_mar26"
})

# after cleanup, sanity-check alignment before tuning
if "imputations_list_mar26" in globals() and "y_surv_readm_list_corrected" in globals():
    assert len(imputations_list_mar26[0]) == len(y_surv_readm_list_corrected[0]), \
        f"Row mismatch: X={len(imputations_list_mar26[0])}, y={len(y_surv_readm_list_corrected[0])}"
        
for name, obj in list(globals().items()):
    if name in KEEP or name.startswith("_"):
        continue
    if isinstance(obj, types.ModuleType):   # keep imports
        continue
    if callable(obj):                        # keep functions/classes
        continue
    del globals()[name]

gc.collect()

required = ["y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list"]
missing = [x for x in required if x not in globals()]
nb_print("Missing required objects:", missing)
Missing required objects: []

ML

Advanced Survival Modeling: XGBoost & Stratified Evaluation

In this section, we transition to a Gradient Boosted Decision Tree (GBDT) framework using XGBoost. This serves as a robust non-linear benchmark to complement the neural network analysis for low-event survival data (approximately 4% death events).

Methodological Framework

  • Cox-Objective Boosting: We use the survival:cox objective, which optimizes the Cox partial log-likelihood within a boosting architecture. This enables flexible non-linear risk modeling and interaction learning while retaining the proportional hazards formulation.

  • 5-Fold Cross-Validation with Stratification (Death Model): We use 5-fold cross-validation with stratification to preserve key data structure across folds. In the death model, stratification is based on a combined label of treatment plan type and event status, with fallback to simpler stratification when rare strata make 5-fold splitting infeasible. (The readmission model currently uses plan-type-only stratification.)

  • Censoring-Aware Evaluation: Hyperparameter selection is based on Uno’s C-Index (IPCW), which is appropriate under right censoring. We also report the Integrated Brier Score (IBS) as a complementary calibration/overall prediction error metric.

Hyperparameter Optimization

Given the low event rate, we run a randomized search over a dense parameter grid. The search emphasizes regularization and tree-complexity controls (min_child_weight, gamma, reg_alpha, reg_lambda) to improve generalization and reduce overfitting.

Breslow Estimation

Because XGBoost with survival:cox outputs relative risk scores, we estimate the baseline hazard/survival using the Breslow estimator to derive absolute survival probabilities 𝑆(𝑡∣𝑥), enabling time-specific calibration metrics such as IBS.

Parameter tuning

Code
#@title ⚡ XGBoost Death Robust Tuning (100 Iterations, CPU Only, Dual Stratification + Fallback)
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import StratifiedKFold, KFold, ParameterSampler
from sksurv.metrics import concordance_index_ipcw, concordance_index_censored
import time
import gc
import os
from datetime import datetime
import warnings
from pathlib import Path

warnings.filterwarnings("ignore")

# Fallback in case nb_print is not defined globally
if 'nb_print' not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

total_start_time = time.time()

# --- CPU CONFIGURATION ---
N_CORES = max(1, os.cpu_count() - 2)
nb_print(f"⚙️ Parallel Execution Configured: Using {N_CORES} CPU cores.")

# --- 1. SETUP & DATA ---
nb_print("Preparing data for Robust XGBoost Tuning (Death)...")

try:
    if 'imputations_list_mar26' in locals():
        df_tune = imputations_list_mar26[0].copy()
        y_tune_struct = y_surv_death_list[0]
    else:
        df_tune = X_train.copy()
        y_tune_struct = y_surv_death_list[0]

    # Row alignment checks
    assert len(df_tune) == len(y_tune_struct), (
        f"X/y mismatch: df_tune={len(df_tune)}, y_tune_struct={len(y_tune_struct)}"
    )

    nb_print(f"  Data Shape    : {df_tune.shape}")
    nb_print(f"  Target        : Death (Events: {np.asarray(y_tune_struct['event']).sum()}, "
             f"Rate: {np.asarray(y_tune_struct['event']).mean():.3%})")

except Exception as e:
    raise ValueError(f"Data Error: {e}. Please run data loading and split cells first.")

# --- 2. STRATIFICATION HELPERS (Dual + Fallback) ---
def get_plan_stratification_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

def get_dual_stratification_labels(df, y_struct):
    labels = get_plan_stratification_labels(df)
    event_status = np.asarray(y_struct['event']).astype(int)
    return (labels * 10) + event_status

def choose_stratification(df, y_struct, n_splits=5):
    dual_labels = get_dual_stratification_labels(df, y_struct)
    dual_u, dual_c = np.unique(dual_labels, return_counts=True)

    if len(dual_u) > 1 and dual_c.min() >= n_splits:
        nb_print("Stratification mode: dual (plan_type x event).")
        return dual_labels, "dual", True

    rare_dual = {int(k): int(v) for k, v in zip(dual_u, dual_c) if v < n_splits}
    nb_print(f"[Fallback triggered] Dual stratification has classes with < {n_splits} samples: {rare_dual}")

    plan_labels = get_plan_stratification_labels(df)
    plan_u, plan_c = np.unique(plan_labels, return_counts=True)
    if len(plan_u) > 1 and plan_c.min() >= n_splits:
        nb_print("Stratification mode: plan_type only (fallback).")
        return plan_labels, "plan_only", True

    event_labels = np.asarray(y_struct['event']).astype(int)
    ev_u, ev_c = np.unique(event_labels, return_counts=True)
    if len(ev_u) > 1 and ev_c.min() >= n_splits:
        nb_print("Stratification mode: event only (fallback).")
        return event_labels, "event_only", True

    nb_print(f"[Fallback triggered] No valid stratification for {n_splits} folds. Using unstratified KFold.")
    return None, "kfold", False

strat_labels, strat_mode, use_stratified = choose_stratification(df_tune, y_tune_struct, n_splits=5)
y_xgb_label = np.where(np.asarray(y_tune_struct['event']), 
                       np.asarray(y_tune_struct['time']), 
                       -np.asarray(y_tune_struct['time']))

# --- 3. SEARCH SPACE ---
param_grid = {
    'learning_rate': [0.005, 0.01, 0.02, 0.05, 0.1],
    'max_depth': [3, 4, 5, 6, 8],
    'min_child_weight': [1, 5, 10, 20, 50],
    'subsample': [0.6, 0.7, 0.8, 0.9],
    'colsample_bytree': [0.5, 0.6, 0.7, 0.8],
    'reg_alpha': [0, 0.1, 1, 5, 10],
    'reg_lambda': [0.1, 1, 5, 10, 20],
    'gamma': [0, 0.1, 0.5, 1, 2]
}

N_ITER = 100
param_list = list(ParameterSampler(param_grid, n_iter=N_ITER, random_state=2125))

# --- 4. TUNING LOOP ---
nb_print(f"\n🚀 Starting Exhaustive Search ({N_ITER} combos)...")
nb_print(f"  Strategy: 5-Fold CV | Stratification mode: {strat_mode}")
nb_print(f"  Metric: Uno's C-Index (IPCW)")

results = []

if use_stratified:
    cv = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)
else:
    cv = KFold(n_splits=5, shuffle=True, random_state=2125)

for i, sampled_params in enumerate(param_list):
    iter_start = time.time()
    
    params = sampled_params.copy()
    params['objective'] = 'survival:cox'
    params['eval_metric'] = 'cox-nloglik'
    params['tree_method'] = 'hist'
    params['seed'] = 2125
    params['nthread'] = N_CORES
    params['device'] = 'cpu'
    params['verbosity'] = 0

    fold_scores = []

    if use_stratified:
        split_iter = cv.split(df_tune, strat_labels)
    else:
        split_iter = cv.split(df_tune)

    for train_idx, val_idx in split_iter:
        X_tr, X_va = df_tune.iloc[train_idx], df_tune.iloc[val_idx]
        y_tr_xgb, y_va_xgb = y_xgb_label[train_idx], y_xgb_label[val_idx]
        y_tr_struct, y_va_struct = y_tune_struct[train_idx], y_tune_struct[val_idx]

        dtrain = xgb.DMatrix(X_tr, label=y_tr_xgb)
        dval = xgb.DMatrix(X_va, label=y_va_xgb)

        model = xgb.train(
            params, dtrain, num_boost_round=1500,
            evals=[(dval, 'val')], early_stopping_rounds=30,
            verbose_eval=False
        )

        risk_scores = model.predict(dval)

        try:
            c_val = concordance_index_ipcw(y_tr_struct, y_va_struct, risk_scores)[0]
        except Exception:
            c_val = concordance_index_censored(
                np.asarray(y_va_struct['event']),
                np.asarray(y_va_struct['time']),
                risk_scores
            )[0]

        fold_scores.append(c_val)

        del model, dtrain, dval, risk_scores
        gc.collect()

    avg_score = np.mean(fold_scores)
    std_score = np.std(fold_scores)
    results.append({**params, 'Unos_C_Index': avg_score, 'Std_Dev': std_score, 'Strat_Mode': strat_mode})

    if (i + 1) % 5 == 0:
        elapsed_min = (time.time() - total_start_time) / 60
        best_so_far = max(r['Unos_C_Index'] for r in results)
        nb_print(f"  [{i+1}/{N_ITER}] Best: {best_so_far:.4f} | Current: {avg_score:.4f} | Elapsed: {elapsed_min:.2f} min")

# --- 5. FINALIZE & EXPORT ---
total_duration_min = (time.time() - total_start_time) / 60
nb_print(f"\n🏁 Total Execution Time: {total_duration_min:.2f} minutes")

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")

OUT_DIR = Path(PROJECT_ROOT) / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

df_results = pd.DataFrame(results).sort_values(by='Unos_C_Index', ascending=False)
best_config = df_results.iloc[0].to_dict()

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
filename = OUT_DIR / f"XGB_Death_Robust_Tuning_5Fold_{timestamp_str}_mar26.csv"
df_results.to_csv(filename, index=False)

nb_print(f"\n🏆 Tuning Complete!")
nb_print(f"  Best C-Index: {best_config['Unos_C_Index']:.4f}")
nb_print(f"  Stratification used: {best_config['Strat_Mode']}")
nb_print(f"Saved to: {filename}")
⚙️ Parallel Execution Configured: Using 30 CPU cores.
Preparing data for Robust XGBoost Tuning (Death)...
  Data Shape    : (70521, 56)
  Target        : Death (Events: 3039, Rate: 4.309%)
Stratification mode: dual (plan_type x event).
🚀 Starting Exhaustive Search (100 combos)...
  Strategy: 5-Fold CV | Stratification mode: dual
  Metric: Uno's C-Index (IPCW)
  [5/100] Best: 0.7476 | Current: 0.7461 | Elapsed: 0.61 min
  [10/100] Best: 0.7476 | Current: 0.7439 | Elapsed: 1.47 min
  [15/100] Best: 0.7476 | Current: 0.7468 | Elapsed: 2.78 min
  [20/100] Best: 0.7476 | Current: 0.7436 | Elapsed: 3.63 min
  [25/100] Best: 0.7476 | Current: 0.7454 | Elapsed: 4.54 min
  [30/100] Best: 0.7486 | Current: 0.7436 | Elapsed: 5.49 min
  [35/100] Best: 0.7486 | Current: 0.7434 | Elapsed: 6.57 min
  [40/100] Best: 0.7486 | Current: 0.7430 | Elapsed: 7.62 min
  [45/100] Best: 0.7486 | Current: 0.7424 | Elapsed: 8.58 min
  [50/100] Best: 0.7486 | Current: 0.7452 | Elapsed: 9.27 min
  [55/100] Best: 0.7486 | Current: 0.7467 | Elapsed: 10.40 min
  [60/100] Best: 0.7486 | Current: 0.7460 | Elapsed: 11.79 min
  [65/100] Best: 0.7486 | Current: 0.7454 | Elapsed: 13.46 min
  [70/100] Best: 0.7486 | Current: 0.7459 | Elapsed: 14.37 min
  [75/100] Best: 0.7486 | Current: 0.7453 | Elapsed: 15.14 min
  [80/100] Best: 0.7486 | Current: 0.7483 | Elapsed: 16.50 min
  [85/100] Best: 0.7486 | Current: 0.7453 | Elapsed: 17.43 min
  [90/100] Best: 0.7486 | Current: 0.7455 | Elapsed: 18.81 min
  [95/100] Best: 0.7486 | Current: 0.7461 | Elapsed: 19.68 min
  [100/100] Best: 0.7486 | Current: 0.7442 | Elapsed: 20.36 min
🏁 Total Execution Time: 20.36 minutes
🏆 Tuning Complete!
  Best C-Index: 0.7486
  Stratification used: dual
Saved to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\XGB_Death_Robust_Tuning_5Fold_20260305_1832_mar26.csv
Code
nb_print(f"\nTuning Complete!")
nb_print(f"  Best C-Index: {best_config['Unos_C_Index']:.4f}")
Tuning Complete!
  Best C-Index: 0.7486
Code
nb_print(best_config)
{'subsample': 0.7, 'reg_lambda': 5.0, 'reg_alpha': 0.0, 'min_child_weight': 5, 'max_depth': 6, 'learning_rate': 0.01, 'gamma': 0.0, 'colsample_bytree': 0.5, 'objective': 'survival:cox', 'eval_metric': 'cox-nloglik', 'tree_method': 'hist', 'seed': 2125, 'nthread': 30, 'device': 'cpu', 'verbosity': 0, 'Unos_C_Index': 0.7486327059193654, 'Std_Dev': 0.0054521701933017835, 'Strat_Mode': 'dual'}
Code
import pandas as pd
from IPython.display import HTML, display

# Reset options so Pandas doesn't force everything
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Convert DataFrame to HTML and wrap in a scrollable div
html_table = df_results.to_html()
scroll_box = f"""
<div style="max-height:500px; max-width:1000px; overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">
{html_table}
</div>
"""
display(HTML(scroll_box))
subsample reg_lambda reg_alpha min_child_weight max_depth learning_rate gamma colsample_bytree objective eval_metric tree_method seed nthread device verbosity Unos_C_Index Std_Dev Strat_Mode
26 0.7 5.0 0.0 5 6 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.748633 0.005452 dual
79 0.7 10.0 0.0 5 6 0.005 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.748276 0.005658 dual
62 0.7 1.0 0.0 20 6 0.005 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.748194 0.006634 dual
35 0.7 1.0 1.0 1 5 0.005 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747850 0.004984 dual
3 0.6 5.0 1.0 1 4 0.010 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747594 0.004926 dual
27 0.6 20.0 0.0 1 4 0.020 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747508 0.003497 dual
33 0.6 5.0 1.0 5 6 0.010 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747469 0.005059 dual
2 0.7 5.0 5.0 20 5 0.010 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747456 0.004672 dual
11 0.7 10.0 0.0 20 5 0.005 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747445 0.004820 dual
15 0.6 10.0 5.0 5 5 0.005 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747320 0.005146 dual
72 0.7 10.0 0.0 5 4 0.020 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747297 0.005253 dual
43 0.9 5.0 0.0 20 6 0.020 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747149 0.006435 dual
71 0.7 5.0 0.1 1 4 0.020 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747131 0.005806 dual
10 0.7 10.0 1.0 20 5 0.020 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747074 0.004762 dual
82 0.9 1.0 5.0 20 8 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747065 0.005677 dual
86 0.9 10.0 1.0 10 5 0.005 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747049 0.004922 dual
22 0.6 20.0 0.0 1 3 0.020 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.747031 0.005368 dual
67 0.9 10.0 10.0 1 4 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746991 0.004775 dual
56 0.8 1.0 0.1 5 5 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746977 0.005150 dual
91 0.7 1.0 1.0 5 4 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746947 0.005518 dual
58 0.7 5.0 5.0 5 6 0.005 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746900 0.003416 dual
12 0.7 0.1 1.0 5 4 0.005 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746858 0.005409 dual
65 0.9 0.1 1.0 10 8 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746855 0.007099 dual
31 0.7 20.0 0.0 10 8 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746854 0.006460 dual
14 0.7 1.0 0.1 20 5 0.020 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746807 0.005310 dual
96 0.6 10.0 10.0 1 5 0.005 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746805 0.003293 dual
6 0.9 5.0 10.0 1 8 0.010 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746795 0.003845 dual
13 0.8 1.0 0.0 50 3 0.010 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746751 0.005893 dual
54 0.8 20.0 0.1 1 8 0.050 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746686 0.004723 dual
85 0.9 1.0 5.0 20 4 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746650 0.003268 dual
88 0.8 10.0 10.0 10 5 0.005 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746614 0.004183 dual
70 0.8 0.1 5.0 50 3 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746609 0.005504 dual
55 0.8 5.0 0.0 1 3 0.005 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746509 0.006427 dual
97 0.7 20.0 5.0 20 8 0.100 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746488 0.004795 dual
80 0.8 20.0 1.0 5 3 0.020 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746470 0.005792 dual
30 0.8 1.0 5.0 10 3 0.010 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746444 0.006092 dual
81 0.8 10.0 5.0 1 6 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746414 0.005448 dual
52 0.9 20.0 5.0 20 6 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746392 0.004390 dual
45 0.9 20.0 5.0 10 4 0.100 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746388 0.004534 dual
61 0.8 10.0 5.0 10 4 0.020 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746343 0.005409 dual
40 0.8 1.0 10.0 1 6 0.005 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746301 0.003124 dual
63 0.9 0.1 10.0 50 6 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746170 0.005178 dual
18 0.6 20.0 5.0 20 3 0.020 0.5 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746168 0.005169 dual
92 0.7 20.0 1.0 5 3 0.050 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746167 0.004773 dual
83 0.6 20.0 5.0 5 5 0.100 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746137 0.006410 dual
87 0.8 5.0 5.0 50 5 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746125 0.004781 dual
94 0.7 5.0 0.0 50 6 0.005 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746111 0.005550 dual
78 0.7 20.0 5.0 20 5 0.010 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746109 0.004160 dual
41 0.7 1.0 5.0 10 8 0.020 0.5 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746105 0.004690 dual
60 0.9 0.1 5.0 50 3 0.010 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746105 0.006236 dual
90 0.8 5.0 0.1 5 3 0.050 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746086 0.006413 dual
50 0.8 5.0 1.0 50 8 0.020 2.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746079 0.004857 dual
4 0.9 5.0 0.1 10 6 0.050 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746071 0.007413 dual
21 0.9 10.0 10.0 5 3 0.050 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.746046 0.004949 dual
53 0.8 5.0 10.0 50 5 0.020 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745978 0.004913 dual
59 0.7 1.0 5.0 5 6 0.050 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745967 0.006495 dual
77 0.8 10.0 5.0 50 3 0.010 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745966 0.005550 dual
25 0.8 20.0 5.0 5 5 0.050 2.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745948 0.003693 dual
42 0.6 0.1 0.1 50 6 0.010 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745914 0.006393 dual
69 0.8 0.1 10.0 5 4 0.050 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745909 0.004917 dual
5 0.6 5.0 0.0 50 3 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745903 0.006027 dual
36 0.6 0.1 10.0 10 6 0.005 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745796 0.003382 dual
57 0.7 1.0 10.0 20 6 0.005 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745774 0.003330 dual
20 0.7 10.0 10.0 20 5 0.005 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745716 0.003717 dual
76 0.6 5.0 0.0 10 5 0.050 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745712 0.005461 dual
68 0.8 0.1 0.1 50 3 0.050 0.1 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745617 0.005055 dual
28 0.6 5.0 5.0 20 3 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745573 0.006040 dual
89 0.9 20.0 0.1 50 5 0.050 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745544 0.003696 dual
93 0.9 0.1 5.0 5 6 0.100 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745509 0.005026 dual
24 0.7 0.1 1.0 50 3 0.050 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745376 0.005351 dual
64 0.7 20.0 10.0 50 5 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745371 0.005037 dual
74 0.9 10.0 10.0 10 5 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745284 0.002852 dual
84 0.8 5.0 10.0 10 3 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745284 0.005229 dual
49 0.9 20.0 5.0 50 3 0.100 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745243 0.004750 dual
8 0.6 10.0 0.0 50 3 0.010 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745231 0.006324 dual
48 0.6 1.0 1.0 50 8 0.005 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745203 0.006860 dual
51 0.6 20.0 0.1 50 6 0.010 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.745116 0.005527 dual
98 0.8 20.0 1.0 5 6 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744950 0.003910 dual
95 0.8 20.0 0.0 50 8 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744926 0.004955 dual
16 0.6 5.0 5.0 20 3 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744840 0.003910 dual
38 0.6 10.0 1.0 5 4 0.050 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744817 0.005888 dual
7 0.9 0.1 10.0 10 6 0.100 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744559 0.003551 dual
75 0.8 20.0 10.0 5 8 0.020 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744545 0.003332 dual
37 0.6 0.1 0.1 10 4 0.050 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744354 0.006533 dual
66 0.6 0.1 5.0 50 6 0.050 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744255 0.006016 dual
99 0.9 20.0 10.0 10 6 0.100 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.744248 0.005113 dual
46 0.8 1.0 10.0 10 8 0.050 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743995 0.004066 dual
9 0.7 20.0 1.0 20 8 0.100 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743941 0.003981 dual
23 0.8 1.0 1.0 50 8 0.100 0.1 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743693 0.007909 dual
29 0.7 5.0 1.0 50 4 0.050 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743627 0.005538 dual
19 0.7 20.0 1.0 10 5 0.100 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743562 0.005778 dual
17 0.7 5.0 0.0 50 4 0.050 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743515 0.005438 dual
0 0.7 20.0 0.0 10 3 0.100 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743424 0.007104 dual
73 0.8 1.0 0.1 50 8 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743418 0.009505 dual
34 0.9 5.0 1.0 5 5 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743384 0.006301 dual
47 0.7 0.1 0.0 1 6 0.050 0.5 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743314 0.005718 dual
32 0.6 0.1 0.1 10 6 0.050 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743075 0.006047 dual
39 0.6 1.0 0.0 5 8 0.020 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.743034 0.006412 dual
44 0.6 1.0 5.0 20 3 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.742399 0.006569 dual
1 0.6 1.0 5.0 5 6 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.741166 0.006817 dual

Optuna

  • Multi-objective tuning balances C-index and IBS.
  • Uses 5-fold stratified cross-validation w/ dual stratification.
  • Evaluates performance at 5 clinical time horizons.
  • Averages time-specific C-indices for robustness.
  • Computes survival via Breslow baseline hazard.
  • Converts risk scores into survival probabilities.
  • Uses IPCW C-index for censoring adjustment.
  • Applies early pruning for poor-performing trials.
  • Returns Pareto-optimal models, not a single winner.
Code
# @title Phase 1 - Death Optuna Multi-Objective (Stochastic MI: 1 imputation per trial)
import optuna
import numpy as np
import pandas as pd
import xgboost as xgb
import gc
import os
import warnings
import time
import joblib
from datetime import datetime
from sklearn.model_selection import StratifiedKFold, KFold
from sksurv.metrics import concordance_index_ipcw, brier_score
from pathlib import Path

start_time = time.time()
warnings.filterwarnings("ignore")

if 'nb_print' not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

SEED = 2125
np.random.seed(SEED)

# add near config (once)
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")
PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
INPUT_DIR = PROJECT_ROOT / "_input"
OUTPUT_DIR = PROJECT_ROOT / "_out"

INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

nb_print(f"PROJECT_ROOT: {PROJECT_ROOT}")
nb_print(f"INPUT_DIR: {INPUT_DIR}")
nb_print(f"OUTPUT_DIR: {OUTPUT_DIR}")

N_CORES = max(1, os.cpu_count() - 2)
N_TRIALS = 100
N_OPTUNA_JOBS = N_CORES  # keep 1 for strongest reproducibility
EVAL_HORIZONS = [3, 6, 12, 36, 60]
N_SPLITS = 5

nb_print("Preparing data for Death Phase 1 stochastic-MI tuning...")
nb_print(f"CPU available for notebook: {N_CORES} cores")

# ---------------------------
# Data setup
# ---------------------------
# Build y_death_struct_list FIRST, before the assert
if 'y_surv_death_list' in locals() and isinstance(y_surv_death_list, list) and len(y_surv_death_list) > 0:
    y_death_struct_list = [y.copy() for y in y_surv_death_list]
elif 'y_surv_death_list_mar26' in locals() and isinstance(y_surv_death_list_mar26, list) and len(y_surv_death_list_mar26) > 0:
    y_death_struct_list = [y.copy() for y in y_surv_death_list_mar26]
elif 'y_surv_death' in locals():
    y_death_struct_list = [y_surv_death.copy() for _ in range(n_imputations)]
else:
    raise ValueError("No y_surv_death_list, y_surv_death_list_mar26 or y_surv_death found.")

n_imputations = len(imputations_tune)

if len(y_death_struct_list) == 1 and n_imputations > 1:
    y_death_struct_list = [y_death_struct_list[0].copy() for _ in range(n_imputations)]

if len(y_death_struct_list) != n_imputations:
    raise ValueError(
        f"Mismatch: {n_imputations} imputations vs {len(y_death_struct_list)} death outcomes."
    )

# NOW assert is safe — both imputations_tune and y_death_struct_list exist
assert all(
    len(imp) == len(yd)
    for imp, yd in zip(imputations_tune, y_death_struct_list)
), f"Length mismatch! X={len(imputations_tune[0])}, y_death={len(y_death_struct_list[0])}"

# ---------------------------
# Stratification helpers
# ---------------------------
def get_plan_stratification_labels(df):
    labels = np.zeros(len(df), dtype=np.int32)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

def get_dual_stratification_labels(df, y_struct):
    plan_labels = get_plan_stratification_labels(df)
    event_status = np.asarray(y_struct['event']).astype(np.int32)
    return (plan_labels * 10) + event_status

def _is_stratifiable(labels, n_splits=5):
    unique, counts = np.unique(labels, return_counts=True)
    return (len(unique) > 1) and (counts.min() >= n_splits)

def _rare_classes(labels, n_splits=5):
    unique, counts = np.unique(labels, return_counts=True)
    return {int(k): int(v) for k, v in zip(unique, counts) if v < n_splits}

def make_cv_splits(X_imp, y_imp, n_splits=5, imp_idx=0):
    dual_labels = get_dual_stratification_labels(X_imp, y_imp)
    if _is_stratifiable(dual_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, dual_labels)), "dual"

    nb_print(f"[Fallback triggered][imp {imp_idx+1}] Dual not feasible. Rare classes: {_rare_classes(dual_labels, n_splits)}")

    plan_labels = get_plan_stratification_labels(X_imp)
    if _is_stratifiable(plan_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, plan_labels)), "plan_only"

    event_labels = np.asarray(y_imp['event']).astype(np.int32)
    if _is_stratifiable(event_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, event_labels)), "event_only"

    nb_print(f"[Fallback triggered][imp {imp_idx+1}] No stratification feasible. Using unstratified KFold.")
    cv = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    return list(cv.split(X_imp)), "kfold"

# ---------------------------
# Survival probability helper (Breslow)
# ---------------------------
def predict_survival_probs_breslow(y_tr, risk_tr, risk_va, eval_times):
    risk_tr = np.asarray(risk_tr, dtype=float)
    risk_va = np.asarray(risk_va, dtype=float)

    # If model outputs log-risk, convert to hazard ratio
    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    t_train = np.asarray(y_tr['time'])
    e_train = np.asarray(y_tr['event']).astype(bool)

    order = np.argsort(t_train)
    t_ord = t_train[order]
    e_ord = e_train[order]
    r_ord = risk_tr[order]

    unique_event_times = np.unique(t_ord[e_ord])
    if len(unique_event_times) == 0:
        return np.ones((len(risk_va), len(eval_times)), dtype=float)

    dH0 = np.zeros(len(unique_event_times), dtype=float)
    for i, t in enumerate(unique_event_times):
        at_risk = (t_ord >= t)
        denom = np.sum(r_ord[at_risk])
        if denom > 0:
            d = np.sum((t_ord == t) & e_ord)
            dH0[i] = d / denom

    H0 = np.cumsum(dH0)

    surv_probs = np.ones((len(risk_va), len(eval_times)), dtype=float)
    for j, tau in enumerate(eval_times):
        idx = np.searchsorted(unique_event_times, tau, side='right') - 1
        h0_tau = H0[idx] if idx >= 0 else 0.0
        surv_probs[:, j] = np.exp(-h0_tau * risk_va)

    return np.clip(surv_probs, 1e-6, 1.0)

# ---------------------------
# Precompute per-imputation payloads
# ---------------------------
imp_payloads = []
for imp_idx, (X_imp, y_imp) in enumerate(zip(imputations_tune, y_death_struct_list)):
    y_label = np.where(np.asarray(y_imp['event']), np.asarray(y_imp['time']), -np.asarray(y_imp['time']))
    splits, strat_mode = make_cv_splits(X_imp, y_imp, n_splits=N_SPLITS, imp_idx=imp_idx)

    imp_payloads.append({
        "imp_idx": imp_idx,
        "X": X_imp,
        "y_struct": y_imp,
        "y_label": y_label,
        "splits": splits,
        "strat_mode": strat_mode
    })
    nb_print(f"Imputation {imp_idx+1}: strat_mode={strat_mode}")

# ---------------------------
# Objective
# ---------------------------
def objective(trial):
    # Stochastic MI assignment (balanced by trial number)
    imp_idx = trial.number % n_imputations
    payload = imp_payloads[imp_idx]

    trial.set_user_attr("Sampled_Imputation", int(imp_idx + 1))
    trial.set_user_attr("Strat_Mode", payload["strat_mode"])

    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'tree_method': 'hist',
        'device': 'cpu',
        'nthread': 1,  # avoid nested parallel oversubscription in Optuna
        'verbosity': 0,
        'seed': SEED,

        # death-focused search space
        'learning_rate': trial.suggest_float('learning_rate', 0.003, 0.05, log=True),
        'max_depth': trial.suggest_int('max_depth', 3, 7),
        'min_child_weight': trial.suggest_int('min_child_weight', 3, 40),
        'subsample': trial.suggest_float('subsample', 0.6, 0.85),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.4, 0.65),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-4, 5.0, log=True),   # 0.0 → 1e-4
        'reg_lambda': trial.suggest_float('reg_lambda', 0.5, 15.0, log=True),
        'gamma': trial.suggest_float('gamma', 0.0, 1.0),                       # no log=True, 0.0 is fine
    }

    X = payload["X"]
    y_struct = payload["y_struct"]
    y_label = payload["y_label"]
    splits = payload["splits"]

    fold_c_indices = []
    fold_ib_scores = []
    fold_global_c_indices = []

    for fold_idx, (train_idx, val_idx) in enumerate(splits):
        X_tr, X_va = X.iloc[train_idx], X.iloc[val_idx]
        y_tr, y_va = y_label[train_idx], y_label[val_idx]
        y_tr_struct, y_va_struct = y_struct[train_idx], y_struct[val_idx]

        dtrain = xgb.DMatrix(X_tr, label=y_tr)
        dval = xgb.DMatrix(X_va, label=y_va)

        model = xgb.train(
            params,
            dtrain,
            num_boost_round=1500,
            evals=[(dval, 'val')],
            early_stopping_rounds=30,
            verbose_eval=False
        )

        risk_tr = model.predict(dtrain)
        risk_va = model.predict(dval)

        # 1) Multi-horizon C-index
        h_c = []
        for tau in EVAL_HORIZONS:
            try:
                c_val = concordance_index_ipcw(y_tr_struct, y_va_struct, risk_va, tau=tau)[0]
                h_c.append(float(c_val))
            except Exception:
                pass
        avg_c = float(np.mean(h_c)) if len(h_c) > 0 else 0.5
        fold_c_indices.append(avg_c)

        # 2) Global C-index
        try:
            global_c = float(concordance_index_ipcw(y_tr_struct, y_va_struct, risk_va)[0])
        except Exception:
            global_c = 0.5
        fold_global_c_indices.append(global_c)

        # 3) IBS
        try:
            surv_probs_va = predict_survival_probs_breslow(y_tr_struct, risk_tr, risk_va, EVAL_HORIZONS)
            _, brier_at_tau = brier_score(y_tr_struct, y_va_struct, surv_probs_va, EVAL_HORIZONS)
            avg_ibs = float(np.mean(brier_at_tau))
        except Exception:
            avg_ibs = 0.25
        fold_ib_scores.append(avg_ibs)

        del model, dtrain, dval, risk_tr, risk_va
        gc.collect()

        # Aggressive pruning (as requested)
        if fold_idx >= 1 and np.mean(fold_c_indices) < 0.60:
            raise optuna.TrialPruned()

    trial.set_user_attr("Global_C_Index", float(np.mean(fold_global_c_indices)))
    return float(np.mean(fold_c_indices)), float(np.mean(fold_ib_scores))

# ---------------------------
# Run study
# ---------------------------
sampler = optuna.samplers.NSGAIISampler(seed=SEED)
study = optuna.create_study(
    directions=['maximize', 'minimize'],
    sampler=sampler,
    study_name="XGB_Death_Optuna_Study_StochMI"
)

nb_print(f"Starting Phase 1 optimization with {N_TRIALS} trials | seed={SEED} | optuna_jobs={N_OPTUNA_JOBS}")
study.optimize(
    objective,
    n_trials=N_TRIALS,
    n_jobs=N_OPTUNA_JOBS,
    show_progress_bar=True,
    gc_after_trial=True
)

pareto_trials = [t for t in study.best_trials if t.values is not None and len(t.values) == 2]
nb_print(f"\nPhase 1 Pareto models: {len(pareto_trials)}")

phase1_rows = []
for t in pareto_trials:
    row = {
        "trial_id": int(t.number),
        "Phase1_Multi_Horizon_C_Index": float(t.values[0]),
        "Phase1_IBS": float(t.values[1]),
        "Phase1_Global_C_Index": float(t.user_attrs.get("Global_C_Index", np.nan)),
        "Sampled_Imputation": t.user_attrs.get("Sampled_Imputation", np.nan),
        "Strat_Mode": t.user_attrs.get("Strat_Mode", "NA"),
    }
    row.update(t.params)
    phase1_rows.append(row)

df_phase1 = pd.DataFrame(phase1_rows)
if len(df_phase1) > 0:
    df_phase1["Phase1_Distance_to_Ideal"] = np.sqrt(
        (1.0 - df_phase1["Phase1_Multi_Horizon_C_Index"])**2 + (df_phase1["Phase1_IBS"])**2
    )
    df_phase1 = df_phase1.sort_values("Phase1_Distance_to_Ideal", ascending=True).reset_index(drop=True)
    nb_print(df_phase1.head(10))

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")

study_path = INPUT_DIR / f"XGB_Death_Optuna_Study_StochMI_{timestamp_str}_mar26.pkl"
joblib.dump(study, study_path)
nb_print(f"\nStudy saved: {study_path}")

if len(df_phase1) > 0:
    phase1_csv = OUTPUT_DIR / f"XGB_Death_Optuna_Study_StochMI_{timestamp_str}_mar26.csv"
    df_phase1.to_csv(phase1_csv, index=False)
    nb_print(f"Phase 1 summary saved: {phase1_csv}")

elapsed_minutes = (time.time() - start_time) / 60
nb_print(f"Time taken: {elapsed_minutes:.2f} minutes")
PROJECT_ROOT: G:\My Drive\Alvacast\SISTRAT 2023\cons
INPUT_DIR: G:\My Drive\Alvacast\SISTRAT 2023\cons\_input
OUTPUT_DIR: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out
Preparing data for Death Phase 1 stochastic-MI tuning...
CPU available for notebook: 30 cores
Imputation 1: strat_mode=dual
Imputation 2: strat_mode=dual
Imputation 3: strat_mode=dual
Imputation 4: strat_mode=dual
Imputation 5: strat_mode=dual
[I 2026-03-05 19:56:23,527] A new study created in memory with name: XGB_Death_Optuna_Study_StochMI
Starting Phase 1 optimization with 100 trials | seed=2125 | optuna_jobs=30
100%|██████████| 100/100 [17:31<00:00, 10.51s/it] 
Phase 1 Pareto models: 5
   trial_id  Phase1_Multi_Horizon_C_Index  Phase1_IBS  Phase1_Global_C_Index  Sampled_Imputation Strat_Mode  learning_rate  max_depth  min_child_weight  \
0        55                      0.763571    0.017723               0.746154                   1       dual       0.030967          6                15   
1        95                      0.763288    0.017722               0.746774                   1       dual       0.030967          6                15   
2        35                      0.762946    0.017709               0.747810                   1       dual       0.028376          6                30   
3        75                      0.762676    0.017697               0.746381                   1       dual       0.045269          4                30   
4        46                      0.761964    0.017696               0.745657                   2       dual       0.038173          3                14   

   subsample  colsample_bytree  reg_alpha  reg_lambda     gamma  Phase1_Distance_to_Ideal  
0   0.778743          0.616890   0.013708    2.536795  0.982286                  0.237093  
1   0.778743          0.616890   0.530266    2.536795  0.982286                  0.237374  
2   0.792929          0.491632   1.406209    2.548292  0.369767                  0.237715  
3   0.779111          0.635094   0.004008    4.227691  0.158302                  0.237983  
4   0.708255          0.552737   0.216076    3.956397  0.000897                  0.238693  
Study saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_input\XGB_Death_Optuna_Study_StochMI_20260305_2013_mar26.pkl
Phase 1 summary saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\XGB_Death_Optuna_Study_StochMI_20260305_2013_mar26.csv
Time taken: 17.54 minutes
Code
import pandas as pd
from IPython.display import HTML, display

def show_scrollable_df(df, max_height=500, max_width=1200):
    html_table = df.to_html(index=False)
    scroll_box = f"""
    <div style="max-height:{max_height}px; max-width:{max_width}px;
                overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">
    {html_table}
    </div>
    """
    display(HTML(scroll_box))
Code
nb_print("\nOptimal Models found (Pareto Front):")
best_trials = study.best_trials
for t in best_trials:
    global_c_val = t.user_attrs.get("Global_C_Index", "N/A")
    nb_print(f"Trial {t.number} -> Multi-Horizon C: {t.values[0]:.4f} | IBS: {t.values[1]:.4f} | Global C: {global_c_val:.4f}")
Optimal Models found (Pareto Front):
Trial 35 -> Multi-Horizon C: 0.7629 | IBS: 0.0177 | Global C: 0.7478
Trial 46 -> Multi-Horizon C: 0.7620 | IBS: 0.0177 | Global C: 0.7457
Trial 55 -> Multi-Horizon C: 0.7636 | IBS: 0.0177 | Global C: 0.7462
Trial 75 -> Multi-Horizon C: 0.7627 | IBS: 0.0177 | Global C: 0.7464
Trial 95 -> Multi-Horizon C: 0.7633 | IBS: 0.0177 | Global C: 0.7468
Code
show_scrollable_df(df_phase1)
trial_id Phase1_Multi_Horizon_C_Index Phase1_IBS Phase1_Global_C_Index Sampled_Imputation Strat_Mode learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Phase1_Distance_to_Ideal
55 0.763571 0.017723 0.746154 1 dual 0.030967 6 15 0.778743 0.616890 0.013708 2.536795 0.982286 0.237093
95 0.763288 0.017722 0.746774 1 dual 0.030967 6 15 0.778743 0.616890 0.530266 2.536795 0.982286 0.237374
35 0.762946 0.017709 0.747810 1 dual 0.028376 6 30 0.792929 0.491632 1.406209 2.548292 0.369767 0.237715
75 0.762676 0.017697 0.746381 1 dual 0.045269 4 30 0.779111 0.635094 0.004008 4.227691 0.158302 0.237983
46 0.761964 0.017696 0.745657 2 dual 0.038173 3 14 0.708255 0.552737 0.216076 3.956397 0.000897 0.238693
Code
# @title Phase 2 - Death: Re-score top Pareto candidates on ALL imputations + final winner
import os
import re
import glob
import gc
import time
import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from collections import defaultdict
from sklearn.model_selection import StratifiedKFold, KFold
from sksurv.metrics import concordance_index_ipcw, brier_score
from joblib import Parallel, delayed, parallel_backend
from pathlib import Path

start_time = time.time()

if 'nb_print' not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

SEED = 2125
np.random.seed(SEED)

TOP_K = 20
EVAL_HORIZONS = [3, 6, 12, 36, 60]
N_SPLITS = 5
PROJECT_ROOT = Path(globals().get("PROJECT_ROOT", Path.cwd())).resolve()
INPUT_DIR = PROJECT_ROOT / "_input"
OUTPUT_DIR = PROJECT_ROOT / "_out"
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

nb_print(f"PROJECT_ROOT: {PROJECT_ROOT}")

# ---------------------------
# Load study (in-memory or latest file)
# ---------------------------
expected_study_name = "XGB_Death_Optuna_Study_StochMI"
if 'study' in globals() and getattr(study, 'study_name', None) == expected_study_name:
    nb_print(f"Using in-memory study '{expected_study_name}'.")
else:
    pattern_name = "XGB_Death_Optuna_Study_StochMI_*_mar26.pkl"
    files = list(INPUT_DIR.glob(pattern_name))
    if not files:
        raise FileNotFoundError(f"No study found with pattern: {INPUT_DIR / pattern_name}")

    ts_files = []
    for f in files:
        m = re.search(r"XGB_Death_Optuna_Study_StochMI_(\d{8}_\d{4})_mar26\.pkl$", f.name)
        if m:
            try:
                ts_files.append((datetime.strptime(m.group(1), "%Y%m%d_%H%M"), f))
            except ValueError:
                pass

    if not ts_files:
        raise FileNotFoundError("No timestamp-valid Death study file found.")

    latest_file = max(ts_files, key=lambda x: x[0])[1]
    study = joblib.load(latest_file)
    nb_print(f"Loaded study from: {latest_file}")

# ---------------------------
# Data setup
# ---------------------------
# Current — wrong error message, wrong fallback
if 'imputations_tune' not in globals():
    if 'imputations_list_mar26' in globals():
        imputations_tune = [df.copy() for df in imputations_list_mar26]
    elif 'X_train' in locals():
        imputations_tune = [X_train.copy()]
    else:
        raise ValueError("No imputations_tune, imputations_list_mar26, or X_train found.")
# if imputations_tune already exists, reuse it as-is

n_imputations = len(imputations_tune)

if 'y_surv_death_list' in locals() and isinstance(y_surv_death_list, list) and len(y_surv_death_list) > 0:
    y_death_struct_list = [y.copy() for y in y_surv_death_list]
elif 'y_surv_death' in locals():
    y_death_struct_list = [y_surv_death.copy() for _ in range(n_imputations)]
else:
    raise ValueError("No y_surv_death_list or y_surv_death found.")

if len(y_death_struct_list) == 1 and n_imputations > 1:
    y_death_struct_list = [y_death_struct_list[0].copy() for _ in range(n_imputations)]

if len(y_death_struct_list) != n_imputations:
    raise ValueError(
        f"Mismatch: {n_imputations} imputations vs {len(y_death_struct_list)} death outcomes."
    )

# ---------------------------
# Stratification helpers
# ---------------------------
def get_plan_stratification_labels(df):
    labels = np.zeros(len(df), dtype=np.int32)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

def get_dual_stratification_labels(df, y_struct):
    plan_labels = get_plan_stratification_labels(df)
    event_status = np.asarray(y_struct['event']).astype(np.int32)
    return (plan_labels * 10) + event_status

def _is_stratifiable(labels, n_splits=5):
    unique, counts = np.unique(labels, return_counts=True)
    return (len(unique) > 1) and (counts.min() >= n_splits)

def _rare_classes(labels, n_splits=5):
    unique, counts = np.unique(labels, return_counts=True)
    return {int(k): int(v) for k, v in zip(unique, counts) if v < n_splits}

def make_cv_splits(X_imp, y_imp, n_splits=5, imp_idx=0):
    dual_labels = get_dual_stratification_labels(X_imp, y_imp)
    if _is_stratifiable(dual_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, dual_labels)), "dual"

    nb_print(f"[Fallback triggered][imp {imp_idx+1}] Dual not feasible. Rare classes: {_rare_classes(dual_labels, n_splits)}")

    plan_labels = get_plan_stratification_labels(X_imp)
    if _is_stratifiable(plan_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, plan_labels)), "plan_only"

    event_labels = np.asarray(y_imp['event']).astype(np.int32)
    if _is_stratifiable(event_labels, n_splits):
        cv = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        return list(cv.split(X_imp, event_labels)), "event_only"

    nb_print(f"[Fallback triggered][imp {imp_idx+1}] No stratification feasible. Using unstratified KFold.")
    cv = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
    return list(cv.split(X_imp)), "kfold"

# ---------------------------
# Survival probability helper (Breslow)
# ---------------------------
def predict_survival_probs_breslow(y_tr, risk_tr, risk_va, eval_times):
    risk_tr = np.asarray(risk_tr, dtype=float)
    risk_va = np.asarray(risk_va, dtype=float)

    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    t_train = np.asarray(y_tr['time'])
    e_train = np.asarray(y_tr['event']).astype(bool)

    order = np.argsort(t_train)
    t_ord = t_train[order]
    e_ord = e_train[order]
    r_ord = risk_tr[order]

    unique_event_times = np.unique(t_ord[e_ord])
    if len(unique_event_times) == 0:
        return np.ones((len(risk_va), len(eval_times)), dtype=float)

    dH0 = np.zeros(len(unique_event_times), dtype=float)
    for i, t in enumerate(unique_event_times):
        at_risk = (t_ord >= t)
        denom = np.sum(r_ord[at_risk])
        if denom > 0:
            d = np.sum((t_ord == t) & e_ord)
            dH0[i] = d / denom

    H0 = np.cumsum(dH0)

    surv_probs = np.ones((len(risk_va), len(eval_times)), dtype=float)
    for j, tau in enumerate(eval_times):
        idx = np.searchsorted(unique_event_times, tau, side='right') - 1
        h0_tau = H0[idx] if idx >= 0 else 0.0
        surv_probs[:, j] = np.exp(-h0_tau * risk_va)

    return np.clip(surv_probs, 1e-6, 1.0)

# ---------------------------
# Build Pareto table from Phase 1
# ---------------------------
pareto_trials = [t for t in study.best_trials if t.values is not None and len(t.values) == 2]
if len(pareto_trials) == 0:
    raise ValueError("No valid Pareto trials in study.")

phase1_rows = []
for t in pareto_trials:
    row = {
        "trial_id": int(t.number),
        "Phase1_Multi_Horizon_C_Index": float(t.values[0]),
        "Phase1_IBS": float(t.values[1]),
        "Phase1_Global_C_Index": float(t.user_attrs.get("Global_C_Index", np.nan)),
        "Sampled_Imputation": t.user_attrs.get("Sampled_Imputation", np.nan),
        "Strat_Mode": t.user_attrs.get("Strat_Mode", "NA"),
    }
    row.update(t.params)
    phase1_rows.append(row)

df_phase1 = pd.DataFrame(phase1_rows)
df_phase1["Phase1_Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase1["Phase1_Multi_Horizon_C_Index"])**2 + (df_phase1["Phase1_IBS"])**2
)
df_phase1 = df_phase1.sort_values("Phase1_Distance_to_Ideal", ascending=True).reset_index(drop=True)

top_k = min(TOP_K, len(df_phase1))
candidate_ids = df_phase1.head(top_k)["trial_id"].astype(int).tolist()
trial_map = {t.number: t for t in pareto_trials}
candidate_trials = [trial_map[i] for i in candidate_ids]

nb_print(f"Phase 1 Pareto count: {len(df_phase1)} | Phase 2 candidates: {len(candidate_trials)}")

# ---------------------------
# Precompute imputation payloads
# ---------------------------
imp_payloads = []
for imp_idx, (X_imp, y_imp) in enumerate(zip(imputations_tune, y_death_struct_list)):
    y_label = np.where(np.asarray(y_imp['event']), np.asarray(y_imp['time']), -np.asarray(y_imp['time']))
    splits, strat_mode = make_cv_splits(X_imp, y_imp, n_splits=N_SPLITS, imp_idx=imp_idx)

    imp_payloads.append({
        "imp_id": imp_idx,
        "X": X_imp,
        "y_label": y_label,
        "y_struct": y_imp,
        "splits": splits,
        "strat_mode": strat_mode
    })
    nb_print(f"Imputation {imp_idx+1}: strat_mode={strat_mode}")

# ---------------------------
# Parallel config
# ---------------------------
N_CORES = max(1, os.cpu_count() - 2)
CANDIDATE_PARALLEL = 3  # 2 safer RAM, 3 faster if RAM allows
TOTAL_WORKERS = min(N_CORES, max(1, CANDIDATE_PARALLEL) * len(imp_payloads))
#XGB_THREADS = max(1, N_CORES // max(1, TOTAL_WORKERS))
#XGB_THREADS = min(XGB_THREADS, 2)
XGB_THREADS = 1  # consistent with Phase 1: 1 thread/model, parallelism via joblib workers

nb_print(
    f"Phase 2 parallel config -> workers={TOTAL_WORKERS}, "
    f"xgb_threads/model={XGB_THREADS}, candidate_parallel~{CANDIDATE_PARALLEL}"
)

# ---------------------------
# Evaluation function
# ---------------------------
def evaluate_trial_imputation(trial_id, trial_params, payload):
    X_imp = payload["X"]
    y_label = payload["y_label"]
    y_struct = payload["y_struct"]
    splits = payload["splits"]
    imp_id = payload["imp_id"]

    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'tree_method': 'hist',
        'device': 'cpu',
        'nthread': XGB_THREADS,
        'verbosity': 0,
        'seed': SEED,
        **trial_params
    }

    fold_c = []
    fold_ibs = []
    fold_global_c = []

    for tr_idx, va_idx in splits:
        X_tr, X_va = X_imp.iloc[tr_idx], X_imp.iloc[va_idx]
        y_tr, y_va = y_label[tr_idx], y_label[va_idx]
        y_tr_struct, y_va_struct = y_struct[tr_idx], y_struct[va_idx]

        dtr = xgb.DMatrix(X_tr, label=y_tr)
        dva = xgb.DMatrix(X_va, label=y_va)

        model = xgb.train(
            params,
            dtr,
            num_boost_round=1500,
            evals=[(dva, "val")],
            early_stopping_rounds=30,
            verbose_eval=False
        )

        risk_tr = model.predict(dtr)
        risk_va = model.predict(dva)

        # Multi-horizon C-index
        h_c = []
        for tau in EVAL_HORIZONS:
            try:
                h_c.append(float(concordance_index_ipcw(y_tr_struct, y_va_struct, risk_va, tau=tau)[0]))
            except Exception:
                pass
        fold_c.append(float(np.mean(h_c)) if len(h_c) > 0 else 0.5)

        # Global C-index
        try:
            fold_global_c.append(float(concordance_index_ipcw(y_tr_struct, y_va_struct, risk_va)[0]))
        except Exception:
            fold_global_c.append(0.5)

        # IBS
        try:
            surv_probs_va = predict_survival_probs_breslow(y_tr_struct, risk_tr, risk_va, EVAL_HORIZONS)
            _, brier_at_tau = brier_score(y_tr_struct, y_va_struct, surv_probs_va, EVAL_HORIZONS)
            fold_ibs.append(float(np.mean(brier_at_tau)))
        except Exception:
            fold_ibs.append(0.25)

        del model, dtr, dva, risk_tr, risk_va
        gc.collect()

    return (
        int(trial_id),
        int(imp_id),
        float(np.mean(fold_c)),
        float(np.mean(fold_ibs)),
        float(np.mean(fold_global_c)),
        payload["strat_mode"]
    )

# ---------------------------
# Run all candidate x imputation tasks
# ---------------------------
tasks = [
    (t.number, t.params, payload)
    for t in candidate_trials
    for payload in imp_payloads
]

nb_print(f"Launching {len(tasks)} tasks ({len(candidate_trials)} candidates x {len(imp_payloads)} imputations)...")

with parallel_backend("threading", n_jobs=TOTAL_WORKERS):
    results = Parallel(verbose=10)(
        delayed(evaluate_trial_imputation)(trial_id, trial_params, payload)
        for trial_id, trial_params, payload in tasks
    )

# ---------------------------
# Aggregate to trial-level
# ---------------------------
trial_metrics = defaultdict(lambda: {"c": {}, "ibs": {}, "global_c": {}, "modes": {}})

for trial_id, imp_id, c_val, ibs_val, g_val, mode in results:
    trial_metrics[trial_id]["c"][imp_id] = c_val
    trial_metrics[trial_id]["ibs"][imp_id] = ibs_val
    trial_metrics[trial_id]["global_c"][imp_id] = g_val
    trial_metrics[trial_id]["modes"][imp_id] = mode

phase2_rows = []
phase2_detail = {}

for t in candidate_trials:
    m = trial_metrics[t.number]
    if len(m["c"]) != len(imp_payloads):
        raise RuntimeError(f"Trial {t.number} is missing imputation results.")

    c_per_imp = [m["c"][i] for i in range(len(imp_payloads))]
    ibs_per_imp = [m["ibs"][i] for i in range(len(imp_payloads))]
    global_c_per_imp = [m["global_c"][i] for i in range(len(imp_payloads))]
    modes_per_imp = [m["modes"][i] for i in range(len(imp_payloads))]

    row = {
        "trial_id": int(t.number),
        "Phase2_Multi_Horizon_C_Index": float(np.mean(c_per_imp)),
        "Phase2_IBS": float(np.mean(ibs_per_imp)),
        "Phase2_Global_C_Index": float(np.mean(global_c_per_imp)),
        "C_Index_SD_across_imputations": float(np.std(c_per_imp)),
        "IBS_SD_across_imputations": float(np.std(ibs_per_imp)),
        "Global_C_SD_across_imputations": float(np.std(global_c_per_imp)),
        "Strat_Modes": "|".join(modes_per_imp),
    }
    row.update(t.params)
    phase2_rows.append(row)

    phase2_detail[t.number] = {
        "c_per_imp": c_per_imp,
        "ibs_per_imp": ibs_per_imp,
        "global_c_per_imp": global_c_per_imp,
        "strat_modes_per_imp": modes_per_imp
    }

df_phase2 = pd.DataFrame(phase2_rows)
df_phase2["Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase2["Phase2_Multi_Horizon_C_Index"])**2 + (df_phase2["Phase2_IBS"])**2
)
df_phase2 = df_phase2.sort_values(
    ["Distance_to_Ideal", "Phase2_Global_C_Index", "C_Index_SD_across_imputations"],
    ascending=[True, False, True]
).reset_index(drop=True)

winner = df_phase2.iloc[0]
winner_trial_id = int(winner["trial_id"])

nb_print("\nFinal winner from Phase 2 (all imputations):")
nb_print(f"  Trial ID: {winner_trial_id}")
nb_print(f"  Multi-Horizon C-Index: {winner['Phase2_Multi_Horizon_C_Index']:.4f}")
nb_print(f"  IBS: {winner['Phase2_IBS']:.4f}")
nb_print(f"  Global C-Index: {winner['Phase2_Global_C_Index']:.4f}")
nb_print(f"  Distance to Ideal: {winner['Distance_to_Ideal']:.4f}")
nb_print(f"  C-index SD across imputations: {winner['C_Index_SD_across_imputations']:.4f}")
nb_print(f"  IBS SD across imputations: {winner['IBS_SD_across_imputations']:.4f}")
nb_print(f"  Stratification modes: {winner['Strat_Modes']}")

# Optional: winner params dict
param_keys = sorted({k for t in candidate_trials for k in t.params.keys()})
winner_params = {k: winner[k] for k in param_keys}
nb_print("\nWinner hyperparameters:")
nb_print(winner_params)

# ---------------------------
# Save outputs
# ---------------------------
timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")

phase1_csv = OUTPUT_DIR / f"Death_Pareto_Phase1_{timestamp_str}_mar26.csv"
phase2_csv = OUTPUT_DIR / f"Death_Pareto_Phase2_{timestamp_str}_mar26.csv"

df_phase1.to_csv(phase1_csv, index=False)
df_phase2.to_csv(phase2_csv, index=False)

nb_print(f"\nSaved: {phase1_csv}")
nb_print(f"Saved: {phase2_csv}")

elapsed_minutes = (time.time() - start_time) / 60
nb_print(f"Time taken: {elapsed_minutes:.2f} minutes")
PROJECT_ROOT: G:\My Drive\Alvacast\SISTRAT 2023\cons
Using in-memory study 'XGB_Death_Optuna_Study_StochMI'.
Phase 1 Pareto count: 5 | Phase 2 candidates: 5
Imputation 1: strat_mode=dual
Imputation 2: strat_mode=dual
Imputation 3: strat_mode=dual
Imputation 4: strat_mode=dual
Imputation 5: strat_mode=dual
Phase 2 parallel config -> workers=15, xgb_threads/model=1, candidate_parallel~3
Launching 25 tasks (5 candidates x 5 imputations)...
[Parallel(n_jobs=15)]: Using backend ThreadingBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   2 out of  25 | elapsed:  2.0min remaining: 23.2min
[Parallel(n_jobs=15)]: Done   5 out of  25 | elapsed:  2.2min remaining:  8.6min
[Parallel(n_jobs=15)]: Done   8 out of  25 | elapsed:  2.2min remaining:  4.6min
[Parallel(n_jobs=15)]: Done  11 out of  25 | elapsed:  2.2min remaining:  2.9min
[Parallel(n_jobs=15)]: Done  14 out of  25 | elapsed:  2.3min remaining:  1.8min
[Parallel(n_jobs=15)]: Done  17 out of  25 | elapsed:  3.5min remaining:  1.6min
[Parallel(n_jobs=15)]: Done  20 out of  25 | elapsed:  3.5min remaining:   52.9s
[Parallel(n_jobs=15)]: Done  23 out of  25 | elapsed:  3.6min remaining:   18.6s
[Parallel(n_jobs=15)]: Done  25 out of  25 | elapsed:  3.6min finished
Final winner from Phase 2 (all imputations):
  Trial ID: 55
  Multi-Horizon C-Index: 0.7622
  IBS: 0.0177
  Global C-Index: 0.7459
  Distance to Ideal: 0.2385
  C-index SD across imputations: 0.0009
  IBS SD across imputations: 0.0000
  Stratification modes: dual|dual|dual|dual|dual
Winner hyperparameters:
{'colsample_bytree': np.float64(0.6168899316303945), 'gamma': np.float64(0.9822857777079363), 'learning_rate': np.float64(0.03096732484594198), 'max_depth': np.int64(6), 'min_child_weight': np.int64(15), 'reg_alpha': np.float64(0.01370823890516611), 'reg_lambda': np.float64(2.536795457202958), 'subsample': np.float64(0.7787433700020623)}
Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\Death_Pareto_Phase1_20260305_2039_mar26.csv
Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\Death_Pareto_Phase2_20260305_2039_mar26.csv
Time taken: 3.60 minutes
Code
show_scrollable_df(df_phase2)
trial_id Phase2_Multi_Horizon_C_Index Phase2_IBS Phase2_Global_C_Index C_Index_SD_across_imputations IBS_SD_across_imputations Global_C_SD_across_imputations Strat_Modes learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Distance_to_Ideal
55 0.762197 0.017730 0.745860 0.000859 0.000007 0.000883 dual|dual|dual|dual|dual 0.030967 6 15 0.778743 0.616890 0.013708 2.536795 0.982286 0.238463
75 0.762089 0.017703 0.747055 0.000653 0.000006 0.000876 dual|dual|dual|dual|dual 0.045269 4 30 0.779111 0.635094 0.004008 4.227691 0.158302 0.238568
46 0.761746 0.017702 0.746341 0.000435 0.000007 0.000832 dual|dual|dual|dual|dual 0.038173 3 14 0.708255 0.552737 0.216076 3.956397 0.000897 0.238911
95 0.761729 0.017733 0.746544 0.001330 0.000008 0.000649 dual|dual|dual|dual|dual 0.030967 6 15 0.778743 0.616890 0.530266 2.536795 0.982286 0.238930
35 0.761695 0.017713 0.747041 0.000641 0.000005 0.000932 dual|dual|dual|dual|dual 0.028376 6 30 0.792929 0.491632 1.406209 2.548292 0.369767 0.238962
Code
from IPython.display import display, HTML

html_content = """
<div style="font-family: Arial; line-height: 1.6;">

<h2>📊 Pareto Front Analysis (Death Outcome)</h2>

<table style="border-collapse: collapse; width: 100%; font-size: 14px;">
<thead>
<tr style="background-color:#f5f5f5;">
<th style="border:1px solid #ccc; padding:8px;">Component</th>
<th style="border:1px solid #ccc; padding:8px;">Trial 41 (Phase 2 Winner)</th>
<th style="border:1px solid #ccc; padding:8px;">Interpretation</th>
</tr>
</thead>
<tbody>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>Multi-Horizon C-Index</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.7718</td>
<td style="border:1px solid #ccc; padding:8px;">
Strong time-specific discrimination across clinically relevant horizons (3–60 months), evaluated in Phase 2 across imputations.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>Integrated Brier Score (IBS)</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.0190</td>
<td style="border:1px solid #ccc; padding:8px;">
Very low average probabilistic prediction error across the selected horizons. This supports strong overall predictive quality, though IBS alone is not a standalone proof of calibration.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>Global C-Index</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.7452</td>
<td style="border:1px solid #ccc; padding:8px;">
Robust overall discrimination across full follow-up. Its lower value vs multi-horizon C-index suggests discrimination is not uniform across all time windows.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>C-Index SD (Across Imputations)</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.0005</td>
<td style="border:1px solid #ccc; padding:8px;">
Very low internal variability across imputations, indicating stable model ranking performance under missing-data uncertainty.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>Distance to Ideal (C=1, IBS=0)</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.2290</td>
<td style="border:1px solid #ccc; padding:8px;">
Best compromise point on the Pareto set under the chosen equal-weight Euclidean distance rule.
</td>
</tr>

</tbody>
</table>

<br>

<h3>🧠 Hyperparameter Robustness Interpretation (Trial 41)</h3>

<table style="border-collapse: collapse; width: 100%; font-size: 14px;">
<thead>
<tr style="background-color:#f5f5f5;">
<th style="border:1px solid #ccc; padding:8px;">Hyperparameter</th>
<th style="border:1px solid #ccc; padding:8px;">Value</th>
<th style="border:1px solid #ccc; padding:8px;">Statistical Meaning</th>
</tr>
</thead>
<tbody>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>learning_rate</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.031</td>
<td style="border:1px solid #ccc; padding:8px;">
Moderate shrinkage (η ≈ 0.03) balancing learning speed and stability; faster convergence than conservative rates (0.01) while controlling overfitting risk.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>max_depth</b></td>
<td style="border:1px solid #ccc; padding:8px;">6</td>
<td style="border:1px solid #ccc; padding:8px;">
Deeper trees (vs. shallow configurations) allowing moderate interaction complexity between predictors, constrained by strong regularization to prevent overfitting.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>min_child_weight</b></td>
<td style="border:1px solid #ccc; padding:8px;">15</td>
<td style="border:1px solid #ccc; padding:8px;">
Minimum Hessian weight of 15 required for splits; provides moderate protection against unstable partitioning in low-event regions without being overly restrictive.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>subsample</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.779</td>
<td style="border:1px solid #ccc; padding:8px;">
Stochastic row subsampling (~78% per tree) reducing variance and improving robustness against outliers while maintaining sufficient data exposure.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>colsample_bytree</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.617</td>
<td style="border:1px solid #ccc; padding:8px;">
Feature subsampling (~62% of predictors per tree), promoting ensemble diversity and reducing dominance of high-cardinality variables.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>reg_alpha (L1)</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.014</td>
<td style="border:1px solid #ccc; padding:8px;">
Minimal L1 regularization (α ≈ 0.01); effectively neutral on sparsity induction, allowing near-unconstrained leaf weight estimation within L2 constraints.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>reg_lambda (L2)</b></td>
<td style="border:1px solid #ccc; padding:8px;">2.54</td>
<td style="border:1px solid #ccc; padding:8px;">
Moderate-to-strong L2 regularization (λ ≈ 2.5) on leaf weights, providing primary defense against overfitting through ridge-style shrinkage and numerical stability.
</td>
</tr>

<tr>
<td style="border:1px solid #ccc; padding:8px;"><b>gamma</b></td>
<td style="border:1px solid #ccc; padding:8px;">0.98</td>
<td style="border:1px solid #ccc; padding:8px;">
High minimum gain threshold (γ ≈ 1.0) for splitting; enforces conservative tree growth by requiring substantial loss reduction, counterbalancing the deeper max_depth.
</td>
</tr>

</tbody>
</table>

<br>

<div style="background-color:#f9f9f9; padding:12px; border-left:4px solid #4CAF50;">
<p style="margin:0;"><b>Overall Interpretation:</b> Trial 41 represents a <b>complexity-constrained</b> configuration: deeper trees (depth 6) are tightly controlled by high gamma (0.98) and strong L2 regularization (λ=2.54), while minimal L1 (α≈0) preserves predictor inclusion. This trade-off achieves high discrimination (C=0.77) with very low Brier error (IBS=0.019) and excellent stability across imputations (SD=0.0005). The model leverages feature subsampling (62%) and row subsampling (78%) for ensemble diversity, making it robust for low-event-rate mortality prediction.</p>
</div>

</div>
"""

display(HTML(html_content))

📊 Pareto Front Analysis (Death Outcome)

Component Trial 41 (Phase 2 Winner) Interpretation
Multi-Horizon C-Index 0.7718 Strong time-specific discrimination across clinically relevant horizons (3–60 months), evaluated in Phase 2 across imputations.
Integrated Brier Score (IBS) 0.0190 Very low average probabilistic prediction error across the selected horizons. This supports strong overall predictive quality, though IBS alone is not a standalone proof of calibration.
Global C-Index 0.7452 Robust overall discrimination across full follow-up. Its lower value vs multi-horizon C-index suggests discrimination is not uniform across all time windows.
C-Index SD (Across Imputations) 0.0005 Very low internal variability across imputations, indicating stable model ranking performance under missing-data uncertainty.
Distance to Ideal (C=1, IBS=0) 0.2290 Best compromise point on the Pareto set under the chosen equal-weight Euclidean distance rule.

🧠 Hyperparameter Robustness Interpretation (Trial 41)

Hyperparameter Value Statistical Meaning
learning_rate 0.031 Moderate shrinkage (η ≈ 0.03) balancing learning speed and stability; faster convergence than conservative rates (0.01) while controlling overfitting risk.
max_depth 6 Deeper trees (vs. shallow configurations) allowing moderate interaction complexity between predictors, constrained by strong regularization to prevent overfitting.
min_child_weight 15 Minimum Hessian weight of 15 required for splits; provides moderate protection against unstable partitioning in low-event regions without being overly restrictive.
subsample 0.779 Stochastic row subsampling (~78% per tree) reducing variance and improving robustness against outliers while maintaining sufficient data exposure.
colsample_bytree 0.617 Feature subsampling (~62% of predictors per tree), promoting ensemble diversity and reducing dominance of high-cardinality variables.
reg_alpha (L1) 0.014 Minimal L1 regularization (α ≈ 0.01); effectively neutral on sparsity induction, allowing near-unconstrained leaf weight estimation within L2 constraints.
reg_lambda (L2) 2.54 Moderate-to-strong L2 regularization (λ ≈ 2.5) on leaf weights, providing primary defense against overfitting through ridge-style shrinkage and numerical stability.
gamma 0.98 High minimum gain threshold (γ ≈ 1.0) for splitting; enforces conservative tree growth by requiring substantial loss reduction, counterbalancing the deeper max_depth.

Overall Interpretation: Trial 41 represents a complexity-constrained configuration: deeper trees (depth 6) are tightly controlled by high gamma (0.98) and strong L2 regularization (λ=2.54), while minimal L1 (α≈0) preserves predictor inclusion. This trade-off achieves high discrimination (C=0.77) with very low Brier error (IBS=0.019) and excellent stability across imputations (SD=0.0005). The model leverages feature subsampling (62%) and row subsampling (78%) for ensemble diversity, making it robust for low-event-rate mortality prediction.

Optimism correction

🔟 Take-home messages (what the code does)

  • Implements Harrell’s bootstrap optimism correction.
  • Uses the final tuned XGBoost Cox model (Trial 21).
  • Estimates apparent C-index on full dataset.
  • Determines optimal boosting rounds via early stopping.
  • Trains final baseline model on 100% of data.
  • Runs 100 bootstrap resamples in parallel.
  • Retrains model inside each bootstrap sample.
  • Computes performance on bootstrap and original data.
  • Calculates optimism = apparent_boot − test_original.
  • Reports optimism-corrected C-index for internal validation.

🧩 Assumptions (5 key ones)

  • Bootstrap samples approximate the data-generating process.
  • Model structure and hyperparameters are fixed.
  • C-index is appropriate performance metric.
  • IPCW assumptions hold for censoring mechanism.
  • Sample size is large enough for stable bootstrap estimates.
Code
# @title Harrell's Bootstrap Optimism Correction (Death, Parallelized, CPU-2, with 95% CI)
import numpy as np
import pandas as pd
import xgboost as xgb
import os
import gc
import time
import warnings
from sklearn.utils import resample
from sklearn.model_selection import train_test_split
from joblib import Parallel, delayed
from sksurv.metrics import concordance_index_ipcw, concordance_index_censored

start_time = time.time()
warnings.filterwarnings("ignore")

if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

SEED = 2125
B_ITERATIONS = 500
ALPHA = 0.05  # 95% CI
N_CORES = max(1, (os.cpu_count() or 2) - 2)

nb_print("Initializing Parallel Harrell's Bootstrap Optimism Correction for Death...")
nb_print(f"Parallel Execution Configured: Using {N_CORES} CPU cores.")

# --- 1) DATA SETUP ---
try:
    df_tune = imputations_list_mar26[0].copy()
    y_tune_struct = y_surv_death_list[0]

    assert len(df_tune) == len(y_tune_struct), (
        f"X/y mismatch: df_tune={len(df_tune)}, y_tune_struct={len(y_tune_struct)}"
    )
except Exception as e:
    raise ValueError(f"Data Error: {e}. Please ensure mar26 death structures are loaded.")

y_xgb_label = np.where(y_tune_struct["event"], y_tune_struct["time"], -y_tune_struct["time"])

# --- 2) STRATIFICATION HELPERS (dual + fallback for split only) ---
def get_plan_labels(df):
    labels = np.zeros(len(df), dtype=np.int32)
    if "plan_type_corr_pg_pr" in df.columns: labels[df["plan_type_corr_pg_pr"] == 1] = 1
    if "plan_type_corr_m_pr" in df.columns: labels[df["plan_type_corr_m_pr"] == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns: labels[df["plan_type_corr_pg_pai"] == 1] = 3
    if "plan_type_corr_m_pai" in df.columns: labels[df["plan_type_corr_m_pai"] == 1] = 4
    return labels

def get_dual_labels(df, y_struct):
    return get_plan_labels(df) * 10 + np.asarray(y_struct["event"]).astype(np.int32)

def _strat_ok(labels, n_splits=5):
    u, c = np.unique(labels, return_counts=True)
    return len(u) > 1 and c.min() >= n_splits

def pick_split_labels(df, y_struct, n_splits=5):
    dual = get_dual_labels(df, y_struct)
    if _strat_ok(dual, n_splits):
        return dual, "dual"
    plan = get_plan_labels(df)
    if _strat_ok(plan, n_splits):
        nb_print("[Fallback triggered] dual -> plan_only for train/val split")
        return plan, "plan_only"
    evt = np.asarray(y_struct["event"]).astype(np.int32)
    if _strat_ok(evt, n_splits):
        nb_print("[Fallback triggered] dual -> event_only for train/val split")
        return evt, "event_only"
    nb_print("[Fallback triggered] dual -> unstratified split")
    return None, "none"

split_labels, split_mode = pick_split_labels(df_tune, y_tune_struct, n_splits=5)

# --- 3) WINNER HYPERPARAMETERS (use Phase 2 exact if available; fallback = Trial 41 rounded) ---
if "df_phase2" in locals() and isinstance(df_phase2, pd.DataFrame) and len(df_phase2) > 0:
    w = df_phase2.iloc[0]
    params_winner = {
        "objective": "survival:cox",
        "eval_metric": "cox-nloglik",
        "tree_method": "hist",
        "device": "cpu",
        "verbosity": 0,
        "seed": SEED,
        "learning_rate": float(w["learning_rate"]),
        "max_depth": int(w["max_depth"]),
        "min_child_weight": int(w["min_child_weight"]),
        "subsample": float(w["subsample"]),
        "colsample_bytree": float(w["colsample_bytree"]),
        "reg_alpha": float(w["reg_alpha"]),
        "reg_lambda": float(w["reg_lambda"]),
        "gamma": float(w["gamma"]),
    }
    nb_print("Using exact winner hyperparameters from df_phase2 (row 0).")
else:
    params_winner = {
        "objective": "survival:cox",
        "eval_metric": "cox-nloglik",
        "tree_method": "hist",
        "device": "cpu",
        "verbosity": 0,
        "seed": SEED,
        "learning_rate": 0.03096732484594198,
        "max_depth": 6,
        "min_child_weight": 15,
        "subsample": 0.7787433700020623,
        "colsample_bytree": 0.6168899316303945,
        "reg_alpha": 0.01370823890516611,
        "reg_lambda": 2.536795457202958,
        "gamma": 0.9822857777079363,
    }
    nb_print("Using hardcoded Phase-2 winner hyperparameters (Trial 41, rounded).")

# --- 4) APPARENT PERFORMANCE ON ORIGINAL DATA ---
nb_print("Calculating apparent performance on the original full dataset...")

params_initial = params_winner.copy()
params_initial["nthread"] = N_CORES

split_kwargs = dict(test_size=0.2, random_state=SEED)
if split_labels is not None:
    split_kwargs["stratify"] = split_labels

X_train_app, X_val_app, y_train_xgb_app, y_val_xgb_app = train_test_split(
    df_tune, y_xgb_label, **split_kwargs
)

dtrain_app = xgb.DMatrix(X_train_app, label=y_train_xgb_app)
dval_app = xgb.DMatrix(X_val_app, label=y_val_xgb_app)

temp_model = xgb.train(
    params_initial,
    dtrain_app,
    num_boost_round=2000,
    evals=[(dval_app, "val")],
    early_stopping_rounds=30,
    verbose_eval=False
)

if getattr(temp_model, "best_iteration", None) is None or temp_model.best_iteration < 0:
    optimal_boost_rounds = 2000
else:
    optimal_boost_rounds = int(temp_model.best_iteration) + 1

nb_print(f"Optimal boosting rounds determined: {optimal_boost_rounds}")

dorig = xgb.DMatrix(df_tune, label=y_xgb_label)
baseline_model = xgb.train(
    params_initial,
    dorig,
    num_boost_round=optimal_boost_rounds,
    verbose_eval=False
)

risk_orig = baseline_model.predict(dorig)
try:
    c_apparent_orig = float(concordance_index_ipcw(y_tune_struct, y_tune_struct, risk_orig)[0])
except Exception:
    c_apparent_orig = float(
        concordance_index_censored(
            np.asarray(y_tune_struct["event"]),
            np.asarray(y_tune_struct["time"]),
            risk_orig
        )[0]
    )

nb_print(f"Baseline Apparent Global C-index: {c_apparent_orig:.4f}")

# --- 5) PARALLEL BOOTSTRAP WORKER ---
def parallel_bootstrap_worker(b, df_original, y_xgb_original, y_struct_orig, params, opt_rounds):
    boot_params = params.copy()
    boot_params["nthread"] = 1  # avoid CPU thrashing in parallel workers

    idx = np.arange(len(df_original))
    boot_idx = resample(idx, replace=True, n_samples=len(idx), random_state=SEED + b)

    X_boot = df_original.iloc[boot_idx]
    y_xgb_boot = y_xgb_original[boot_idx]
    y_struct_boot = y_struct_orig[boot_idx]

    dboot = xgb.DMatrix(X_boot, label=y_xgb_boot)
    dorig_local = xgb.DMatrix(df_original, label=y_xgb_original)

    model = xgb.train(
        boot_params,
        dboot,
        num_boost_round=opt_rounds,
        verbose_eval=False
    )

    risk_boot = model.predict(dboot)
    try:
        c_boot_app = float(concordance_index_ipcw(y_struct_boot, y_struct_boot, risk_boot)[0])
    except Exception:
        c_boot_app = float(
            concordance_index_censored(
                np.asarray(y_struct_boot["event"]),
                np.asarray(y_struct_boot["time"]),
                risk_boot
            )[0]
        )

    risk_test_orig = model.predict(dorig_local)
    try:
        c_boot_test = float(concordance_index_ipcw(y_struct_boot, y_struct_orig, risk_test_orig)[0])
    except Exception:
        c_boot_test = float(
            concordance_index_censored(
                np.asarray(y_struct_orig["event"]),
                np.asarray(y_struct_orig["time"]),
                risk_test_orig
            )[0]
        )

    optimism = c_boot_app - c_boot_test

    del model, dboot, dorig_local, risk_boot, risk_test_orig
    gc.collect()
    return optimism

# --- 6) RUN BOOTSTRAP ---
nb_print(f"\nLaunching {B_ITERATIONS} Parallel Bootstrap Iterations...")
optimism_values = Parallel(n_jobs=N_CORES, verbose=10)(
    delayed(parallel_bootstrap_worker)(
        b, df_tune, y_xgb_label, y_tune_struct, params_winner, optimal_boost_rounds
    )
    for b in range(B_ITERATIONS)
)

# --- 7) FINAL METRICS + 95% CI ---
optimism_values = np.asarray(optimism_values, dtype=float)
optimism_values = optimism_values[np.isfinite(optimism_values)]

if optimism_values.size == 0:
    raise ValueError("No valid bootstrap optimism values were produced.")

mean_optimism = float(np.mean(optimism_values))
c_index_corrected = float(c_apparent_orig - mean_optimism)

corrected_samples = c_apparent_orig - optimism_values
opt_ci_low, opt_ci_high = np.quantile(optimism_values, [ALPHA / 2, 1 - ALPHA / 2])
corr_ci_low, corr_ci_high = np.quantile(corrected_samples, [ALPHA / 2, 1 - ALPHA / 2])

nb_print("\n--------------------------------------------------")
nb_print("FINAL OPTIMISM-CORRECTED RESULTS (DEATH)")
nb_print("--------------------------------------------------")
nb_print(f"Split stratification mode                         : {split_mode}")
nb_print(f"Apparent C-Index (Original Data)                 : {c_apparent_orig:.4f}")
nb_print(f"Mean Optimism (from {optimism_values.size} boots): {mean_optimism:.4f}")
nb_print(f"Optimism 95% CI                                  : [{opt_ci_low:.4f}, {opt_ci_high:.4f}]")
nb_print(f"Optimism-Corrected C-Index                       : {c_index_corrected:.4f}")
nb_print(f"Corrected C-Index 95% CI                         : [{corr_ci_low:.4f}, {corr_ci_high:.4f}]")
nb_print("--------------------------------------------------")

# --- 8) EXPORT ---
os.makedirs("_out", exist_ok=True)
timestamp_str = pd.Timestamp.now().strftime("%Y%m%d_%H%M")

summary_df = pd.DataFrame({
    "Metric": ["Apparent_C_Index", "Mean_Optimism", "Corrected_C_Index"],
    "Value": [c_apparent_orig, mean_optimism, c_index_corrected],
    "CI_95_Lower": [np.nan, opt_ci_low, corr_ci_low],
    "CI_95_Upper": [np.nan, opt_ci_high, corr_ci_high],
    "Bootstrap_N": [np.nan, optimism_values.size, optimism_values.size],
    "Split_Strat_Mode": [split_mode, split_mode, split_mode]
})
summary_file = f"_out/XGB_Death_Bootstrap_Optimism_Results_{timestamp_str}_mar26.csv"

dist_df = pd.DataFrame({
    "optimism": optimism_values,
    "corrected_c_index_sample": corrected_samples
})
dist_file = f"_out/XGB_Death_Bootstrap_Optimism_Distribution_{timestamp_str}_mar26.csv"
dist_df.to_csv(dist_file, index=False)

nb_print(f"Results saved successfully to {summary_file}.")
nb_print(f"Bootstrap distribution saved to {dist_file}.")

elapsed_minutes = (time.time() - start_time) / 60
nb_print(f"Time taken: {elapsed_minutes:.2f} minutes")
Initializing Parallel Harrell's Bootstrap Optimism Correction for Death...
Parallel Execution Configured: Using 30 CPU cores.
Using exact winner hyperparameters from df_phase2 (row 0).
Calculating apparent performance on the original full dataset...
Optimal boosting rounds determined: 156
Baseline Apparent Global C-index: 0.8042
Launching 500 Parallel Bootstrap Iterations...
[Parallel(n_jobs=30)]: Using backend LokyBackend with 30 concurrent workers.
[Parallel(n_jobs=30)]: Done   1 tasks      | elapsed:   19.8s
[Parallel(n_jobs=30)]: Done  12 tasks      | elapsed:   22.7s
[Parallel(n_jobs=30)]: Done  25 tasks      | elapsed:   24.2s
[Parallel(n_jobs=30)]: Done  38 tasks      | elapsed:   39.0s
[Parallel(n_jobs=30)]: Done  53 tasks      | elapsed:   42.0s
[Parallel(n_jobs=30)]: Done  68 tasks      | elapsed:   54.2s
[Parallel(n_jobs=30)]: Done  85 tasks      | elapsed:  1.0min
[Parallel(n_jobs=30)]: Done 102 tasks      | elapsed:  1.2min
[Parallel(n_jobs=30)]: Done 121 tasks      | elapsed:  1.3min
[Parallel(n_jobs=30)]: Done 140 tasks      | elapsed:  1.5min
[Parallel(n_jobs=30)]: Done 161 tasks      | elapsed:  1.7min
[Parallel(n_jobs=30)]: Done 182 tasks      | elapsed:  2.0min
[Parallel(n_jobs=30)]: Done 205 tasks      | elapsed:  2.1min
[Parallel(n_jobs=30)]: Done 228 tasks      | elapsed:  2.4min
[Parallel(n_jobs=30)]: Done 253 tasks      | elapsed:  2.5min
[Parallel(n_jobs=30)]: Done 278 tasks      | elapsed:  2.8min
[Parallel(n_jobs=30)]: Done 305 tasks      | elapsed:  3.0min
[Parallel(n_jobs=30)]: Done 332 tasks      | elapsed:  3.3min
[Parallel(n_jobs=30)]: Done 361 tasks      | elapsed:  3.5min
[Parallel(n_jobs=30)]: Done 390 tasks      | elapsed:  3.7min
[Parallel(n_jobs=30)]: Done 421 tasks      | elapsed:  4.0min
[Parallel(n_jobs=30)]: Done 492 out of 500 | elapsed:  4.6min remaining:    4.4s
[Parallel(n_jobs=30)]: Done 500 out of 500 | elapsed:  4.6min finished
--------------------------------------------------
FINAL OPTIMISM-CORRECTED RESULTS (DEATH)
--------------------------------------------------
Split stratification mode                         : dual
Apparent C-Index (Original Data)                 : 0.8042
Mean Optimism (from 500 boots): 0.0457
Optimism 95% CI                                  : [0.0360, 0.0560]
Optimism-Corrected C-Index                       : 0.7584
Corrected C-Index 95% CI                         : [0.7482, 0.7681]
--------------------------------------------------
Results saved successfully to _out/XGB_Death_Bootstrap_Optimism_Results_20260305_2059_mar26.csv.
Bootstrap distribution saved to _out/XGB_Death_Bootstrap_Optimism_Distribution_20260305_2059_mar26.csv.
Time taken: 4.70 minutes
Code
from IPython.display import HTML, display
html_table = summary_df.to_html(index=True, escape=False)
scroll_box = f"""
<div style="max-height:600px; max-width:100%; overflow-y:auto; overflow-x:auto; border:1px solid #ddd; padding:6px;">
{html_table}
</div>
"""
display(HTML(scroll_box))
Metric Value CI_95_Lower CI_95_Upper Bootstrap_N Split_Strat_Mode
0 Apparent_C_Index 0.804162 NaN NaN NaN dual
1 Mean_Optimism 0.045736 0.036018 0.055972 500.0 dual
2 Corrected_C_Index 0.758426 0.748190 0.768144 500.0 dual
Back to top