Hyperparameter tuning based on readmission discrimination

This notebook tunes and validates an XGBoost model for readmission risk in a substance use disorder (SUD) treatment cohort, using 5 multiple imputations and ~56 predictors (demographic, clinical, socioeconomic, and treatment variables).

The model uses survival:cox to estimate the cause-specific hazard for readmission. Because death is a competing event, absolute risk calibration is evaluated with an Aalen–Johansen-consistent cumulative incidence approximation derived from model risk scores.

Hyperparameters are optimized with Optuna multi-objective Pareto search, maximizing multi-horizon IPCW Uno’s C-index and minimizing mean time-specific Brier score at 3, 6, 12, 36, and 60 months.

Validation uses 5-fold cross-validation stratified by treatment plan type (readmission workflow is plan-stratified, not dual-stratified). The optimization follows a two-phase multiple-imputation strategy: Phase 1 stochastic MI (one imputation per trial), then Phase 2 rescoring of top Pareto candidates across all imputations to select a robust final configuration.

Final reporting includes discrimination/calibration metrics, across-imputation stability, and bootstrap optimism correction with 95% confidence intervals. Runs are CPU-only with seed 2125 for reproducibility.

Author

ags

Published

March 5, 2026

Hyperparameter tuning XGBOOST (readmission as a reference)

0. Package loading and installation

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda install -c conda-forge \
#    numpy \
#    scipy \
#    pandas \
#    pyarrow \
#    scikit-survival \
#    spyder \
#    lifelines

# conda install -c conda-forge fastparquet
# conda install -c conda-forge xgboost
# conda install -c conda-forge pytorch cpuonly
# conda install -c pytorch pytorch cpuonly
# conda install -c conda-forge matplotlib
# conda install -c conda-forge seaborn
# conda install spyder-notebook -c spyder-ide
# conda install notebook nbformat nbconvert
# conda install -c conda-forge xlsxwriter
# conda install -c conda-forge shap

# import subprocess, sys

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "matplotlib"
# ])

# subprocess.check_call([
#     sys.executable,
#     "-m",
#     "pip",
#     "install",
#     "seaborn"
# ])

print("numpy:", np.__version__)


from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv

#Dput
def dput_df(df, digits=6):
    data = {
        "columns": list(df.columns),
        "data": [
            [round(x, digits) if isinstance(x, (float, np.floating)) else x
             for x in row]
            for row in df.to_numpy()
        ]
    }
    print(data)


#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df
numpy: 2.0.1

Load data

Code

from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)

import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)
Code

import pandas as pd

for i in range(1, 6):
    globals()[f"imputation_nodum_{i}"] = pd.read_parquet(
        BASE_DIR / f"imputation_nondum_{i}.parquet",
        engine="fastparquet"
    )
Code
from IPython.display import display, HTML
import io
import sys

def fold_output(title, func):
    buffer = io.StringIO()
    sys.stdout = buffer
    func()
    sys.stdout = sys.__stdout__
    
    html = f"""
    <details>
      <summary>{title}</summary>
      <pre>{buffer.getvalue()}</pre>
    </details>
    """
    display(HTML(html))


fold_output(
    "Show imputation_nodum_1 structure",
    lambda: imputation_nodum_1.info()
)

fold_output(
    "Show imputation_1 structure",
    lambda: imputation_1.info()
)
Show imputation_nodum_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 43 columns):
 #   Column                         Non-Null Count  Dtype  
---  ------                         --------------  -----  
 0   readmit_time_from_adm_m        88504 non-null  float64
 1   death_time_from_adm_m          88504 non-null  float64
 2   adm_age_rec3                   88504 non-null  float64
 3   porc_pobr                      88504 non-null  float64
 4   dit_m                          88504 non-null  float64
 5   sex_rec                        88504 non-null  object 
 6   tenure_status_household        88504 non-null  object 
 7   cohabitation                   88504 non-null  object 
 8   sub_dep_icd10_status           88504 non-null  object 
 9   any_violence                   88504 non-null  object 
 10  prim_sub_freq_rec              88504 non-null  object 
 11  tr_outcome                     88504 non-null  object 
 12  adm_motive                     88504 non-null  object 
 13  first_sub_used                 88504 non-null  object 
 14  primary_sub_mod                88504 non-null  object 
 15  tipo_de_vivienda_rec2          88504 non-null  object 
 16  national_foreign               88504 non-null  int32  
 17  plan_type_corr                 88504 non-null  object 
 18  occupation_condition_corr24    88504 non-null  object 
 19  marital_status_rec             88504 non-null  object 
 20  urbanicity_cat                 88504 non-null  object 
 21  ed_attainment_corr             88504 non-null  object 
 22  evaluacindelprocesoteraputico  88504 non-null  object 
 23  eva_consumo                    88504 non-null  object 
 24  eva_fam                        88504 non-null  object 
 25  eva_relinterp                  88504 non-null  object 
 26  eva_ocupacion                  88504 non-null  object 
 27  eva_sm                         88504 non-null  object 
 28  eva_fisica                     88504 non-null  object 
 29  eva_transgnorma                88504 non-null  object 
 30  ethnicity                      88504 non-null  float64
 31  dg_psiq_cie_10_instudy         88504 non-null  bool   
 32  dg_psiq_cie_10_dg              88504 non-null  bool   
 33  dx_f3_mood                     88504 non-null  int32  
 34  dx_f6_personality              88504 non-null  int32  
 35  dx_f_any_severe_mental         88504 non-null  bool   
 36  any_phys_dx                    88504 non-null  bool   
 37  polysubstance_strict           88504 non-null  int32  
 38  readmit_event                  88504 non-null  float64
 39  death_event                    88504 non-null  int32  
 40  readmit_time_from_disch_m      88504 non-null  float64
 41  death_time_from_disch_m        88504 non-null  float64
 42  center_id                      88475 non-null  object 
dtypes: bool(4), float64(9), int32(5), object(25)
memory usage: 25.0+ MB
Show imputation_1 structure

RangeIndex: 88504 entries, 0 to 88503
Data columns (total 78 columns):
 #   Column                                                              Non-Null Count  Dtype  
---  ------                                                              --------------  -----  
 0   readmit_time_from_adm_m                                             88504 non-null  float64
 1   death_time_from_adm_m                                               88504 non-null  float64
 2   adm_age_rec3                                                        88504 non-null  float64
 3   porc_pobr                                                           88504 non-null  float64
 4   dit_m                                                               88504 non-null  float64
 5   national_foreign                                                    88504 non-null  int32  
 6   ethnicity                                                           88504 non-null  float64
 7   dg_psiq_cie_10_instudy                                              88504 non-null  bool   
 8   dg_psiq_cie_10_dg                                                   88504 non-null  bool   
 9   dx_f3_mood                                                          88504 non-null  int32  
 10  dx_f6_personality                                                   88504 non-null  int32  
 11  dx_f_any_severe_mental                                              88504 non-null  bool   
 12  any_phys_dx                                                         88504 non-null  bool   
 13  polysubstance_strict                                                88504 non-null  int32  
 14  readmit_time_from_disch_m                                           88504 non-null  float64
 15  readmit_event                                                       88504 non-null  float64
 16  death_time_from_disch_m                                             88504 non-null  float64
 17  death_event                                                         88504 non-null  int32  
 18  sex_rec_woman                                                       88504 non-null  float64
 19  tenure_status_household_illegal_settlement                          88504 non-null  float64
 20  tenure_status_household_owner_transferred_dwellings_pays_dividends  88504 non-null  float64
 21  tenure_status_household_renting                                     88504 non-null  float64
 22  tenure_status_household_stays_temporarily_with_a_relative           88504 non-null  float64
 23  cohabitation_alone                                                  88504 non-null  float64
 24  cohabitation_with_couple_children                                   88504 non-null  float64
 25  cohabitation_family_of_origin                                       88504 non-null  float64
 26  sub_dep_icd10_status_drug_dependence                                88504 non-null  float64
 27  any_violence_1_domestic_violence_sex_abuse                          88504 non-null  float64
 28  prim_sub_freq_rec_2_2_6_days_wk                                     88504 non-null  float64
 29  prim_sub_freq_rec_3_daily                                           88504 non-null  float64
 30  tr_outcome_adm_discharge_adm_reasons                                88504 non-null  float64
 31  tr_outcome_adm_discharge_rule_violation_undet                       88504 non-null  float64
 32  tr_outcome_completion                                               88504 non-null  float64
 33  tr_outcome_dropout                                                  88504 non-null  float64
 34  tr_outcome_referral                                                 88504 non-null  float64
 35  adm_motive_another_sud_facility_fonodrogas_senda_previene           88504 non-null  float64
 36  adm_motive_justice_sector                                           88504 non-null  float64
 37  adm_motive_sanitary_sector                                          88504 non-null  float64
 38  adm_motive_spontaneous_consultation                                 88504 non-null  float64
 39  first_sub_used_alcohol                                              88504 non-null  float64
 40  first_sub_used_cocaine_paste                                        88504 non-null  float64
 41  first_sub_used_cocaine_powder                                       88504 non-null  float64
 42  first_sub_used_marijuana                                            88504 non-null  float64
 43  first_sub_used_opioids                                              88504 non-null  float64
 44  first_sub_used_tranquilizers_hypnotics                              88504 non-null  float64
 45  primary_sub_mod_cocaine_paste                                       88504 non-null  float64
 46  primary_sub_mod_cocaine_powder                                      88504 non-null  float64
 47  primary_sub_mod_alcohol                                             88504 non-null  float64
 48  primary_sub_mod_marijuana                                           88504 non-null  float64
 49  tipo_de_vivienda_rec2_other_unknown                                 88504 non-null  float64
 50  plan_type_corr_m_pai                                                88504 non-null  float64
 51  plan_type_corr_m_pr                                                 88504 non-null  float64
 52  plan_type_corr_pg_pai                                               88504 non-null  float64
 53  plan_type_corr_pg_pr                                                88504 non-null  float64
 54  occupation_condition_corr24_inactive                                88504 non-null  float64
 55  occupation_condition_corr24_unemployed                              88504 non-null  float64
 56  marital_status_rec_separated_divorced_annulled_widowed              88504 non-null  float64
 57  marital_status_rec_single                                           88504 non-null  float64
 58  urbanicity_cat_1_rural                                              88504 non-null  float64
 59  urbanicity_cat_2_mixed                                              88504 non-null  float64
 60  ed_attainment_corr_2_completed_high_school_or_less                  88504 non-null  float64
 61  ed_attainment_corr_3_completed_primary_school_or_less               88504 non-null  float64
 62  evaluacindelprocesoteraputico_logro_intermedio                      88504 non-null  float64
 63  evaluacindelprocesoteraputico_logro_minimo                          88504 non-null  float64
 64  eva_consumo_logro_intermedio                                        88504 non-null  float64
 65  eva_consumo_logro_minimo                                            88504 non-null  float64
 66  eva_fam_logro_intermedio                                            88504 non-null  float64
 67  eva_fam_logro_minimo                                                88504 non-null  float64
 68  eva_relinterp_logro_intermedio                                      88504 non-null  float64
 69  eva_relinterp_logro_minimo                                          88504 non-null  float64
 70  eva_ocupacion_logro_intermedio                                      88504 non-null  float64
 71  eva_ocupacion_logro_minimo                                          88504 non-null  float64
 72  eva_sm_logro_intermedio                                             88504 non-null  float64
 73  eva_sm_logro_minimo                                                 88504 non-null  float64
 74  eva_fisica_logro_intermedio                                         88504 non-null  float64
 75  eva_fisica_logro_minimo                                             88504 non-null  float64
 76  eva_transgnorma_logro_intermedio                                    88504 non-null  float64
 77  eva_transgnorma_logro_minimo                                        88504 non-null  float64
dtypes: bool(4), float64(69), int32(5)
memory usage: 48.6 MB
Code
from IPython.display import display, Markdown

if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    display(Markdown(f"**First element type:** `{type(imputations_list_jan26[0])}`"))

    if isinstance(imputations_list_jan26[0], dict):
        display(Markdown(f"**First element keys:** `{list(imputations_list_jan26[0].keys())}`"))

    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        display(Markdown(f"**First element shape:** `{imputations_list_jan26[0].shape}`"))

First element type: <class 'pandas.DataFrame'>

First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Format data

Due to inconsistencies and structural heterogeneity across previously merged datasets, we decided not to proceed with a direct inspection and comparison of column names between the first imputed dataset from imputations_list_jan26 (which likely included dummy-encoded variables) and imputation_nodum_1 (which likely retained non–dummy-encoded variables).

Instead, we reconstructed the analytic datasets de novo using the most recent source files available in the original directory (BASE_DIR). Time-to-event variables were re-derived to ensure internal consistency. Variables that could introduce information leakage (e.g., time from admission) were excluded, and the center identifier variable was removed prior to modeling.

Code
#1.2. Build Surv objects from df_final
from IPython.display import display, Markdown
from sksurv.util import Surv

for i in range(1, 6):
    # Get the DataFrame
    df = globals()[f"imputation_nodum_{i}"]

    # Extract time and event arrays
    time_readm  = df["readmit_time_from_disch_m"].to_numpy()
    event_readm = (df["readmit_event"].to_numpy() == 1)
    time_death  = df["death_time_from_disch_m"].to_numpy()
    event_death = (df["death_event"].to_numpy() == 1)

    # Create survival objects
    y_surv_readm = Surv.from_arrays(event=event_readm, time=time_readm)
    y_surv_death = Surv.from_arrays(event=event_death, time=time_death)

    # Store in global variables (optional but matches your pattern)
    globals()[f"y_surv_readm_{i}"]  = y_surv_readm
    globals()[f"y_surv_death_{i}"]  = y_surv_death

    # Print info
    display(Markdown(f"\n--- Imputation {i} ---"))
    display(Markdown(
    f"**y_surv_readm dtype:** {y_surv_readm.dtype}  \n"
    f"**shape:** {y_surv_readm.shape}"
    ))
    display(Markdown(
    f"**y_surv_death dtype:** {y_surv_death.dtype}  \n"
    f"**shape:** {y_surv_death.shape}"
    ))

— Imputation 1 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 2 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 3 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 4 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

— Imputation 5 —

y_surv_readm dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

y_surv_death dtype: [(‘event’, ‘?’), (‘time’, ‘<f8’)]
shape: (88504,)

Code
fold_output(
    "Show imputation_nodum_1 (newer database) glimpse",
    lambda: glimpse(imputation_nodum_1)
)
fold_output(
    "Show first db of imputations_list_jan26 (older) glimpse",
    lambda: glimpse(imputations_list_jan26[0])
)
Show imputation_nodum_1 (newer database) glimpse
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Show first db of imputations_list_jan26 (older) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             float64         1.0, 2.0, 1.0, 1.0, 2.0
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset (1–5), we identified and removed predictors with zero variance, as they provide no useful information and can destabilize models. We printed the dropped variables and produced a cleaned version of each design matrix. This ensures that all downstream analyses use only informative predictors.

Code
# Keep only these objects
objects_to_keep = {
    "objects_to_keep",
    "imputation_nodum_1",
    "imputation_nodum_2",
    "imputation_nodum_3",
    "imputation_nodum_4",
    "imputation_nodum_5",
    "y_surv_readm",
    "y_surv_death",
    "imputations_list_jan26"
}

import types

for name in list(globals().keys()):
    obj = globals()[name]
    if (
        name not in objects_to_keep
        and not name.startswith("_")
        and not callable(obj)
        and not isinstance(obj, types.ModuleType)  # <- protects ALL modules
    ):
        del globals()[name]
Code
from IPython.display import display, Markdown

# 1. Define columns to exclude (same as before)
target_cols = [
    "readmit_time_from_disch_m",
    "readmit_event",
    "death_time_from_disch_m",
    "death_event",
]

leak_time_cols = [
    "readmit_time_from_adm_m",
    "death_time_from_adm_m",
]

center_id = ["center_id"]

cols_to_exclude = target_cols + center_id  + leak_time_cols

# 2. Create list of your EXISTING imputation DataFrames (1-5)
imputed_dfs = [
    imputation_nodum_1,
    imputation_nodum_2,
    imputation_nodum_3,
    imputation_nodum_4,
    imputation_nodum_5
]

# 3. Preprocessing loop
X_reduced_list = []

for d, df in enumerate(imputed_dfs):
    imputation_num = d + 1  # Convert 0-index to 1-index for display

    display(Markdown(f"\n=== Imputation dataset {imputation_num} ==="))

    # a) Identify and drop constant predictors
    const_mask = (df.nunique(dropna=False) <= 1)
    dropped_const = df.columns[const_mask].tolist()
    display(Markdown(f"**Constant predictors dropped ({len(dropped_const)}):**"))
    display(Markdown(f"{dropped_const if dropped_const else 'None'}"))

    # b) Remove constant columns
    X_reduced = df.loc[:, ~const_mask]

    # c) Drop target/leakage columns (if present)
    cols_to_drop = [col for col in cols_to_exclude if col in X_reduced.columns]
    if cols_to_drop:
        X_reduced = X_reduced.drop(columns=cols_to_drop)
        display(Markdown(f"**Dropped target/leakage columns:** {cols_to_drop}"))
    else:
        display(Markdown("No target/leakage columns found to drop"))

    # d) Store cleaned DataFrame
    X_reduced_list.append(X_reduced)

    # e) Report shapes
    display(Markdown(f"**Original shape:** {df.shape}"))
    display(Markdown(
        f"**Cleaned shape:** {X_reduced.shape} "
        f"(removed {df.shape[1] - X_reduced.shape[1]} columns)"
    ))

display(Markdown("\n✅ **Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.**"))

=== Imputation dataset 1 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 2 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 3 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 4 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

=== Imputation dataset 5 ===

Constant predictors dropped (0):

None

Dropped target/leakage columns: [‘readmit_time_from_disch_m’, ‘readmit_event’, ‘death_time_from_disch_m’, ‘death_event’, ‘center_id’, ‘readmit_time_from_adm_m’, ‘death_time_from_adm_m’]

Original shape: (88504, 43)

Cleaned shape: (88504, 36) (removed 7 columns)

Preprocessing complete! X_reduced_list contains 5 cleaned DataFrames.

Dummify

A structured preprocessing pipeline was implemented prior to modeling. Ordered categorical variables (e.g., housing status, educational attainment, clinical evaluations, and substance use frequency) were manually mapped to numeric scales reflecting their natural ordering. For nominal categorical variables, prespecified reference categories were enforced to ensure consistent baseline comparisons across imputations. All remaining categorical predictors were then converted to dummy variables using one-hot encoding with the first category dropped to prevent multicollinearity. The procedure was applied consistently across all imputed datasets to ensure harmonized model inputs.

Code
import pandas as pd
import numpy as np
from sklearn.preprocessing import OrdinalEncoder
import pandas as pd
import numpy as np
from pandas.api.types import CategoricalDtype

def preprocess_features_robust(df):
    df_clean = df.copy()

    # ---------------------------------------------------------
    # 1. Ordinal encoding (your existing code)
    # ---------------------------------------------------------
    ordered_mappings = {
        # --- NEW: Housing & Urbanicity ---
        "tenure_status_household": {
            "illegal settlement": 4,                       # Situación Calle
            "stays temporarily with a relative": 3,        # Allegado
            "others": 2,                                   # En pensión / Otros
            "renting": 1,                                  # Arrendando
            "owner/transferred dwellings/pays dividends": 0 # Vivienda Propia
        },
        "urbanicity_cat": {
            "1.Rural": 2,
            "2.Mixed": 1,
            "3.Urban": 0
        },

        # --- Clinical Evaluations (Minimo -> Intermedio -> Alto) ---
        "evaluacindelprocesoteraputico": {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_consumo":      {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fam":          {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_relinterp":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_ocupacion":    {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_sm":           {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_fisica":       {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},
        "eva_transgnorma":  {"logro minimo": 2, "logro intermedio": 1, "logro alto": 0},

        # --- Frequency (Less freq -> More freq) ---
        "prim_sub_freq_rec": {
            "1.≤1 day/wk": 0,
            "2.2–6 days/wk": 1,
            "3.Daily": 2
        },

        # --- Education (Less -> More) ---
        "ed_attainment_corr": {
            "3-Completed primary school or less": 2,
            "2-Completed high school or less": 1,
            "1-More than high school": 0
        }
    }

    for col, mapping in ordered_mappings.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            df_clean[col] = df_clean[col].map(mapping)

            n_missing = df_clean[col].isnull().sum()
            if n_missing > 0:
                if n_missing == len(df_clean):
                    print(f"⚠️ WARNING: Mapping failed completely for '{col}'.")
                mode_val = df_clean[col].mode()[0]
                df_clean[col] = df_clean[col].fillna(mode_val)

    # ---------------------------------------------------------
    # 2. FORCE reference categories for dummies
    # ---------------------------------------------------------
    dummy_reference = {
        "sex_rec": "man",
        "plan_type_corr": "ambulatory",
        "marital_status_rec": "married/cohabiting",
        "cohabitation": "alone",
        "sub_dep_icd10_status": "hazardous consumption",
        "tr_outcome": "completion",
        "adm_motive": "spontaneous consultation",
        "tipo_de_vivienda_rec2": "formal housing",
        "plan_type_corr": "pg-pab",
        "occupation_condition_corr24": "employed",
        "any_violence": "0.No domestic violence/sex abuse",
        "first_sub_used": "marijuana",
        }

    for col, ref in dummy_reference.items():
        if col in df_clean.columns:
            df_clean[col] = df_clean[col].astype(str).str.strip()
            cats = df_clean[col].unique().tolist()

            if ref in cats:
                new_order = [ref] + [c for c in cats if c != ref]
                cat_type = CategoricalDtype(categories=new_order, ordered=False)
                df_clean[col] = df_clean[col].astype(cat_type)
            else:
                print(f"⚠️ Reference '{ref}' not found in {col}")

    # ---------------------------------------------------------
    # 3. One-hot encoding
    # ---------------------------------------------------------
    df_final = pd.get_dummies(df_clean, drop_first=True, dtype=float)

    return df_final

X_encoded_list_final = [preprocess_features_robust(X) for X in X_reduced_list]
X_encoded_list_final = [clean_names(X) for X in X_encoded_list_final]
Code
from IPython.display import display, Markdown

# 1. DIAGNOSTIC: Check exact string values
display(Markdown("### --- Diagnostic Check ---"))
sample_df = X_encoded_list_final[0]

if 'tenure_status_household' in sample_df.columns:
    display(Markdown("**Unique values in 'tenure_status_household':**"))
    display(Markdown(str(sample_df['tenure_status_household'].unique())))
else:
    display(Markdown("❌ 'tenure_status_household' is missing entirely from input data!"))

if 'urbanicity_cat' in sample_df.columns:
    display(Markdown("**Unique values in 'urbanicity_cat':**"))
    display(Markdown(str(sample_df['urbanicity_cat'].unique())))

if 'ed_attainment_corr' in sample_df.columns:
    display(Markdown("**Unique values in 'ed_attainment_corr':**"))
    display(Markdown(str(sample_df['ed_attainment_corr'].unique())))

— Diagnostic Check —

Unique values in ‘tenure_status_household’:

[3 0 1 2 4]

Unique values in ‘urbanicity_cat’:

[0 1 2]

Unique values in ‘ed_attainment_corr’:

[1 2 0]

We recoded first substance use so small categories are grouped into Others

Code
# Columns to combine
cols_to_group = [
    "first_sub_used_opioids",
    "first_sub_used_others",
    "first_sub_used_hallucinogens",
    "first_sub_used_inhalants",
    "first_sub_used_tranquilizers_hypnotics",
    "first_sub_used_amphetamine_type_stimulants",
]

# Loop over datasets 0–4 and modify in place
for i in range(5):
    df = X_encoded_list_final[i].copy()
    # Collapse into one dummy: if any of these == 1, mark as 1
    df["first_sub_used_other"] = df[cols_to_group].max(axis=1)
    # Drop the rest except the new combined column
    df = df.drop(columns=[c for c in cols_to_group if c != "first_sub_used_other"])
    # Replace the dataset in the original list
    X_encoded_list_final[i] = df
Code
import sys
fold_output(
    "Show first db of X_encoded_list_final (newer) glimpse",
    lambda: glimpse(X_encoded_list_final[0])
)
Show first db of X_encoded_list_final (newer) glimpse
Rows: 88504 | Columns: 56
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
tenure_status_household        int64           3, 0, 3, 0, 3
prim_sub_freq_rec              int64           1, 2, 2, 2, 2
national_foreign               int32           0, 0, 0, 0, 0
urbanicity_cat                 int64           0, 0, 0, 0, 0
ed_attainment_corr             int64           1, 2, 1, 1, 2
evaluacindelprocesoteraputico  int64           0, 2, 2, 2, 0
eva_consumo                    int64           0, 2, 2, 1, 0
eva_fam                        int64           1, 2, 2, 1, 0
eva_relinterp                  int64           0, 2, 2, 1, 0
eva_ocupacion                  int64           0, 2, 2, 2, 1
eva_sm                         int64           1, 2, 2, 1, 2
eva_fisica                     int64           0, 2, 1, 1, 0
eva_transgnorma                int64           0, 2, 2, 2, 1
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_others            float64         0.0, 0.0, 0.0, 0.0, 0.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_other               float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_others         float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_other           float64         0.0, 0.0, 0.0, 0.0, 0.0

For each imputed dataset, we fitted two regularized Cox models (one for readmission and one for death) using Coxnet, which applies elastic-net penalization with a strong LASSO component to enable variable selection. The loop fits both models on every imputation, prints basic model information, and stores all fitted models so they can later be combined or compared across imputations.

Create bins for followup (landmarks)

We extracted the observed event times and corresponding event indicators directly from the structured survival objects (y_surv_readm and y_surv_death). Using the observed event times, we constructed evaluation grids based on the 5th to 95th percentiles of the event-time distribution. These grids define standardized time points at which model performance is assessed for both readmission and mortality outcomes.

Code
import numpy as np
from IPython.display import display, Markdown

# Extract event times directly from structured arrays
event_times_readm = y_surv_readm["time"][y_surv_readm["event"]]
event_times_death = y_surv_death["time"][y_surv_death["event"]]

# Build evaluation grids (5th–95th percentiles, 50 points)
times_eval_readm = np.unique(
    np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50))
)

times_eval_death = np.unique(
    np.quantile(event_times_death, np.linspace(0.05, 0.95, 50))
)

# Display only final result
display(Markdown(
    f"**Eval times (readmission):** `{times_eval_readm[:5]}` ... `{times_eval_readm[-5:]}`"
))

display(Markdown(
    f"**Eval times (death):** `{times_eval_death[:5]}` ... `{times_eval_death[-5:]}`"
))

Eval times (readmission): [0.38709677 0.67741935 1.03225806 1.41935484 1.76666667][46.81833443 50.96030612 55.16129032 60.84848585 67.08322581]

Eval times (death): [0. 0.09677419 1.06666667 2.1691691 3.34812377][74.72632653 78.4516129 82.39472203 86.41935484 92.36311828]

Correct inmortal time bias

First, we eliminated inmortal time bias (dead patients look like without readmission).

This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.

Code
import numpy as np

# Step 3. Replicate across imputations (safe copies)
n_imputations = len(X_encoded_list_final)
y_surv_readm_list = [y_surv_readm.copy() for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death.copy() for _ in range(n_imputations)]

def correct_competing_risks(y_readm_list, y_death_list):
    corrected = []
    for y_readm, y_death in zip(y_readm_list, y_death_list):
        y_corr = y_readm.copy()

        # death observed and occurs before (or at) readmission/censoring time
        mask = (y_death["event"]) & (y_death["time"] < y_corr["time"])

        y_corr["event"][mask] = False
        y_corr["time"][mask] = y_death["time"][mask]

        corrected.append(y_corr)
    return corrected

# Step 4. Apply correction
y_surv_readm_list_corrected = correct_competing_risks(
    y_surv_readm_list,
    y_surv_death_list
)
Code
# Check type and length
type(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)

# Look at the first element
y_surv_readm_list_corrected[0][:5]   # first 5 rows
array([(False, 68.96774194), ( True,  7.        ), ( True, 13.25806452),
       ( True,  5.        ), ( True,  7.35483871)],
      dtype=[('event', '?'), ('time', '<f8')])
Code
from IPython.display import display, HTML
import html

def nb_print(*args, sep=" "):
    msg = sep.join(str(a) for a in args)
    display(HTML(f"<pre style='margin:0'>{html.escape(msg)}</pre>"))

The fully preprocessed and encoded feature matrices were renamed from X_encoded_list_final to imputations_list_mar26 to reflect the finalized February 2026 analytic version of the imputed datasets.

This object contains the harmonized, ordinal-encoded, and one-hot encoded predictor matrices for all five imputations and will serve as the definitive input for subsequent modeling procedures.

Code
imputations_list_mar26 = X_encoded_list_final
del X_encoded_list_final
Code
import numpy as np
import pandas as pd
from IPython.display import display, Markdown

# ── Build exclusion mask (same for all imputations, based on imputation 0) ──
df0 = imputations_list_mar26[0]

# Condition 1: tr_outcome_adm_discharge_adm_reasons == 1 AND death time ≤ 7
mask_adm_death = (
    (df0["tr_outcome_adm_discharge_adm_reasons"] == 1)
    & (y_surv_death["event"] == True)
    & (y_surv_death["time"] <= 0.23)
)

# Condition 2: tr_outcome_other == 1 (any time)
mask_other = df0["tr_outcome_other"] == 1

# Combined exclusion mask
exclude = mask_adm_death | mask_other
keep = ~exclude

# ── Report ──
n_total = len(df0)
n_excl_adm = mask_adm_death.sum()
n_excl_other = mask_other.sum()
n_excl_both = (mask_adm_death & mask_other).sum()
n_excl_total = exclude.sum()
n_remaining = keep.sum()

report = f"""### Exclusion Report

| Criterion | n excluded |
|---|---:|
| `tr_outcome_adm_discharge_adm_reasons == 1` & death time ≤ 7 days | {n_excl_adm} |
| `tr_outcome_other == 1` (any time) | {n_excl_other} |
| Both criteria (overlap) | {n_excl_both} |
| **Total unique excluded** | **{n_excl_total}** |
| **Remaining observations** | **{n_remaining}** / {n_total} |
"""
display(Markdown(report))

# ── Apply filter to all imputation lists and outcome arrays ──
imputations_list_mar26 = [df.loc[keep].reset_index(drop=True) for df in imputations_list_mar26]
y_surv_readm_list_mar26 = [y[keep] for y in y_surv_readm_list_corrected]
y_surv_death_list_mar26 = [y[keep] for y in y_surv_death_list]
y_surv_readm_list_corrected_mar26 = [y[keep] for y in y_surv_readm_list_corrected]

# Single (non-list) outcome arrays for convenience
y_surv_readm_mar26 = y_surv_readm[keep]
y_surv_death_mar26 = y_surv_death[keep]

# Rebuild eval time grids on the filtered data
event_times_readm_mar26 = y_surv_readm_mar26["time"][y_surv_readm_mar26["event"]]
event_times_death_mar26 = y_surv_death_mar26["time"][y_surv_death_mar26["event"]]

times_eval_readm_mar26 = np.unique(
    np.quantile(event_times_readm_mar26, np.linspace(0.05, 0.95, 50))
)
times_eval_death_mar26 = np.unique(
    np.quantile(event_times_death_mar26, np.linspace(0.05, 0.95, 50))
)

Exclusion Report

Criterion n excluded
tr_outcome_adm_discharge_adm_reasons == 1 & death time ≤ 7 days 137
tr_outcome_other == 1 (any time) 215
Both criteria (overlap) 0
Total unique excluded 352
Remaining observations 88152 / 88504

Train / test split (80/20)

  1. Sets a fixed random seed to make the 80/20 split exactly reproducible.

  2. Verifies required datasets exist (features and survival outcomes) before doing anything.

  3. Creates a “death-corrected” outcome list if it was not already available.

  4. Derives stratification labels from treatment plan and completion categories plus readmission/death events.

  5. Uses a step-down strategy if some strata are too rare, simplifying the stratification to keep it feasible.

  6. Caches a “full snapshot” of all imputations and outcomes so reruns don’t silently change the split.

  7. Checks row alignment so every imputation and every outcome has the same number of observations.

  8. Optionally checks stability across imputations for plan/completion columns (should not vary much).

  9. Loads split indices from disk when available, ensuring the exact same train/test split across sessions.

  10. Builds train/test datasets consistently for all imputations, then runs strict diagnostics to confirm balance.

Code
#@title 🧪 / 🎓 Reproducible 80/20 split before ML (integrated + idempotent + persisted)
# Stratification hierarchy:
#   1) plan + completion + readm_event + death_event
#   2) mixed fallback for rare full strata (<2) -> plan + readm + death
#   3) full fallback -> plan + readm + death

import numpy as np
import pandas as pd
from pathlib import Path
from sklearn.model_selection import train_test_split
from IPython.display import display, Markdown

SEED = 2125
TEST_SIZE = 0.20
FORCE_RESPLIT = False          # True to force new split
STRICT_SPLIT_CHECKS = True     # CI-style hard checks
MAX_EVENT_GAP = 0.01           # 1% tolerance
PERSIST_SPLIT_INDICES = True

# --- Project-root anchored paths ---
from pathlib import Path

def find_project_root(markers=("AGENTS.md", ".git")):
    try:
        cur = Path.cwd().resolve()
    except OSError as e:
        raise RuntimeError(
            "Invalid working directory. Run this notebook from inside the project folder."
        ) from e

    for p in (cur, *cur.parents):
        if any((p / m).exists() for m in markers):
            return p

    raise RuntimeError(
        f"Could not locate project root starting from {cur}. "
        f"Expected one of markers: {markers}."
    )

PROJECT_ROOT = find_project_root()
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_FILE = OUT_DIR / f"readm_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.npz"

def nb_print_md(msg):
    display(Markdown(str(msg)))

nb_print_md(f"**Project root:** `{PROJECT_ROOT}`")


# ---------- Requirements ----------
required = [
    "imputations_list_mar26",
    "y_surv_readm_list_mar26",
    "y_surv_readm_list_corrected_mar26",
    "y_surv_death_list_mar26",
]
missing = [v for v in required if v not in globals()]
if missing:
    raise ValueError(f"Missing required objects: {missing}")

if "y_surv_death_list_corrected_mar26" not in globals():
    y_surv_death_list_corrected_mar26 = [y.copy() for y in y_surv_death_list_mar26]

# ---------- Helpers ----------
def get_plan_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "plan_type_corr_pg_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "plan_type_corr_m_pr" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pr"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_pg_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "plan_type_corr_m_pai" in df.columns:
        labels[pd.to_numeric(df["plan_type_corr_m_pai"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    return labels

def get_completion_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if "tr_outcome_referral" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_referral"], errors="coerce").fillna(0).to_numpy() == 1] = 1
    if "tr_outcome_dropout" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_dropout"], errors="coerce").fillna(0).to_numpy() == 1] = 2
    if "tr_outcome_adm_discharge_rule_violation_undet" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_rule_violation_undet"], errors="coerce").fillna(0).to_numpy() == 1] = 3
    if "tr_outcome_adm_discharge_adm_reasons" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_adm_discharge_adm_reasons"], errors="coerce").fillna(0).to_numpy() == 1] = 4
    if "tr_outcome_other" in df.columns:
        labels[pd.to_numeric(df["tr_outcome_other"], errors="coerce").fillna(0).to_numpy() == 1] = 5
    return labels

def build_strata(X0, y_readm0, y_death0):
    """
    Build stratification labels with progressive fallback:

    1) full: plan + completion + readmission_event + death_event
    2) mixed: only rare full strata (<2 rows) are replaced by fallback labels
       (plan + readmission_event + death_event)
    3) fallback: plan + readmission_event + death_event for all rows

    Returns:
        strata (np.ndarray), mode (str), readm_evt (np.ndarray), death_evt (np.ndarray)
    """
    plan = get_plan_labels(X0)
    comp = get_completion_labels(X0)
    readm_evt = y_readm0["event"].astype(int)
    death_evt = y_death0["event"].astype(int)

    full = pd.Series(plan.astype(str) + "_" + comp.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    if full.value_counts().min() >= 2:
        return full.to_numpy(), "full(plan+completion+readm+death)", readm_evt, death_evt

    fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))
    mixed = full.copy()
    rare = mixed.map(mixed.value_counts()) < 2
    mixed[rare] = fb[rare]
    if mixed.value_counts().min() >= 2:
        return mixed.to_numpy(), "mixed(rare->plan+readm+death)", readm_evt, death_evt

    if fb.value_counts().min() >= 2:
        return fb.to_numpy(), "fallback(plan+readm+death)", readm_evt, death_evt

    raise ValueError("Could not build stratification labels with >=2 rows per stratum.")

def split_df_list(df_list, tr_idx, te_idx):
    tr = [df.iloc[tr_idx].reset_index(drop=True).copy() for df in df_list]
    te = [df.iloc[te_idx].reset_index(drop=True).copy() for df in df_list]
    return tr, te

def split_surv_list(y_list, tr_idx, te_idx):
    tr = [y[tr_idx].copy() for y in y_list]
    te = [y[te_idx].copy() for y in y_list]
    return tr, te

# ---------- Cache full data once (idempotent re-runs) ----------
if "_split_cache_readm_mar26" not in globals():
    _split_cache_readm_mar26 = {}
cache = _split_cache_readm_mar26

if FORCE_RESPLIT:
    cache.pop("idx", None)

if FORCE_RESPLIT or "full" not in cache:
    cache["full"] = {
        "X": [df.reset_index(drop=True).copy() for df in imputations_list_mar26],
        "y_readm": [y.copy() for y in y_surv_readm_list_mar26],
        "y_readm_corr": [y.copy() for y in y_surv_readm_list_corrected_mar26],
        "y_death": [y.copy() for y in y_surv_death_list_mar26],
        "y_death_corr": [y.copy() for y in y_surv_death_list_corrected_mar26],
    }

full = cache["full"]

# ---------- Consistency checks ----------
n_imp = len(full["X"])
n = len(full["X"][0])

if any(len(df) != n for df in full["X"]):
    raise ValueError("Row mismatch inside full X list.")

for name, obj in [
    ("y_readm", full["y_readm"]),
    ("y_readm_corr", full["y_readm_corr"]),
    ("y_death", full["y_death"]),
    ("y_death_corr", full["y_death_corr"]),
]:
    if len(obj) != n_imp:
        raise ValueError(f"{name} length ({len(obj)}) != n_imputations ({n_imp})")
    if any(len(y) != n for y in obj):
        raise ValueError(f"Row mismatch between X and {name}.")

# ---------- Optional diagnostic: plan/completion consistency across imputations ----------
plan_comp_cols = [
    c for c in [
        "plan_type_corr_pg_pr",
        "plan_type_corr_m_pr",
        "plan_type_corr_pg_pai",
        "plan_type_corr_m_pai",
        "tr_outcome_referral",
        "tr_outcome_dropout",
        "tr_outcome_adm_discharge_rule_violation_undet",
        "tr_outcome_adm_discharge_adm_reasons"#,
        #"tr_outcome_other", #2026-03-26: Excluded from consistency check since it was used as an exclusion criterion and thus may differ by design across imputations
    ] if c in full["X"][0].columns
]

max_diff_rows = 0
if plan_comp_cols:
    base_pc = full["X"][0][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
    for i in range(1, n_imp):
        cur_pc = full["X"][i][plan_comp_cols].astype("string").fillna("__NA__").reset_index(drop=True)
        diff_rows = int((base_pc != cur_pc).any(axis=1).sum())
        max_diff_rows = max(max_diff_rows, diff_rows)

# ---------- Try loading indices from disk ----------
loaded_from_disk = False
if PERSIST_SPLIT_INDICES and (not FORCE_RESPLIT) and SPLIT_FILE.exists() and ("idx" not in cache):
    z = np.load(SPLIT_FILE, allow_pickle=False)
    tr = z["train_idx"].astype(int)
    te = z["test_idx"].astype(int)
    n_disk = int(z["n_full"][0]) if "n_full" in z else n
    if n_disk == n and tr.max() < n and te.max() < n:
        cache["idx"] = (np.sort(tr), np.sort(te))
        cache["strat_mode"] = str(z["strat_mode"][0]) if "strat_mode" in z else "loaded_from_disk"
        loaded_from_disk = True

# ---------- Compute or reuse split indices ----------
if FORCE_RESPLIT or "idx" not in cache:
    strata_used, strat_mode, readm_evt_all, death_evt_all = build_strata(
        full["X"][0], full["y_readm"][0], full["y_death"][0]
    )
    idx = np.arange(n)
    train_idx, test_idx = train_test_split(
        idx, test_size=TEST_SIZE, random_state=SEED, shuffle=True, stratify=strata_used
    )
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    cache["idx"] = (train_idx, test_idx)
    cache["strat_mode"] = strat_mode

    if PERSIST_SPLIT_INDICES:
        np.savez_compressed(
            SPLIT_FILE,
            train_idx=train_idx,
            test_idx=test_idx,
            n_full=np.array([n], dtype=int),
            seed=np.array([SEED], dtype=int),
            test_size=np.array([TEST_SIZE], dtype=float),
            strat_mode=np.array([strat_mode], dtype="U64"),
        )
else:
    train_idx, test_idx = cache["idx"]
    train_idx = np.sort(train_idx)
    test_idx = np.sort(test_idx)
    readm_evt_all = full["y_readm"][0]["event"].astype(int)
    death_evt_all = full["y_death"][0]["event"].astype(int)

# ---------- Build train/test from full snapshot every run ----------
imputations_list_mar26_train, imputations_list_mar26_test = split_df_list(full["X"], train_idx, test_idx)

y_surv_readm_list_train, y_surv_readm_list_test = split_surv_list(full["y_readm"], train_idx, test_idx)
y_surv_readm_list_corrected_train, y_surv_readm_list_corrected_test = split_surv_list(full["y_readm_corr"], train_idx, test_idx)

y_surv_death_list_train, y_surv_death_list_test = split_surv_list(full["y_death"], train_idx, test_idx)
y_surv_death_list_corrected_train, y_surv_death_list_corrected_test = split_surv_list(full["y_death_corr"], train_idx, test_idx)

# Downstream code uses TRAIN only
imputations_list_mar26 = imputations_list_mar26_train
y_surv_readm_list = y_surv_readm_list_train
y_surv_readm_list_corrected = y_surv_readm_list_corrected_train
y_surv_death_list = y_surv_death_list_train
y_surv_death_list_corrected = y_surv_death_list_corrected_train

# ---------- Diagnostics + strict checks ----------
strata_diag, strat_mode_diag, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])
sdiag = pd.Series(strata_diag)
train_strata = set(sdiag.iloc[train_idx].unique())
test_strata = set(sdiag.iloc[test_idx].unique())
missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

readm_gap = abs(readm_evt_all[train_idx].mean() - readm_evt_all[test_idx].mean())
death_gap = abs(death_evt_all[train_idx].mean() - death_evt_all[test_idx].mean())

# full-strata rarity report (before fallback)
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)
vc_full = strata_full.value_counts()
rare_rows = int((strata_full.map(vc_full) < 2).sum())

if STRICT_SPLIT_CHECKS:
    assert len(np.intersect1d(train_idx, test_idx)) == 0, "Train/Test index overlap detected."
    assert (len(train_idx) + len(test_idx)) == n, "Train/Test sizes do not sum to n."
    assert len(missing_in_test) == 0, f"Strata missing in test: {missing_in_test}"
    assert len(missing_in_train) == 0, f"Strata missing in train: {missing_in_train}"
    assert readm_gap < MAX_EVENT_GAP, f"Readmission rate imbalance > {MAX_EVENT_GAP:.0%} (gap={readm_gap:.4f})"
    assert death_gap < MAX_EVENT_GAP, f"Death rate imbalance > {MAX_EVENT_GAP:.0%} (gap={death_gap:.4f})"

# ---------- Summary ----------
nb_print_md(f"**Loaded indices from disk:** `{loaded_from_disk}`")
nb_print_md(f"**Split file:** `{SPLIT_FILE}`")
nb_print_md(f"**Split mode used:** `{cache.get('strat_mode', strat_mode_diag)}`")
nb_print_md(f"**Plan/completion diff rows across imputations (max vs imp0):** `{max_diff_rows}`")
nb_print_md(f"**Full strata count:** `{vc_full.size}` | **Min full stratum size:** `{int(vc_full.min())}` | **Rows in rare full strata (<2):** `{rare_rows}`")
nb_print_md(f"**Train/Test sizes:** `{len(train_idx)}` ({len(train_idx)/n:.1%}) / `{len(test_idx)}` ({len(test_idx)/n:.1%})")
nb_print_md(
    "**Readmission rate all/train/test:** "
    f"`{readm_evt_all.mean():.3%}` / `{readm_evt_all[train_idx].mean():.3%}` / `{readm_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    "**Death rate all/train/test:** "
    f"`{death_evt_all.mean():.3%}` / `{death_evt_all[train_idx].mean():.3%}` / `{death_evt_all[test_idx].mean():.3%}`"
)
nb_print_md(
    f"**Strata in train/test:** `{len(train_strata)}` / `{len(test_strata)}` | "
    f"**Missing train→test:** `{len(missing_in_test)}` | **Missing test→train:** `{len(missing_in_train)}`"
)

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Loaded indices from disk: True

Split file: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\readm_split_seed2125_test20_mar26.npz

Split mode used: fallback(plan+readm+death)

Plan/completion diff rows across imputations (max vs imp0): 0

Full strata count: 95 | Min full stratum size: 1 | Rows in rare full strata (<2): 4

Train/Test sizes: 70521 (80.0%) / 17631 (20.0%)

Readmission rate all/train/test: 21.622% / 21.621% / 21.627%

Death rate all/train/test: 4.310% / 4.309% / 4.311%

Strata in train/test: 20 / 20 | Missing train→test: 0 | Missing test→train: 0

Code
# counts per stratum in train/test
train_counts = sdiag.iloc[train_idx].value_counts()
test_counts  = sdiag.iloc[test_idx].value_counts()

min_train = int(train_counts.min())
min_test  = int(test_counts.min())

nb_print_md(f"**Min stratum count in TRAIN (used strata):** `{min_train}`")
nb_print_md(f"**Min stratum count in TEST (used strata):** `{min_test}`")

# strata that got 0 in test or 0 in train
zero_in_test = sorted(set(train_counts.index) - set(test_counts.index))
zero_in_train = sorted(set(test_counts.index) - set(train_counts.index))

nb_print_md(f"**Strata with 0 in TEST:** `{len(zero_in_test)}`")
nb_print_md(f"**Strata with 0 in TRAIN:** `{len(zero_in_train)}`")

# show examples with their full-data counts
if len(zero_in_test) > 0:
    ex = zero_in_test[:10]
    nb_print_md(f"**Examples 0 in TEST (up to 10):** `{ex}`")
    nb_print_md(f"**Full-data counts:** `{[int(sdiag.value_counts()[k]) for k in ex]}`")

Min stratum count in TRAIN (used strata): 18

Min stratum count in TEST (used strata): 5

Strata with 0 in TEST: 0

Strata with 0 in TRAIN: 0

Code
strata_full = pd.Series(
    get_plan_labels(full["X"][0]).astype(str) + "_" +
    get_completion_labels(full["X"][0]).astype(str) + "_" +
    full["y_readm"][0]["event"].astype(int).astype(str) + "_" +
    full["y_death"][0]["event"].astype(int).astype(str)
)

vc = strata_full.value_counts()
display(Markdown(f"**# full strata:** `{vc.size}`"))
display(Markdown(f"**Min stratum size (full):** `{int(vc.min())}`"))
display(Markdown(f"**# strata with count < 2:** `{int((vc < 2).sum())}`"))

# full strata: 95

Min stratum size (full): 1

# strata with count < 2: 4

Code
plan = get_plan_labels(full["X"][0])
readm_evt = full["y_readm"][0]["event"].astype(int)
death_evt = full["y_death"][0]["event"].astype(int)

fb = pd.Series(plan.astype(str) + "_" + readm_evt.astype(str) + "_" + death_evt.astype(str))

rare_mask = strata_full.map(strata_full.value_counts()) < 2
n_rare = int(rare_mask.sum())

display(Markdown(f"**Rows in rare full-strata (<2):** `{n_rare}`"))
if n_rare > 0:
    display(Markdown(
        f"**Rare rows proportion:** `{n_rare/len(strata_full):.3%}`"
    ))

Rows in rare full-strata (<2): 4

Rare rows proportion: 0.005%

Code
# Use the actual stratification mode that was used to split
strata_used, strat_mode, _, _ = build_strata(full["X"][0], full["y_readm"][0], full["y_death"][0])

s = pd.Series(strata_used)
train_strata = set(s.iloc[train_idx].unique())
test_strata = set(s.iloc[test_idx].unique())

missing_in_test = sorted(train_strata - test_strata)
missing_in_train = sorted(test_strata - train_strata)

display(Markdown(f"**Strata used:** `{strat_mode}`"))
display(Markdown(f"**# strata in train:** `{len(train_strata)}` | **# strata in test:** `{len(test_strata)}`"))
display(Markdown(f"**Strata present in train but missing in test:** `{len(missing_in_test)}`"))
display(Markdown(f"**Strata present in test but missing in train:** `{len(missing_in_train)}`"))

Strata used: fallback(plan+readm+death)

# strata in train: 20 | # strata in test: 20

Strata present in train but missing in test: 0

Strata present in test but missing in train: 0

Code
from pathlib import Path
import pandas as pd
import numpy as np
from IPython.display import display, Markdown

PROJECT_ROOT = find_project_root()   # no hardcoded absolute path
OUT_DIR = PROJECT_ROOT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)

SPLIT_PARQUET = OUT_DIR / f"readm_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"

split_df = pd.DataFrame({
    "row_id": np.arange(n),
    "is_train": np.isin(np.arange(n), train_idx)
})

split_df.to_parquet(SPLIT_PARQUET, index=False)

display(Markdown(f"**Project root:** `{PROJECT_ROOT}`"))
display(Markdown(f"**Saved split to:** `{SPLIT_PARQUET}`"))

Project root: G:\My Drive\Alvacast\SISTRAT 2023\cons

Saved split to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\readm_split_seed2125_test20_mar26.parquet

Code
import pandas as pd
import numpy as np

SEED = 2125
TEST_SIZE = 0.20

# Use the first imputation (complete data)
X_full = full["X"][0]
y_death_full = full["y_death"][0]

# Find admission age column
age_col = 'adm_age_rec3' if 'adm_age_rec3' in X_full.columns else \
          [c for c in X_full.columns if 'adm_age' in c][0]

# Create the 4-column split file
split_export = pd.DataFrame({
    'row_id': np.arange(1, len(X_full) + 1),  # 1-based for R
    'is_train': [i in train_idx for i in range(len(X_full))],
    'death_time_from_disch_m': np.round(y_death_full['time'], 2),
    'adm_age_rec3': X_full[age_col]
})

# Export
fname = f"readm_split_seed{SEED}_test{int(TEST_SIZE*100)}_mar26.parquet"
split_export.to_parquet(fname, index=False)

nb_print(f"Exported: {fname}")
nb_print(f"Total: {len(split_export)} rows")
nb_print(f"Train: {split_export['is_train'].sum()} ({100*split_export['is_train'].mean():.1f}%)")
nb_print(f"Test: {(~split_export['is_train']).sum()} ({100*(~split_export['is_train']).mean():.1f}%)")
nb_print(f"\nFirst 5 rows:")
nb_print(split_export.head())
Exported: readm_split_seed2125_test20_mar26.parquet
Total: 88152 rows
Train: 70521 (80.0%)
Test: 17631 (20.0%)
First 5 rows:
   row_id  is_train  death_time_from_disch_m  adm_age_rec3
0       1      True                    68.97         31.53
1       2      True                    81.32         20.61
2       3      True                   116.74         42.52
3       4     False                    91.97         60.61
4       5      True                    31.03         45.08
Code
df0 = imputations_list_mar26[0]

# Calculate exactly what you need
mean_age = df0["adm_age_rec3"].mean()
count_foreign = (df0["national_foreign"] == 1).sum()

# Print results
nb_print(f"Mean of adm_age_rec3: {mean_age:.4f}")
nb_print(f"Count of national_foreign == 1: {count_foreign}")
Mean of adm_age_rec3: 35.7256
Count of national_foreign == 1: 453

We cleaned our environment safely so that:

  • Old models

  • Temporary objects

  • Large intermediate datasets

do not interfere with the next modeling block.

Code
# Safe cleanup before Readmission XGBoost blocks
import types
import gc

# Ensure logger exists (some target cells expect it)
if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Compatibility: one Optuna/Bootstrap block checks jan26 naming
#if "imputations_list_jan26" not in globals() and "imputations_list_mar26" in globals():
#    imputations_list_jan26 = imputations_list_mar26

KEEP = {
    "nb_print", "study",
    "imputations_list_mar26", #"imputations_list",#"imputations_list_jan26"
    "X_train", "y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list",
    # Optional plot config objects:
    "plt", "sns", "matplotlib", "mpl", "rcParams", "PROJECT_ROOT"
}

# ensure both variants are kept
KEEP.update({
    "y_surv_readm_list_corrected_mar26", "y_surv_readm_list_corrected",
    "y_surv_readm_list_mar26", "y_surv_death_list_mar26"
})

# after cleanup, sanity-check alignment before tuning
if "imputations_list_mar26" in globals() and "y_surv_readm_list_corrected" in globals():
    assert len(imputations_list_mar26[0]) == len(y_surv_readm_list_corrected[0]), \
        f"Row mismatch: X={len(imputations_list_mar26[0])}, y={len(y_surv_readm_list_corrected[0])}"
        
for name, obj in list(globals().items()):
    if name in KEEP or name.startswith("_"):
        continue
    if isinstance(obj, types.ModuleType):   # keep imports
        continue
    if callable(obj):                        # keep functions/classes
        continue
    del globals()[name]

gc.collect()

required = ["y_surv_readm_list_corrected", "y_surv_readm_list", "y_surv_death_list"]
missing = [x for x in required if x not in globals()]
nb_print("Missing required objects:", missing)
Missing required objects: []

ML

Advanced Survival Modeling: XGBoost & Stratified Evaluation

In this section, we transition to a Gradient Boosted Decision Tree (GBDT) framework using XGBoost. This serves as a robust non-linear benchmark to complement the neural network analysis for low-event survival data (approximately 4% death events).

Methodological Framework

  • Cox-Objective Boosting: We use the survival:cox objective, which optimizes the Cox partial log-likelihood within a boosting architecture. This enables flexible non-linear risk modeling and interaction learning while retaining the proportional hazards formulation.

  • 5-Fold Cross-Validation with Stratification (Death Model): We use 5-fold cross-validation with stratification to preserve key data structure across folds. In the death model, stratification is based on a combined label of treatment plan type and event status, with fallback to simpler stratification when rare strata make 5-fold splitting infeasible. (The readmission model currently uses plan-type-only stratification.)

  • Censoring-Aware Evaluation: Hyperparameter selection is based on Uno’s C-Index (IPCW), which is appropriate under right censoring. We also report the Integrated Brier Score (IBS) as a complementary calibration/overall prediction error metric.

Hyperparameter Optimization

Given the low event rate, we run a randomized search over a dense parameter grid. The search emphasizes regularization and tree-complexity controls (min_child_weight, gamma, reg_alpha, reg_lambda) to improve generalization and reduce overfitting.

Breslow Estimation

Because XGBoost with survival:cox outputs relative risk scores, we estimate the baseline hazard/survival using the Breslow estimator to derive absolute survival probabilities 𝑆(𝑡∣𝑥), enabling time-specific calibration metrics such as IBS.

Code
#@title ⚡ XGBoost Readmission Robust Tuning (100 Iterations, CPU Only)
import numpy as np
import pandas as pd
import xgboost as xgb
from sklearn.model_selection import StratifiedKFold, ParameterSampler
from sksurv.metrics import concordance_index_ipcw
import time
import gc
import os
from datetime import datetime
import warnings
from pathlib import Path

# Suppress warnings
warnings.filterwarnings("ignore")

# Fallback in case nb_print is not defined globally
if 'nb_print' not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

# Start Timer
total_start_time = time.time()

# --- CPU CONFIGURATION ---
# Calculate total cores minus 2 (ensuring at least 1 core is used)
N_CORES = max(1, os.cpu_count() - 2)
nb_print(f"⚙️ Parallel Execution Configured: Using {N_CORES} CPU cores.")

# --- 1. SETUP & DATA ---
nb_print("Preparing data for Robust XGBoost Tuning (Readmission)...")

try:
    # imputations_list_mar26  → overwritten to train split (70,510) by split cell ✅
    # y_surv_readm_list_corrected → overwritten to train split (70,510) by split cell ✅
    # y_surv_readm_list_corrected_mar26 → still full (88,138), do NOT use here
    df_tune      = imputations_list_mar26[0].copy()
    y_tune_struct = y_surv_readm_list_corrected[0]   # train split (70,510)

    # Row alignment checks
    assert len(df_tune) == len(y_tune_struct), (
        f"X/y_readm mismatch: df_tune={len(df_tune)}, y_tune_struct={len(y_tune_struct)}"
    )
    assert len(df_tune) == len(y_surv_death_list[0]), (
        f"X/y_death mismatch: df_tune={len(df_tune)}, y_death={len(y_surv_death_list[0])}"
    )
    assert len(imputations_list_mar26) == len(y_surv_readm_list_corrected), (
        f"Imputation list length mismatch: "
        f"X={len(imputations_list_mar26)}, y={len(y_surv_readm_list_corrected)}"
    )

    nb_print(f"  Imputations   : {len(imputations_list_mar26)}")
    nb_print(f"  Data Shape    : {df_tune.shape}")
    nb_print(f"  Target        : Readmission (Events: {y_tune_struct['event'].sum()}, "
             f"Rate: {y_tune_struct['event'].mean():.3%})")
    nb_print(f"  Split         : train (mar26, exclusions applied)")

except Exception as e:
    raise ValueError(f"Data Error: {e}. Please run data loading and split cells first.")

# --- 2. STRATIFICATION HELPER ---
def get_stratification_labels(df):
    labels = np.zeros(len(df), dtype=int)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

strat_labels = get_stratification_labels(df_tune)
y_xgb_label = np.where(y_tune_struct['event'], y_tune_struct['time'], -y_tune_struct['time'])

# --- 3. EXHAUSTIVE SEARCH SPACE ---
param_grid = {
    'learning_rate': [0.005, 0.01, 0.02, 0.05, 0.1],
    'max_depth': [3, 4, 5, 6, 8],
    'min_child_weight': [1, 5, 10, 20, 50],
    'subsample': [0.6, 0.7, 0.8, 0.9],
    'colsample_bytree': [0.5, 0.6, 0.7, 0.8],
    'reg_alpha': [0, 0.1, 1, 5, 10],
    'reg_lambda': [0.1, 1, 5, 10, 20],
    'gamma': [0, 0.1, 0.5, 1, 2]
}

N_ITER = 100
param_list = list(ParameterSampler(param_grid, n_iter=N_ITER, random_state=2125))

# --- 4. TUNING LOOP ---
nb_print(f"\n🚀 Starting Exhaustive Search ({N_ITER} combos)...")
nb_print(f"  Strategy: 5-Fold Stratified CV")
nb_print(f"  Metric: Uno's C-Index (IPCW)")

results = []

for i, params in enumerate(param_list):
    iter_start = time.time()

    # Fixed Parameters & Configuration (Strictly CPU)
    params['objective'] = 'survival:cox'
    params['eval_metric'] = 'cox-nloglik'
    params['tree_method'] = 'hist'
    params['seed'] = 2125            # Explicit seed for XGBoost reproducibility
    params['nthread'] = N_CORES    # Explicit CPU threading (Cores - 2)
    params['device'] = 'cpu'       # Hardcoded to CPU, removing GPU overrides
    params['verbosity'] = 0

    # 5-Fold matching your methodology
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)
    fold_scores = []

    for train_idx, val_idx in skf.split(df_tune, strat_labels):
        X_tr, X_va = df_tune.iloc[train_idx], df_tune.iloc[val_idx]
        y_tr_xgb, y_va_xgb = y_xgb_label[train_idx], y_xgb_label[val_idx]
        y_tr_struct, y_va_struct = y_tune_struct[train_idx], y_tune_struct[val_idx]

        dtrain = xgb.DMatrix(X_tr, label=y_tr_xgb)
        dval = xgb.DMatrix(X_va, label=y_va_xgb)

        model = xgb.train(params, dtrain, num_boost_round=1500,
                          evals=[(dval, 'val')], early_stopping_rounds=30,
                          verbose_eval=False)

        risk_scores = model.predict(dval)

        try:
            c_val = concordance_index_ipcw(y_tr_struct, y_va_struct, risk_scores)[0]
            fold_scores.append(c_val)
        except:
            from sksurv.metrics import concordance_index_censored
            c_val = concordance_index_censored(y_va_struct['event'], y_va_struct['time'], risk_scores)[0]
            fold_scores.append(c_val)

        # Clean Memory
        del model, dtrain, dval, risk_scores
        gc.collect()

    # Average & Store
    avg_score = np.mean(fold_scores)
    std_score = np.std(fold_scores)
    results.append({**params, 'Unos_C_Index': avg_score, 'Std_Dev': std_score})

    if (i+1) % 5 == 0:
        elapsed_min = (time.time() - total_start_time) / 60
        best_so_far = max([r['Unos_C_Index'] for r in results])
        nb_print(f"  [{i+1}/{N_ITER}] Best: {best_so_far:.4f} | Current: {avg_score:.4f} | Elapsed: {elapsed_min:.2f} min")

# --- 5. FINALIZE & EXPORT ---
total_duration_min = (time.time() - total_start_time) / 60
nb_print(f"\n🏁 Total Execution Time: {total_duration_min:.2f} minutes")

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")

OUT_DIR = Path(PROJECT_ROOT) / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)


df_results = pd.DataFrame(results).sort_values(by='Unos_C_Index', ascending=False)
best_config = df_results.iloc[0].to_dict()

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
filename = OUT_DIR / f"XGB_Readmission_Robust_Tuning_5Fold_{timestamp_str}_mar26.csv"
df_results.to_csv(filename, index=False)

nb_print(f"\n🏆 Tuning Complete!")
nb_print(f"  Best C-Index: {best_config['Unos_C_Index']:.4f}")
nb_print(f"Saved to: {filename}")
⚙️ Parallel Execution Configured: Using 30 CPU cores.
Preparing data for Robust XGBoost Tuning (Readmission)...
  Imputations   : 5
  Data Shape    : (70521, 56)
  Target        : Readmission (Events: 15247, Rate: 21.621%)
  Split         : train (mar26, exclusions applied)
🚀 Starting Exhaustive Search (100 combos)...
  Strategy: 5-Fold Stratified CV
  Metric: Uno's C-Index (IPCW)
  [5/100] Best: 0.6183 | Current: 0.6182 | Elapsed: 1.05 min
  [10/100] Best: 0.6183 | Current: 0.6138 | Elapsed: 2.25 min
  [15/100] Best: 0.6193 | Current: 0.6182 | Elapsed: 3.95 min
  [20/100] Best: 0.6193 | Current: 0.6174 | Elapsed: 5.17 min
  [25/100] Best: 0.6193 | Current: 0.6171 | Elapsed: 6.54 min
  [30/100] Best: 0.6193 | Current: 0.6182 | Elapsed: 7.90 min
  [35/100] Best: 0.6193 | Current: 0.6164 | Elapsed: 9.49 min
  [40/100] Best: 0.6193 | Current: 0.6182 | Elapsed: 11.06 min
  [45/100] Best: 0.6193 | Current: 0.6143 | Elapsed: 12.59 min
  [50/100] Best: 0.6193 | Current: 0.6169 | Elapsed: 13.61 min
  [55/100] Best: 0.6193 | Current: 0.6163 | Elapsed: 15.17 min
  [60/100] Best: 0.6195 | Current: 0.6189 | Elapsed: 17.21 min
  [65/100] Best: 0.6195 | Current: 0.6170 | Elapsed: 19.48 min
  [70/100] Best: 0.6195 | Current: 0.6168 | Elapsed: 20.76 min
  [75/100] Best: 0.6195 | Current: 0.6179 | Elapsed: 21.90 min
  [80/100] Best: 0.6195 | Current: 0.6178 | Elapsed: 23.56 min
  [85/100] Best: 0.6195 | Current: 0.6147 | Elapsed: 24.95 min
  [90/100] Best: 0.6195 | Current: 0.6176 | Elapsed: 26.64 min
  [95/100] Best: 0.6195 | Current: 0.6178 | Elapsed: 27.98 min
  [100/100] Best: 0.6195 | Current: 0.6175 | Elapsed: 28.83 min
🏁 Total Execution Time: 28.83 minutes
🏆 Tuning Complete!
  Best C-Index: 0.6195
Saved to: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\XGB_Readmission_Robust_Tuning_5Fold_20260305_1002_mar26.csv
Code
nb_print(best_config)
{'subsample': 0.8, 'reg_lambda': 1.0, 'reg_alpha': 0.1, 'min_child_weight': 5, 'max_depth': 5, 'learning_rate': 0.02, 'gamma': 0.5, 'colsample_bytree': 0.5, 'objective': 'survival:cox', 'eval_metric': 'cox-nloglik', 'tree_method': 'hist', 'seed': 2125, 'nthread': 30, 'device': 'cpu', 'verbosity': 0, 'Unos_C_Index': 0.6194559121252876, 'Std_Dev': 0.004722473659877666}
Code
import pandas as pd
from IPython.display import HTML, display

# Reset options so Pandas doesn't force everything
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)
pd.set_option('display.width', None)
pd.set_option('display.max_colwidth', None)

# Convert DataFrame to HTML and wrap in a scrollable div
html_table = df_results.to_html()
scroll_box = f"""
<div style="max-height:500px; max-width:1000px; overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">
{html_table}
</div>
"""
display(HTML(scroll_box))
subsample reg_lambda reg_alpha min_child_weight max_depth learning_rate gamma colsample_bytree objective eval_metric tree_method seed nthread device verbosity Unos_C_Index Std_Dev
56 0.8 1.0 0.1 5 5 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.619456 0.004722
10 0.7 10.0 1.0 20 5 0.020 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.619302 0.005350
59 0.7 1.0 5.0 5 6 0.050 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618856 0.004825
72 0.7 10.0 0.0 5 4 0.020 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618822 0.005858
26 0.7 5.0 0.0 5 6 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618819 0.004419
81 0.8 10.0 5.0 1 6 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618767 0.005213
91 0.7 1.0 1.0 5 4 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618627 0.004390
93 0.9 0.1 5.0 5 6 0.100 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618511 0.005019
71 0.7 5.0 0.1 1 4 0.020 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618387 0.005214
33 0.6 5.0 1.0 5 6 0.010 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618333 0.004503
42 0.6 0.1 0.1 50 6 0.010 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618333 0.004934
65 0.9 0.1 1.0 10 8 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618324 0.004502
6 0.9 5.0 10.0 1 8 0.010 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618311 0.004786
2 0.7 5.0 5.0 20 5 0.010 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618310 0.005008
51 0.6 20.0 0.1 50 6 0.010 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618270 0.005510
14 0.7 1.0 0.1 20 5 0.020 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618236 0.005055
61 0.8 10.0 5.0 10 4 0.020 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618220 0.005668
3 0.6 5.0 1.0 1 4 0.010 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618211 0.005351
43 0.9 5.0 0.0 20 6 0.020 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618195 0.004476
29 0.7 5.0 1.0 50 4 0.050 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618187 0.006251
4 0.9 5.0 0.1 10 6 0.050 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618181 0.004645
39 0.6 1.0 0.0 5 8 0.020 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618180 0.004164
62 0.7 1.0 0.0 20 6 0.005 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618167 0.004228
68 0.8 0.1 0.1 50 3 0.050 0.1 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618080 0.005634
27 0.6 20.0 0.0 1 4 0.020 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.618032 0.005414
31 0.7 20.0 0.0 10 8 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617908 0.004190
48 0.6 1.0 1.0 50 8 0.005 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617891 0.005211
74 0.9 10.0 10.0 10 5 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617887 0.006528
53 0.8 5.0 10.0 50 5 0.020 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617873 0.006048
82 0.9 1.0 5.0 20 8 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617868 0.004563
79 0.7 10.0 0.0 5 6 0.005 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617835 0.003893
50 0.8 5.0 1.0 50 8 0.020 2.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617809 0.004952
17 0.7 5.0 0.0 50 4 0.050 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617788 0.006143
94 0.7 5.0 0.0 50 6 0.005 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617767 0.004982
85 0.9 1.0 5.0 20 4 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617764 0.005557
37 0.6 0.1 0.1 10 4 0.050 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617711 0.005144
78 0.7 20.0 5.0 20 5 0.010 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617631 0.005916
89 0.9 20.0 0.1 50 5 0.050 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617579 0.004104
21 0.9 10.0 10.0 5 3 0.050 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617530 0.005528
22 0.6 20.0 0.0 1 3 0.020 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617513 0.006197
99 0.9 20.0 10.0 10 6 0.100 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617490 0.004414
75 0.8 20.0 10.0 5 8 0.020 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617395 0.004882
45 0.9 20.0 5.0 10 4 0.100 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617370 0.004740
19 0.7 20.0 1.0 10 5 0.100 0.5 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617351 0.005455
35 0.7 1.0 1.0 1 5 0.005 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617281 0.005127
7 0.9 0.1 10.0 10 6 0.100 1.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617225 0.006046
58 0.7 5.0 5.0 5 6 0.005 0.1 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617169 0.005163
38 0.6 10.0 1.0 5 4 0.050 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617127 0.005941
24 0.7 0.1 1.0 50 3 0.050 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617120 0.005803
32 0.6 0.1 0.1 10 6 0.050 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617108 0.004508
98 0.8 20.0 1.0 5 6 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617054 0.005115
64 0.7 20.0 10.0 50 5 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.617021 0.005488
80 0.8 20.0 1.0 5 3 0.020 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616984 0.005461
52 0.9 20.0 5.0 20 6 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616947 0.004582
41 0.7 1.0 5.0 10 8 0.020 0.5 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616944 0.004940
49 0.9 20.0 5.0 50 3 0.100 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616926 0.005534
76 0.6 5.0 0.0 10 5 0.050 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616907 0.005501
57 0.7 1.0 10.0 20 6 0.005 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616848 0.004753
69 0.8 0.1 10.0 5 4 0.050 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616768 0.006026
18 0.6 20.0 5.0 20 3 0.020 0.5 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616764 0.005597
40 0.8 1.0 10.0 1 6 0.005 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616752 0.004915
5 0.6 5.0 0.0 50 3 0.020 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616737 0.005650
11 0.7 10.0 0.0 20 5 0.005 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616673 0.004372
0 0.7 20.0 0.0 10 3 0.100 1.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616663 0.006462
36 0.6 0.1 10.0 10 6 0.005 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616648 0.004916
90 0.8 5.0 0.1 5 3 0.050 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616579 0.005454
92 0.7 20.0 1.0 5 3 0.050 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616578 0.005616
86 0.9 10.0 1.0 10 5 0.005 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616502 0.004589
63 0.9 0.1 10.0 50 6 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616435 0.004754
34 0.9 5.0 1.0 5 5 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616422 0.005025
66 0.6 0.1 5.0 50 6 0.050 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616374 0.004808
54 0.8 20.0 0.1 1 8 0.050 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616321 0.005404
23 0.8 1.0 1.0 50 8 0.100 0.1 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616234 0.004115
13 0.8 1.0 0.0 50 3 0.010 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616192 0.004989
46 0.8 1.0 10.0 10 8 0.050 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616169 0.006018
25 0.8 20.0 5.0 5 5 0.050 2.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.616017 0.005568
15 0.6 10.0 5.0 5 5 0.005 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615936 0.004430
70 0.8 0.1 5.0 50 3 0.010 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615899 0.005315
87 0.8 5.0 5.0 50 5 0.005 1.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615852 0.004818
77 0.8 10.0 5.0 50 3 0.010 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615828 0.005353
12 0.7 0.1 1.0 5 4 0.005 0.1 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615743 0.004836
30 0.8 1.0 5.0 10 3 0.010 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615738 0.004966
8 0.6 10.0 0.0 50 3 0.010 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615704 0.005449
83 0.6 20.0 5.0 5 5 0.100 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615528 0.006431
60 0.9 0.1 5.0 50 3 0.010 2.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615497 0.005154
95 0.8 20.0 0.0 50 8 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615490 0.005709
96 0.6 10.0 10.0 1 5 0.005 0.0 0.8 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615352 0.004964
47 0.7 0.1 0.0 1 6 0.050 0.5 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615337 0.003581
20 0.7 10.0 10.0 20 5 0.005 1.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615265 0.005248
88 0.8 10.0 10.0 10 5 0.005 0.0 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615254 0.004668
73 0.8 1.0 0.1 50 8 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.615024 0.004906
16 0.6 5.0 5.0 20 3 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.614979 0.007727
97 0.7 20.0 5.0 20 8 0.100 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.614805 0.005893
84 0.8 5.0 10.0 10 3 0.100 2.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.614741 0.006891
1 0.6 1.0 5.0 5 6 0.100 0.1 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.614277 0.005217
44 0.6 1.0 5.0 20 3 0.100 0.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.614258 0.007332
67 0.9 10.0 10.0 1 4 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.613963 0.004614
9 0.7 20.0 1.0 20 8 0.100 2.0 0.7 survival:cox cox-nloglik hist 2125 30 cpu 0 0.613790 0.007101
55 0.8 5.0 0.0 1 3 0.005 0.0 0.6 survival:cox cox-nloglik hist 2125 30 cpu 0 0.612564 0.004760
28 0.6 5.0 5.0 20 3 0.005 0.5 0.5 survival:cox cox-nloglik hist 2125 30 cpu 0 0.612281 0.004844
Code
#@title Optimal XGBoost Configuration (Readmission – Reviewer-Proof Version)

import pandas as pd
from IPython.display import display

cfg = {
    'subsample': 0.8, 'reg_lambda': 1.0, 'reg_alpha': 0.1, 'min_child_weight': 5,
    'max_depth': 5, 'learning_rate': 0.02, 'gamma': 0.5, 'colsample_bytree': 0.5,
    'objective': 'survival:cox', 'eval_metric': 'cox-nloglik', 'tree_method': 'hist',
    'seed': 2125, 'nthread': 30, 'device': 'cpu', 'verbosity': 0,
    'Unos_C_Index': 0.6194559121252876, 'Std_Dev': 0.004722473659877666
}

data = [
    {"Category": "Performance", "Parameter": "Uno's C-Index (IPCW)", "Value": f"{cfg['Unos_C_Index']:.6f}",
     "Description": "Discriminative performance for readmission prediction under right-censoring (5-fold stratified CV)."},

    {"Category": "Stability", "Parameter": "Standard Deviation (CV)", "Value": f"±{cfg['Std_Dev']:.6f}",
     "Description": "Cross-validation variability across the 5 folds."},

    {"Category": "Tree Structure", "Parameter": "max_depth", "Value": cfg["max_depth"],
     "Description": "Maximum depth of a tree, restricting the order of feature interactions."},

    {"Category": "Imbalance Handling", "Parameter": "min_child_weight", "Value": cfg["min_child_weight"],
     "Description": "Minimum sum of instance weight (Hessian) needed in a child node."},

    {"Category": "Boosting", "Parameter": "learning_rate", "Value": cfg["learning_rate"],
     "Description": "Step size shrinkage applied to updates to prevent overfitting."},

    {"Category": "Tree Structure", "Parameter": "gamma", "Value": cfg["gamma"],
     "Description": "Minimum loss reduction required to make a further partition on a leaf node of the tree."},

    {"Category": "Regularization", "Parameter": "reg_alpha (L1)", "Value": cfg["reg_alpha"],
     "Description": "L1 regularization term on weights to encourage sparsity."},

    {"Category": "Regularization", "Parameter": "reg_lambda (L2)", "Value": cfg["reg_lambda"],
     "Description": "L2 regularization term on weights, penalizing large coefficients to stabilize estimates."},

    {"Category": "Stochasticity", "Parameter": "subsample", "Value": cfg["subsample"],
     "Description": "Subsample ratio of the training instances used to grow trees."},

    {"Category": "Stochasticity", "Parameter": "colsample_bytree", "Value": cfg["colsample_bytree"],
     "Description": "Subsample ratio of columns (features) evaluated when constructing each tree."},

    {"Category": "Model Specification", "Parameter": "objective", "Value": cfg["objective"],
     "Description": "Cox proportional hazards objective function for right-censored time-to-event data."},

    {"Category": "Model Specification", "Parameter": "eval_metric", "Value": cfg["eval_metric"],
     "Description": "Negative partial log-likelihood used during model optimization."},

    {"Category": "Computation", "Parameter": "tree_method", "Value": cfg["tree_method"],
     "Description": "Histogram-based tree construction for computational efficiency."},

    {"Category": "Computation", "Parameter": "nthread", "Value": cfg["nthread"],
     "Description": "Parallelized execution utilizing available CPU cores."},

    {"Category": "Reproducibility", "Parameter": "seed", "Value": cfg["seed"],
     "Description": "Fixed random seed ensuring reproducibility of cross-validation splits."},

    {"Category": "Computation", "Parameter": "device", "Value": cfg["device"],
     "Description": "CPU-based computation."}
]

pd.set_option('display.max_colwidth', None)
df_optimal_config = pd.DataFrame(data)
display(df_optimal_config)
Category Parameter Value Description
0 Performance Uno's C-Index (IPCW) 0.619456 Discriminative performance for readmission prediction under right-censoring (5-fold stratified CV).
1 Stability Standard Deviation (CV) ±0.004722 Cross-validation variability across the 5 folds.
2 Tree Structure max_depth 5 Maximum depth of a tree, restricting the order of feature interactions.
3 Imbalance Handling min_child_weight 5 Minimum sum of instance weight (Hessian) needed in a child node.
4 Boosting learning_rate 0.02 Step size shrinkage applied to updates to prevent overfitting.
5 Tree Structure gamma 0.5 Minimum loss reduction required to make a further partition on a leaf node of the tree.
6 Regularization reg_alpha (L1) 0.1 L1 regularization term on weights to encourage sparsity.
7 Regularization reg_lambda (L2) 1.0 L2 regularization term on weights, penalizing large coefficients to stabilize estimates.
8 Stochasticity subsample 0.8 Subsample ratio of the training instances used to grow trees.
9 Stochasticity colsample_bytree 0.5 Subsample ratio of columns (features) evaluated when constructing each tree.
10 Model Specification objective survival:cox Cox proportional hazards objective function for right-censored time-to-event data.
11 Model Specification eval_metric cox-nloglik Negative partial log-likelihood used during model optimization.
12 Computation tree_method hist Histogram-based tree construction for computational efficiency.
13 Computation nthread 30 Parallelized execution utilizing available CPU cores.
14 Reproducibility seed 2125 Fixed random seed ensuring reproducibility of cross-validation splits.
15 Computation device cpu CPU-based computation.

Optuna

To evaluate calibration in the presence of competing risks, we estimated the absolute probability of readmission using a pseudo-conditional Cumulative Incidence Function (CIF). Because standard Aalen-Johansen estimators are computationally intractable within iterative gradient boosting loops, the CIF was approximated by integrating the XGBoost-derived covariate-adjusted cause-specific hazard for readmission with the marginal Kaplan-Meier estimate of overall event-free survival. This approach ensures competing mortality is mathematically accounted for without strictly requiring the simultaneous estimation of all individual-level competing hazards

Following Andersen et al. (2012; doi: 10.1093/ije/dyr213) and Binder et al. (10.1093/bioinformatics/btp088), the CIF was computed by integrating the cause-specific hazard for readmission with the marginal overall event-free survival (eq. 2) [(Cumulative Incidence = Overall Survival × Cause-Specific Hazard)]

Take-Home Messages

  • Phase 1 uses stochastic MI (1 imp/trial) for speed.
  • Phase 2 validates top candidates on all imputations.
  • Multi-objective optimizes C-Index and Brier Score.
  • Competing risks handled via cause-specific hazards.
  • CIF approximated using Aalen-Johansen identity.
  • Final winner selected by pooled MI metrics.
  • Parameter stability verified across imputations.
  • Optuna Pareto front identifies trade-offs.
  • Strategy balances compute cost and MI rigor.
  • Reproducible via saved study objects and seeds.

Assumptions

  • Cause-specific hazard approximates readmission risk.
  • CIF approximation is sufficiently accurate for tuning.
  • One imputation per trial represents MI variance.
  • Top Phase 1 candidates contain the global optimum.
  • IPCW assumptions for censoring hold across imputations.
Code
# @title Phase 1 - Optuna Multi-Objective (Stochastic MI: 1 imputation per trial)
import optuna
import numpy as np
import pandas as pd
import xgboost as xgb
import gc
import os
from sklearn.model_selection import StratifiedKFold
from sksurv.metrics import concordance_index_ipcw, brier_score
import joblib
from datetime import datetime
import warnings
import time
from pathlib import Path

start_time = time.time()

warnings.filterwarnings("ignore")

nb_print("Preparing data for Phase 1 stochastic-MI tuning...")

if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
INPUT_DIR = PROJECT_ROOT / "_input"
INPUT_DIR.mkdir(parents=True, exist_ok=True)

N_CORES = max(1, os.cpu_count() - 2)
nb_print(f"Parallel Execution Configured: Using {N_CORES} CPU cores for Optuna Trials.")

N_trial= 100

# --- Setup ---
if 'imputations_list_mar26' in locals():
    imputations_tune = [df.copy() for df in imputations_list_mar26]
else:
    imputations_tune = [X_train.copy()]

n_imputations = len(imputations_tune)

# Use train-split versions (70521 rows) — NOT the _mar26 full versions (88152 rows)
y_readm_struct_list = [y.copy() for y in y_surv_readm_list_corrected]
y_death_struct_list  = [y.copy() for y in y_surv_death_list]
y_xgb_label_list     = [np.where(y['event'], y['time'], -y['time']) for y in y_surv_readm_list_corrected]

# Safety: all arrays must match imputations_tune row count
assert all(
    len(imp) == len(yr) == len(yd) == len(yl)
    for imp, yr, yd, yl in zip(imputations_tune, y_readm_struct_list, y_death_struct_list, y_xgb_label_list)
), (
    f"Length mismatch! X={len(imputations_tune[0])}, "
    f"y_readm={len(y_readm_struct_list[0])}, "
    f"y_death={len(y_death_struct_list[0])}, "
    f"y_label={len(y_xgb_label_list[0])}"
)

if not (len(y_readm_struct_list) == len(y_death_struct_list) == n_imputations):
    raise ValueError("Mismatch in lengths of imputations and outcome lists.")

def get_stratification_labels(df):
    labels = np.zeros(len(df), dtype=np.int32)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

strat_labels_list = [get_stratification_labels(df) for df in imputations_tune]
y_xgb_label_list = [np.where(y['event'], y['time'], -y['time']) for y in y_readm_struct_list]

EVAL_HORIZONS = [3, 6, 12, 36, 60]

def predict_cif_aalen_johansen_approx(y_tr_readm, y_tr_death, risk_tr, risk_va, eval_times):
    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    time_train = y_tr_readm['time']
    event_any = y_tr_readm['event'] | y_tr_death['event']

    order = np.argsort(time_train)
    t_ord = time_train[order]
    e_any_ord = event_any[order]
    e_readm_ord = y_tr_readm['event'][order]
    risk_tr_ord = risk_tr[order]

    unique_times = np.unique(t_ord[e_any_ord])

    S_all = np.ones(len(unique_times) + 1)
    baseline_hazard_readm = np.zeros(len(unique_times))

    current_S = 1.0
    for i, t in enumerate(unique_times):
        at_risk_mask = t_ord >= t
        n_at_risk_t = np.sum(at_risk_mask)
        events_any_t = np.sum((t_ord == t) & e_any_ord)
        events_readm_t = np.sum((t_ord == t) & e_readm_ord)
        if n_at_risk_t > 0:
            S_all[i + 1] = current_S * (1.0 - events_any_t / n_at_risk_t)
            current_S = S_all[i + 1]
            baseline_hazard_readm[i] = events_readm_t / np.sum(risk_tr_ord[at_risk_mask])

    cif_va = np.zeros((len(risk_va), len(eval_times)))
    for j, eval_t in enumerate(eval_times):
        valid_idx = np.where(unique_times <= eval_t)[0]
        if len(valid_idx) > 0:
            S_all_t_minus = S_all[valid_idx]
            dH_readm = baseline_hazard_readm[valid_idx]
            base_cif_increment = S_all_t_minus * dH_readm
            cif_va[:, j] = risk_va * np.sum(base_cif_increment)

    return 1.0 - cif_va

# --- Objective: one imputation per trial ---
def objective(trial):
    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'tree_method': 'hist',
        'nthread': 1,
        'verbosity': 0,
        'seed': 2125,
        # Both sources agree: low LR wins. Grid best=0.02, Optuna2 Pareto=0.003-0.01
        'learning_rate': trial.suggest_float('learning_rate', 0.002, 0.035, log=True),
        # Grid: 4-6 best. Optuna2: 6-9. Widen slightly.
        'max_depth': trial.suggest_int('max_depth', 4, 9),
        # Grid: 5 best but 1-50 flat. Optuna2: 12-30 preferred. Shift higher.
        'min_child_weight': trial.suggest_int('min_child_weight', 3, 30),
        # Both sources: insensitive, center ~0.75
        'subsample': trial.suggest_float('subsample', 0.6, 0.9),
        # Consistent: 0.45-0.65 sweet spot
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.4, 0.7),
        # Low range works; log-scale for proper exploration
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-3, 2.0, log=True),
        # Grid: 1-10 best. Optuna2: 0.1-5.7. Cover both.
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 10.0, log=True),
        # Grid: 0.0-0.5. Optuna2: 0.03-1.8. Log-scale covers well.
        'gamma': trial.suggest_float('gamma', 0.01, 2.0, log=True),
    }

    imp_idx = trial.number % n_imputations  # balanced stochastic assignment
    trial.set_user_attr("Sampled_Imputation", int(imp_idx + 1))

    df_tune = imputations_tune[imp_idx]
    y_readm_struct = y_readm_struct_list[imp_idx]
    y_death_struct = y_death_struct_list[imp_idx]
    strat_labels = strat_labels_list[imp_idx]
    y_xgb_label = y_xgb_label_list[imp_idx]

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)

    fold_c_indices = []
    fold_ib_scores = []
    fold_global_c_indices = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(df_tune, strat_labels)):
        X_tr, X_va = df_tune.iloc[train_idx], df_tune.iloc[val_idx]
        y_tr_readm_xgb, y_va_readm_xgb = y_xgb_label[train_idx], y_xgb_label[val_idx]

        y_tr_readm_struct, y_va_readm_struct = y_readm_struct[train_idx], y_readm_struct[val_idx]
        y_tr_death_struct = y_death_struct[train_idx]

        dtrain = xgb.DMatrix(X_tr, label=y_tr_readm_xgb)
        dval = xgb.DMatrix(X_va, label=y_va_readm_xgb)

        model = xgb.train(
            params, dtrain,
            num_boost_round=1500,
            evals=[(dval, 'val')],
            early_stopping_rounds=30,
            verbose_eval=False
        )

        risk_tr = model.predict(dtrain)
        risk_va = model.predict(dval)

        h_c_indices = []
        for tau_val in EVAL_HORIZONS:
            try:
                c_val = concordance_index_ipcw(y_tr_readm_struct, y_va_readm_struct, risk_va, tau=tau_val)[0]
                h_c_indices.append(c_val)
            except Exception:
                pass
        avg_c_index = np.mean(h_c_indices) if len(h_c_indices) > 0 else 0.5

        try:
            global_c = concordance_index_ipcw(y_tr_readm_struct, y_va_readm_struct, risk_va)[0]
        except Exception:
            global_c = 0.5
        fold_global_c_indices.append(global_c)

        try:
            surv_probs_va = predict_cif_aalen_johansen_approx(
                y_tr_readm_struct, y_tr_death_struct, risk_tr, risk_va, EVAL_HORIZONS
            )
            _, brier_scores_at_tau = brier_score(
                y_tr_readm_struct, y_va_readm_struct, surv_probs_va, EVAL_HORIZONS
            )
            avg_ibs = np.mean(brier_scores_at_tau)
        except Exception:
            avg_ibs = 0.25

        fold_c_indices.append(avg_c_index)
        fold_ib_scores.append(avg_ibs)

        del model, dtrain, dval, risk_tr, risk_va
        gc.collect()
        #Once Optuna realizes a hyperparameter combination is terrible (C-index < 0.55 after the first fold), 
        # it immediately "kills" the trial. It skips the remaining 4 folds and avoids the heavy math entirely
        if fold_idx >= 1 and np.mean(fold_c_indices) < 0.55:
            raise optuna.TrialPruned()

    trial.set_user_attr("Global_C_Index", float(np.mean(fold_global_c_indices)))
    return float(np.mean(fold_c_indices)), float(np.mean(fold_ib_scores))

study = optuna.create_study(
    directions=['maximize', 'minimize'],
    study_name="XGB_Readm_Optuna_Study_StochMI"
)

nb_print(
    f"Starting Phase 1 stochastic-MI optimization: 1 imputation/trial over {n_imputations} imputations."
)

study.optimize(objective, n_trials=N_trial, n_jobs=N_CORES, show_progress_bar=True)

# Check if it completed
for t in study.trials:
    print(f"Trial {t.number}: values={t.values}, state={t.state}")

nb_print("\nPhase 1 Pareto models:")
for t in study.best_trials:
    nb_print(
        f"Trial {t.number} | Imp={t.user_attrs.get('Sampled_Imputation', 'NA')} | "
        f"C={t.values[0]:.4f} | IBS={t.values[1]:.4f}"
    )

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
study_filename = INPUT_DIR / f"XGB_Readm_Optuna_Study_StochMI_{timestamp_str}_mar26.pkl"
joblib.dump(study, study_filename)
nb_print(f"\nStudy saved: {study_filename}")

end_time = time.time()
elapsed_seconds = end_time - start_time
nb_print(f"Time taken: {elapsed_seconds/60:.2f} minutes")
Preparing data for Phase 1 stochastic-MI tuning...
Parallel Execution Configured: Using 30 CPU cores for Optuna Trials.
[I 2026-03-05 11:01:09,910] A new study created in memory with name: XGB_Readm_Optuna_Study_StochMI
Starting Phase 1 stochastic-MI optimization: 1 imputation/trial over 5 imputations.
100%|██████████| 100/100 [49:59<00:00, 30.00s/it]  
Phase 1 Pareto models:
Trial 0 | Imp=1 | C=0.6454 | IBS=0.1104
Trial 12 | Imp=3 | C=0.6427 | IBS=0.1097
Trial 20 | Imp=1 | C=0.6444 | IBS=0.1099
Trial 33 | Imp=4 | C=0.6454 | IBS=0.1104
Trial 52 | Imp=3 | C=0.6471 | IBS=0.1106
Trial 69 | Imp=5 | C=0.6471 | IBS=0.1111
Trial 73 | Imp=4 | C=0.6437 | IBS=0.1098
Trial 79 | Imp=5 | C=0.6456 | IBS=0.1105
Trial 89 | Imp=5 | C=0.6452 | IBS=0.1101
Study saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_input\XGB_Readm_Optuna_Study_StochMI_20260305_1151_mar26.pkl
Time taken: 49.99 minutes
Code
# Diagnostic: Check what's in your study
nb_print(f"Study object type: {type(study)}")
nb_print(f"Study name: {getattr(study, 'study_name', 'N/A')}")
nb_print(f"Directions: {study.directions if hasattr(study, 'directions') else 'N/A'}")
nb_print(f"Number of best_trials: {len(study.best_trials)}")
nb_print(f"\nFirst 5 trials:")
for t in study.trials[:5]:
    print(f"  Trial {t.number}: values={t.values}, state={t.state}")

nb_print(f"\nTrials with 2 values:")
trials_with_2 = [t for t in study.trials if t.values is not None and len(t.values) == 2]
nb_print(f"  Count: {len(trials_with_2)}")

nb_print(f"\nTrials with ANY values:")
trials_with_values = [t for t in study.trials if t.values is not None]
nb_print(f"  Count: {len(trials_with_values)}")
for t in trials_with_values[:3]:
    nb_print(f"    Trial {t.number}: values={t.values}")
Study object type: <class 'optuna.study.study.Study'>
Study name: XGB_Readm_Optuna_Study_StochMI
Directions: [<StudyDirection.MAXIMIZE: 2>, <StudyDirection.MINIMIZE: 1>]
Number of best_trials: 9
First 5 trials:
Trials with 2 values:
  Count: 100
Trials with ANY values:
  Count: 100
    Trial 0: values=[0.6454298647931938, 0.11044094336696626]
    Trial 1: values=[0.6453057220863746, 0.11119736130918115]
    Trial 2: values=[0.645393026111752, 0.11051291684728756]
Code
# @title Phase 2 - Re-score top Pareto candidates on ALL imputations + final winner
import os
import glob
import re
import joblib
import pandas as pd
import numpy as np
import xgboost as xgb
import gc
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sksurv.metrics import concordance_index_ipcw, brier_score
import time
from collections import defaultdict
from joblib import Parallel, delayed, parallel_backend
from pathlib import Path

start_time = time.time()

nb_print("Phase 2: loading Phase 1 study and rescoring top candidates on all imputations...")

TOP_K = 20  # choose 10-20
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")

PROJECT_ROOT = Path(PROJECT_ROOT).resolve()
INPUT_DIR = PROJECT_ROOT / "_input"
OUTPUT_DIR = PROJECT_ROOT / "_out"
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# ---------------------------
# Open files
# ---------------------------
expected_study_name = "XGB_Readm_Optuna_Study_StochMI"
if 'study' in globals() and getattr(study, 'study_name', None) == expected_study_name:
    nb_print(f"Using in-memory study with study_name='{expected_study_name}'.")
else:
    os.makedirs("_input", exist_ok=True)
    pattern = "_input/XGB_Readm_Optuna_Study_StochMI_*_mar26.pkl"
    files = glob.glob(pattern)
    if not files:
        raise FileNotFoundError(f"No saved study files found matching pattern: {pattern}")

    timestamped_files = []
    for f in files:
        m = re.search(r"XGB_Readm_Optuna_Study_StochMI_(\d{8}_\d{4})\_mar26\.pkl$", f)
        if m:
            ts = m.group(1)
            try:
                dt = datetime.strptime(ts, "%Y%m%d_%H%M")
                timestamped_files.append((dt, f))
            except ValueError:
                pass

    if not timestamped_files:
        raise FileNotFoundError("No study files with valid timestamp found.")

    latest_file = max(timestamped_files, key=lambda x: x[0])[1]
    study = joblib.load(latest_file)
    nb_print(f"Loaded study object from: {latest_file}")

if not (len(imputations_list_mar26) == len(y_surv_readm_list_corrected) == len(y_surv_death_list)):
    raise ValueError("Mismatch in lengths of imputations and outcome lists.")

assert len(imputations_list_mar26[0]) == len(y_surv_readm_list_corrected[0]), (
    f"X/y_readm mismatch: {len(imputations_list_mar26[0])} vs {len(y_surv_readm_list_corrected[0])}"
)
assert len(imputations_list_mar26[0]) == len(y_surv_death_list[0]), (
    f"X/y_death mismatch: {len(imputations_list_mar26[0])} vs {len(y_surv_death_list[0])}"
)

# ---------------------------
# Helpers
# ---------------------------
def get_stratification_labels(df):
    labels = np.zeros(len(df), dtype=np.int32)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    return labels

def predict_cif_aalen_johansen_approx(y_tr_readm, y_tr_death, risk_tr, risk_va, eval_times):
    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    time_train = y_tr_readm['time']
    event_any = y_tr_readm['event'] | y_tr_death['event']

    order = np.argsort(time_train)
    t_ord = time_train[order]
    e_any_ord = event_any[order]
    e_readm_ord = y_tr_readm['event'][order]
    risk_tr_ord = risk_tr[order]

    unique_times = np.unique(t_ord[e_any_ord])
    S_all = np.ones(len(unique_times) + 1)
    baseline_hazard_readm = np.zeros(len(unique_times))

    current_S = 1.0
    for i, t in enumerate(unique_times):
        at_risk_mask = t_ord >= t
        n_at_risk_t = np.sum(at_risk_mask)
        events_any_t = np.sum((t_ord == t) & e_any_ord)
        events_readm_t = np.sum((t_ord == t) & e_readm_ord)
        if n_at_risk_t > 0:
            S_all[i + 1] = current_S * (1.0 - events_any_t / n_at_risk_t)
            current_S = S_all[i + 1]
            baseline_hazard_readm[i] = events_readm_t / np.sum(risk_tr_ord[at_risk_mask])

    cif_va = np.zeros((len(risk_va), len(eval_times)))
    for j, eval_t in enumerate(eval_times):
        valid_idx = np.where(unique_times <= eval_t)[0]
        if len(valid_idx) > 0:
            S_all_t_minus = S_all[valid_idx]
            dH_readm = baseline_hazard_readm[valid_idx]
            base_cif_increment = S_all_t_minus * dH_readm
            cif_va[:, j] = risk_va * np.sum(base_cif_increment)

    return 1.0 - cif_va

EVAL_HORIZONS = [3, 6, 12, 36, 60]

# ---------------------------
# Phase 1 Pareto
# ---------------------------
pareto_trials = [t for t in study.best_trials if t.values is not None and len(t.values) == 2]
if len(pareto_trials) == 0:
    raise ValueError("No valid Pareto trials in Phase 1 study.")

phase1_rows = []
for t in pareto_trials:
    row = {
        "trial_id": t.number,
        "Phase1_C_Index": float(t.values[0]),
        "Phase1_IBS": float(t.values[1]),
        "Phase1_Global_C_Index": float(t.user_attrs.get("Global_C_Index", np.nan)),
        "Sampled_Imputation": t.user_attrs.get("Sampled_Imputation", np.nan)
    }
    row.update(t.params)
    phase1_rows.append(row)

df_phase1 = pd.DataFrame(phase1_rows)
df_phase1["Phase1_Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase1["Phase1_C_Index"])**2 + (df_phase1["Phase1_IBS"])**2
)
df_phase1 = df_phase1.sort_values("Phase1_Distance_to_Ideal", ascending=True).reset_index(drop=True)

top_k = min(TOP_K, len(df_phase1))
candidate_ids = df_phase1.head(top_k)["trial_id"].astype(int).tolist()
trial_map = {t.number: t for t in pareto_trials}
candidate_trials = [trial_map[i] for i in candidate_ids]

nb_print(f"Phase 1 Pareto trials: {len(df_phase1)} | Phase 2 candidates: {len(candidate_trials)}")

# ---------------------------
# Parallel config
# ---------------------------
n_imputations = len(imputations_list_mar26)
N_CORES = max(1, os.cpu_count() - 2)

# Tune this:
# 2 = safer on RAM, 3 = faster if RAM allows
CANDIDATE_PARALLEL = 3

# Total concurrent tasks ~= candidates in parallel * imputations in parallel
TOTAL_WORKERS = min(N_CORES, CANDIDATE_PARALLEL * len(imputations_list_mar26))

# Threads per XGBoost model (keep small to avoid oversubscription)
XGB_THREADS = max(1, N_CORES // TOTAL_WORKERS)
XGB_THREADS = min(XGB_THREADS, 2)

nb_print(
    f"Phase 2 parallel config -> workers={TOTAL_WORKERS}, "
    f"xgb_threads/model={XGB_THREADS}, candidate_parallel~{CANDIDATE_PARALLEL}"
)

# ---------------------------
# Precompute imputation payloads once
# ---------------------------
imp_payloads = []
for imp_id, (X_imp, y_readm_imp, y_death_imp) in enumerate(
    zip(imputations_list_mar26, y_surv_readm_list_corrected, y_surv_death_list)
):
    y_label = np.where(y_readm_imp['event'], y_readm_imp['time'], -y_readm_imp['time'])
    strat_labels = get_stratification_labels(X_imp)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)
    splits = list(skf.split(X_imp, strat_labels))

    imp_payloads.append({
        "imp_id": imp_id,
        "X_imp": X_imp,
        "y_label": y_label,
        "y_readm_imp": y_readm_imp,
        "y_death_imp": y_death_imp,
        "splits": splits,
    })

def evaluate_trial_imputation(trial_id, trial_params, payload):
    X_imp = payload["X_imp"]
    y_label = payload["y_label"]
    y_readm_imp = payload["y_readm_imp"]
    y_death_imp = payload["y_death_imp"]
    splits = payload["splits"]
    imp_id = payload["imp_id"]

    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'tree_method': 'hist',
        'nthread': XGB_THREADS,
        'verbosity': 0,
        'seed': 2125,
        **trial_params
    }

    fold_c, fold_ibs, fold_global_c = [], [], []

    for tr_idx, va_idx in splits:
        X_tr, X_va = X_imp.iloc[tr_idx], X_imp.iloc[va_idx]
        y_tr, y_va = y_label[tr_idx], y_label[va_idx]
        y_tr_readm, y_va_readm = y_readm_imp[tr_idx], y_readm_imp[va_idx]
        y_tr_death = y_death_imp[tr_idx]

        dtr = xgb.DMatrix(X_tr, label=y_tr)
        dva = xgb.DMatrix(X_va, label=y_va)

        model = xgb.train(
            params, dtr,
            num_boost_round=1500,
            evals=[(dva, "val")],
            early_stopping_rounds=30,
            verbose_eval=False
        )

        risk_tr = model.predict(dtr)
        risk_va = model.predict(dva)

        # Multi-horizon C-index
        h_c = []
        for tau in EVAL_HORIZONS:
            try:
                h_c.append(concordance_index_ipcw(y_tr_readm, y_va_readm, risk_va, tau=tau)[0])
            except Exception:
                pass
        fold_c.append(float(np.mean(h_c)) if len(h_c) > 0 else 0.5)

        # Global C-index
        try:
            fold_global_c.append(float(concordance_index_ipcw(y_tr_readm, y_va_readm, risk_va)[0]))
        except Exception:
            fold_global_c.append(0.5)

        # IBS
        try:
            surv_probs_va = predict_cif_aalen_johansen_approx(
                y_tr_readm, y_tr_death, risk_tr, risk_va, EVAL_HORIZONS
            )
            _, brier_at_tau = brier_score(y_tr_readm, y_va_readm, surv_probs_va, EVAL_HORIZONS)
            fold_ibs.append(float(np.mean(brier_at_tau)))
        except Exception:
            fold_ibs.append(0.25)

        del model, dtr, dva, risk_tr, risk_va
        gc.collect()

    return (
        trial_id,
        imp_id,
        float(np.mean(fold_c)),
        float(np.mean(fold_ibs)),
        float(np.mean(fold_global_c)),
    )

# ---------------------------
# Run all (candidate, imputation) tasks in parallel
# ---------------------------
tasks = [
    (t.number, t.params, payload)
    for t in candidate_trials
    for payload in imp_payloads
]

nb_print(f"Launching {len(tasks)} parallel tasks ({len(candidate_trials)} candidates x {len(imp_payloads)} imputations).")

with parallel_backend("threading", n_jobs=TOTAL_WORKERS):
    results = Parallel(verbose=10)(
        delayed(evaluate_trial_imputation)(trial_id, trial_params, payload)
        for trial_id, trial_params, payload in tasks
    )

# ---------------------------
# Aggregate back to trial-level
# ---------------------------
trial_metrics = defaultdict(lambda: {"c": {}, "ibs": {}, "global_c": {}})

for trial_id, imp_id, c_val, ibs_val, g_val in results:
    trial_metrics[trial_id]["c"][imp_id] = c_val
    trial_metrics[trial_id]["ibs"][imp_id] = ibs_val
    trial_metrics[trial_id]["global_c"][imp_id] = g_val

phase2_rows = []
phase2_detail = {}

for t in candidate_trials:
    m = trial_metrics[t.number]

    if len(m["c"]) != len(imp_payloads):
        raise RuntimeError(f"Trial {t.number} missing imputation results.")

    c_per_imp = [m["c"][i] for i in range(len(imp_payloads))]
    ibs_per_imp = [m["ibs"][i] for i in range(len(imp_payloads))]
    global_c_per_imp = [m["global_c"][i] for i in range(len(imp_payloads))]

    row = {
        "trial_id": t.number,
        "Phase2_Multi_Horizon_C_Index": float(np.mean(c_per_imp)),
        "Phase2_Aalen_Johansen_Brier_Score": float(np.mean(ibs_per_imp)),
        "Phase2_Global_C_Index": float(np.mean(global_c_per_imp)),
        "C_Index_SD_across_imputations": float(np.std(c_per_imp)),
        "IBS_SD_across_imputations": float(np.std(ibs_per_imp))
    }
    row.update(t.params)
    phase2_rows.append(row)

    phase2_detail[t.number] = {"c_per_imp": c_per_imp, "ibs_per_imp": ibs_per_imp}

df_phase2 = pd.DataFrame(phase2_rows)
df_phase2["Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase2["Phase2_Multi_Horizon_C_Index"])**2
    + (df_phase2["Phase2_Aalen_Johansen_Brier_Score"])**2
)
df_phase2 = df_phase2.sort_values("Distance_to_Ideal", ascending=True).reset_index(drop=True)

winner = df_phase2.iloc[0]
winner_trial_id = int(winner["trial_id"])
param_keys = list(candidate_trials[0].params.keys())
params_winner = {k: winner[k] for k in param_keys}

nb_print("\nFinal winner from Phase 2 (all imputations):")
nb_print(f"  Trial ID: {winner_trial_id}")
nb_print(f"  Multi-Horizon C-Index: {winner['Phase2_Multi_Horizon_C_Index']:.4f}")
nb_print(f"  Aalen-Johansen Brier: {winner['Phase2_Aalen_Johansen_Brier_Score']:.4f}")
nb_print(f"  Global C-Index: {winner['Phase2_Global_C_Index']:.4f}")
nb_print(f"  Distance to Ideal: {winner['Distance_to_Ideal']:.4f}")

c_indices_all_imp = phase2_detail[winner_trial_id]["c_per_imp"]
ibs_all_imp = phase2_detail[winner_trial_id]["ibs_per_imp"]
nb_print(f"C-index variance across imputations: {np.std(c_indices_all_imp):.4f}")
nb_print(f"IBS variance across imputations: {np.std(ibs_all_imp):.4f}")

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
os.makedirs("_out", exist_ok=True)
df_phase1.to_csv(f"_out/Readmission_Pareto_Phase1_{timestamp_str}_mar26.csv", index=False)
df_phase2.to_csv(f"_out/Readmission_Pareto_Phase2_{timestamp_str}_mar26.csv", index=False)
nb_print(f"\nSaved: _out/Readmission_Pareto_Phase1_{timestamp_str}_mar26.csv")
nb_print(f"Saved: _out/Readmission_Pareto_Phase2_{timestamp_str}_mar26.csv")

end_time = time.time()
elapsed_seconds = end_time - start_time
nb_print(f"Time taken: {elapsed_seconds/60:.2f} minutes")
Phase 2: loading Phase 1 study and rescoring top candidates on all imputations...
Using in-memory study with study_name='XGB_Readm_Optuna_Study_StochMI'.
Phase 1 Pareto trials: 9 | Phase 2 candidates: 9
Phase 2 parallel config -> workers=15, xgb_threads/model=2, candidate_parallel~3
Launching 45 parallel tasks (9 candidates x 5 imputations).
[Parallel(n_jobs=15)]: Using backend ThreadingBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   2 tasks      | elapsed:  7.1min
[Parallel(n_jobs=15)]: Done  11 tasks      | elapsed:  8.1min
[Parallel(n_jobs=15)]: Done  21 out of  45 | elapsed: 14.7min remaining: 16.8min
[Parallel(n_jobs=15)]: Done  26 out of  45 | elapsed: 15.4min remaining: 11.3min
[Parallel(n_jobs=15)]: Done  31 out of  45 | elapsed: 22.9min remaining: 10.4min
[Parallel(n_jobs=15)]: Done  36 out of  45 | elapsed: 23.5min remaining:  5.9min
[Parallel(n_jobs=15)]: Done  41 out of  45 | elapsed: 23.8min remaining:  2.3min
[Parallel(n_jobs=15)]: Done  45 out of  45 | elapsed: 23.8min finished
Final winner from Phase 2 (all imputations):
  Trial ID: 52
  Multi-Horizon C-Index: 0.6467
  Aalen-Johansen Brier: 0.1105
  Global C-Index: 0.6176
  Distance to Ideal: 0.3702
C-index variance across imputations: 0.0002
IBS variance across imputations: 0.0001
Saved: _out/Readmission_Pareto_Phase1_20260305_1431_mar26.csv
Saved: _out/Readmission_Pareto_Phase2_20260305_1431_mar26.csv
Time taken: 23.85 minutes
Code
import pandas as pd
from IPython.display import HTML, display

def show_scrollable_df(df, max_height=500, max_width=1200):
    html_table = df.to_html(index=False)
    scroll_box = f"""
    <div style="max-height:{max_height}px; max-width:{max_width}px;
                overflow-y:auto; overflow-x:auto; border:1px solid #ccc;">
    {html_table}
    </div>
    """
    display(HTML(scroll_box))
Code
show_scrollable_df(df_phase1)
trial_id Phase1_C_Index Phase1_IBS Phase1_Global_C_Index Sampled_Imputation learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Phase1_Distance_to_Ideal
52 0.647085 0.110633 0.618108 3 0.003377 9 4 0.610275 0.405147 0.284944 0.668159 0.043185 0.369849
69 0.647123 0.111107 0.618033 5 0.009800 8 5 0.718240 0.449258 0.011134 0.101847 0.471845 0.369955
79 0.645644 0.110534 0.618633 5 0.007126 8 27 0.799758 0.478864 0.249842 2.949075 0.020444 0.371195
33 0.645409 0.110356 0.617976 4 0.005417 9 30 0.657228 0.487221 0.023960 0.304705 0.056264 0.371367
0 0.645430 0.110441 0.619597 1 0.009674 5 10 0.760740 0.556438 0.001052 9.824890 0.028099 0.371372
89 0.645234 0.110082 0.617488 5 0.002803 8 12 0.745269 0.476133 0.066481 0.650462 0.120555 0.371452
20 0.644388 0.109887 0.616894 1 0.003245 7 15 0.675051 0.407253 0.002712 3.084713 0.208060 0.372203
73 0.643748 0.109848 0.615733 4 0.002097 9 20 0.860766 0.511514 0.236068 0.328810 0.316959 0.372803
12 0.642659 0.109674 0.615176 3 0.002083 8 27 0.799758 0.478864 0.986947 2.949075 0.020444 0.373793
Code
show_scrollable_df(df_phase2)
trial_id Phase2_Multi_Horizon_C_Index Phase2_Aalen_Johansen_Brier_Score Phase2_Global_C_Index C_Index_SD_across_imputations IBS_SD_across_imputations learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Distance_to_Ideal
52 0.646658 0.110466 0.617593 0.000151 0.000056 0.003377 9 4 0.610275 0.405147 0.284944 0.668159 0.043185 0.370207
69 0.646548 0.111110 0.618781 0.000434 0.000042 0.009800 8 5 0.718240 0.449258 0.011134 0.101847 0.471845 0.370504
79 0.645365 0.110537 0.618412 0.000419 0.000039 0.007126 8 27 0.799758 0.478864 0.249842 2.949075 0.020444 0.371463
0 0.645302 0.110427 0.619422 0.000252 0.000028 0.009674 5 10 0.760740 0.556438 0.001052 9.824890 0.028099 0.371490
89 0.645040 0.110085 0.616957 0.000109 0.000015 0.002803 8 12 0.745269 0.476133 0.066481 0.650462 0.120555 0.371639
33 0.644741 0.110354 0.617274 0.000318 0.000059 0.005417 9 30 0.657228 0.487221 0.023960 0.304705 0.056264 0.372004
20 0.644425 0.109886 0.616752 0.000090 0.000004 0.003245 7 15 0.675051 0.407253 0.002712 3.084713 0.208060 0.372167
73 0.643706 0.109858 0.615511 0.000145 0.000005 0.002097 9 20 0.860766 0.511514 0.236068 0.328810 0.316959 0.372846
12 0.642727 0.109680 0.614876 0.000148 0.000002 0.002083 8 27 0.799758 0.478864 0.986947 2.949075 0.020444 0.373729

Optuna (2)

  1. Informed seeding accelerates convergence to optimal region.

  2. NSGA-II sampler improves multi-objective Pareto exploration.

  3. Dual stratification balances plan type and event status.

  4. Narrowed search space reduces wasted trials on poor regions.

  5. Stochastic MI maintains rigor while reducing compute cost.

  6. Aggressive pruning (0.57 C-index) filters weak trials early.

  7. 50 trials sufficient when starting from proven configurations.

  8. Phase 2 validation still required across all imputations.

  9. Iterative refinement is best practice for hyperparameter tuning.

  10. Saved study objects enable reproducibility and audit trails.

Key assumptions

  1. Previous Phase 1 Pareto front contains near-optimal configurations.

  2. Narrowed search space still contains the global optimum.

  3. Dual stratification improves fold balance without causing sparsity.

  4. Seven seed configurations adequately cover the promising region.

  5. NSGA-II sampler converges faster than default TPE for this problem.

Code
# @title Optuna2 Phase 1 - Stochastic MI (Dual Stratified, Informed Search)
import os
import gc
import joblib
import optuna
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sksurv.metrics import concordance_index_ipcw, brier_score
import warnings
import time
from pathlib import Path


warnings.filterwarnings("ignore")

start_time = time.time()
# ---------------------------
# Config
# ---------------------------
if "PROJECT_ROOT" not in globals():
    raise RuntimeError("PROJECT_ROOT is not defined. Run the project-root setup cell first.")

ROOT_PROJECT = Path(PROJECT_ROOT).resolve()   # alias if you prefer this name
INPUT_DIR = ROOT_PROJECT / "_input"
OUTPUT_DIR = ROOT_PROJECT / "_out"
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

EVAL_HORIZONS = [3, 6, 12, 36, 60]
N_CORES = max(1, os.cpu_count() - 2)
PHASE1_TRIALS = 50            # includes enqueued trials
EARLY_STOPPING = 30
NUM_BOOST_ROUND = 1500

nb_print(f"Optuna2 Phase 1 starting with {N_CORES} cores...")

# ---------------------------
# Data
# ---------------------------
imputations_tune = [df.copy() for df in imputations_list_mar26]

n_imputations = len(imputations_tune)

# Explicit train-split versions (70521 rows)
y_readm_struct_list = [y.copy() for y in y_surv_readm_list_corrected]
y_death_struct_list  = [y.copy() for y in y_surv_death_list]
y_xgb_label_list     = [np.where(y['event'], y['time'], -y['time']) for y in y_surv_readm_list_corrected]

assert all(
    len(imp) == len(yr) == len(yd) == len(yl)
    for imp, yr, yd, yl in zip(imputations_tune, y_readm_struct_list, y_death_struct_list, y_xgb_label_list)
), (
    f"Length mismatch! X={len(imputations_tune[0])}, "
    f"y_readm={len(y_readm_struct_list[0])}, "
    f"y_death={len(y_death_struct_list[0])}, "
    f"y_label={len(y_xgb_label_list[0])}"
)

if not (len(y_readm_struct_list) == len(y_death_struct_list) == n_imputations):
    raise ValueError("Mismatch lengths among imputations/readm/death lists.")

# ---------------------------
# DUAL STRATIFICATION (Plan + Event)
# ---------------------------
def get_dual_stratification_labels(df, y_struct):
    labels = np.zeros(len(df), dtype=int)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    event_status = y_struct['event'].astype(int)
    return (labels * 10) + event_status

strat_labels_dual_list = [get_dual_stratification_labels(df, y) for df, y in zip(imputations_tune, y_readm_struct_list)]
y_xgb_label_list = [np.where(y['event'], y['time'], -y['time']) for y in y_readm_struct_list]

# ---------------------------
# Aalen-Johansen approximation (unchanged)
# ---------------------------
def predict_cif_aalen_johansen_approx(y_tr_readm, y_tr_death, risk_tr, risk_va, eval_times):
    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    time_train = y_tr_readm['time']
    event_any = y_tr_readm['event'] | y_tr_death['event']

    order = np.argsort(time_train)
    t_ord = time_train[order]
    e_any_ord = event_any[order]
    e_readm_ord = y_tr_readm['event'][order]
    risk_tr_ord = risk_tr[order]

    unique_times = np.unique(t_ord[e_any_ord])
    S_all = np.ones(len(unique_times) + 1)
    baseline_hazard_readm = np.zeros(len(unique_times))

    current_S = 1.0
    for i, t in enumerate(unique_times):
        at_risk_mask = t_ord >= t
        n_at_risk_t = np.sum(at_risk_mask)
        events_any_t = np.sum((t_ord == t) & e_any_ord)
        events_readm_t = np.sum((t_ord == t) & e_readm_ord)

        if n_at_risk_t > 0:
            S_all[i + 1] = current_S * (1.0 - events_any_t / n_at_risk_t)
            current_S = S_all[i + 1]
            baseline_hazard_readm[i] = events_readm_t / np.sum(risk_tr_ord[at_risk_mask])

    cif_va = np.zeros((len(risk_va), len(eval_times)))
    for j, eval_t in enumerate(eval_times):
        valid_idx = np.where(unique_times <= eval_t)[0]
        if len(valid_idx) > 0:
            S_all_t_minus = S_all[valid_idx]
            dH_readm = baseline_hazard_readm[valid_idx]
            base_cif_increment = S_all_t_minus * dH_readm
            cif_va[:, j] = risk_va * np.sum(base_cif_increment)

    return 1.0 - cif_va

# ---------------------------
# Informed seed configs
# ---------------------------
PARAM_KEYS = [
    "learning_rate", "max_depth", "min_child_weight", "subsample",
    "colsample_bytree", "reg_alpha", "reg_lambda", "gamma"
]
INT_KEYS = {"max_depth", "min_child_weight"}

def load_phase2_seed_configs(output_dir, top_n=7, pattern="Readmission_Pareto_Phase2_*_mar26.csv"):
    files = list(output_dir.glob(pattern))
    if not files:
        return [], None
    latest = max(files, key=lambda p: p.stat().st_mtime)
    df = pd.read_csv(latest)
    if "Distance_to_Ideal" in df.columns:
        df = df.sort_values("Distance_to_Ideal", ascending=True)
    seeds, seen = [], set()
    for _, row in df.head(top_n).iterrows():
        cfg = {}
        ok = True
        for k in PARAM_KEYS:
            if k not in row or pd.isna(row[k]):
                ok = False
                break
            v = row[k]
            cfg[k] = int(round(v)) if k in INT_KEYS else float(v)
        if not ok:
            continue
        sig = tuple(cfg[k] for k in PARAM_KEYS)
        if sig not in seen:
            seen.add(sig)
            seeds.append(cfg)
    return seeds, latest

# ← THIS WAS MISSING
seed_configs, phase2_file = load_phase2_seed_configs(OUTPUT_DIR, top_n=7)
if seed_configs:
    nb_print(f"Loaded {len(seed_configs)} seed configs from: {phase2_file}")
else:
    nb_print("No previous Phase 2 seeds found. Running without seeds.")

# ---------------------------
# Objective: stochastic MI (1 imputation per trial), dual strat
# ---------------------------
def objective(trial):
    params = {
        'objective': 'survival:cox',
        'eval_metric': 'cox-nloglik',
        'tree_method': 'hist',
        'nthread': 1,
        'verbosity': 0,
        'seed': 2125,

        # Informed/narrowed ranges for faster convergence
        'learning_rate': trial.suggest_float('learning_rate', 0.0025, 0.012, log=True),
        'max_depth': trial.suggest_int('max_depth', 4, 9),
        'min_child_weight': trial.suggest_int('min_child_weight', 4, 30),
        'subsample': trial.suggest_float('subsample', 0.54, 0.92),
        'colsample_bytree': trial.suggest_float('colsample_bytree', 0.45, 0.66),
        'reg_alpha': trial.suggest_float('reg_alpha', 1e-4, 2.0, log=True),
        'reg_lambda': trial.suggest_float('reg_lambda', 0.1, 3.5, log=True),
        'gamma': trial.suggest_float('gamma', 0.2, 2.0),
    }

    imp_idx = trial.number % n_imputations
    trial.set_user_attr("Sampled_Imputation", int(imp_idx + 1))

    df_tune = imputations_tune[imp_idx]
    y_readm_struct = y_readm_struct_list[imp_idx]
    y_death_struct = y_death_struct_list[imp_idx]
    strat_labels_dual = strat_labels_dual_list[imp_idx]
    y_xgb_label = y_xgb_label_list[imp_idx]

    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)

    fold_c_indices, fold_ib_scores, fold_global_c_indices = [], [], []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(df_tune, strat_labels_dual)):
        X_tr, X_va = df_tune.iloc[train_idx], df_tune.iloc[val_idx]
        y_tr_readm_xgb, y_va_readm_xgb = y_xgb_label[train_idx], y_xgb_label[val_idx]

        y_tr_readm_struct, y_va_readm_struct = y_readm_struct[train_idx], y_readm_struct[val_idx]
        y_tr_death_struct = y_death_struct[train_idx]

        dtrain = xgb.DMatrix(X_tr, label=y_tr_readm_xgb)
        dval = xgb.DMatrix(X_va, label=y_va_readm_xgb)

        model = xgb.train(
            params,
            dtrain,
            num_boost_round=NUM_BOOST_ROUND,
            evals=[(dval, 'val')],
            early_stopping_rounds=EARLY_STOPPING,
            verbose_eval=False
        )

        risk_tr = model.predict(dtrain)
        risk_va = model.predict(dval)

        # Multi-horizon C-index
        h_c_indices = []
        for tau_val in EVAL_HORIZONS:
            try:
                c_val = concordance_index_ipcw(y_tr_readm_struct, y_va_readm_struct, risk_va, tau=tau_val)[0]
                h_c_indices.append(c_val)
            except Exception:
                pass
        avg_c_index = np.mean(h_c_indices) if len(h_c_indices) > 0 else 0.5

        # Global C-index
        try:
            global_c = concordance_index_ipcw(y_tr_readm_struct, y_va_readm_struct, risk_va)[0]
        except Exception:
            global_c = 0.5

        # IBS
        try:
            surv_probs_va = predict_cif_aalen_johansen_approx(
                y_tr_readm_struct, y_tr_death_struct, risk_tr, risk_va, EVAL_HORIZONS
            )
            _, brier_scores_at_tau = brier_score(y_tr_readm_struct, y_va_readm_struct, surv_probs_va, EVAL_HORIZONS)
            avg_ibs = np.mean(brier_scores_at_tau)
        except Exception:
            avg_ibs = 0.25

        fold_c_indices.append(avg_c_index)
        fold_global_c_indices.append(global_c)
        fold_ib_scores.append(avg_ibs)

        del model, dtrain, dval, risk_tr, risk_va
        gc.collect()

        # mild prune
        if fold_idx >= 2 and np.mean(fold_c_indices) < 0.58:
            raise optuna.TrialPruned()

    trial.set_user_attr("Global_C_Index", float(np.mean(fold_global_c_indices)))
    return float(np.mean(fold_c_indices)), float(np.mean(fold_ib_scores))

sampler = optuna.samplers.NSGAIISampler(seed=2125)
study = optuna.create_study(
    directions=['maximize', 'minimize'],
    study_name="XGB_Readm_Optuna2_Study_AJ_StochMI",
    sampler=sampler
)

# enqueue seeds so search starts from proven region
for cfg in seed_configs:
    study.enqueue_trial(cfg)

nb_print(f"Running Optuna2 Phase 1 ({PHASE1_TRIALS} trials incl. seeds)...")
study.optimize(objective, n_trials=PHASE1_TRIALS, n_jobs=N_CORES, show_progress_bar=True)

nb_print("\nPhase 1 Pareto:")
for t in study.best_trials:
    nb_print(
        f"Trial {t.number} | Imp={t.user_attrs.get('Sampled_Imputation','NA')} | "
        f"C={t.values[0]:.4f} | IBS={t.values[1]:.4f}"
    )

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
study_filename = INPUT_DIR / f"XGB_Readm_Optuna2_Study_AJ_StochMI_{timestamp_str}_mar26.pkl"
joblib.dump(study, study_filename)
nb_print(f"\nSaved study: {study_filename}")

end_time = time.time()
elapsed_seconds = end_time - start_time
nb_print(f"Time taken: {elapsed_seconds/60:.2f} minutes")
Optuna2 Phase 1 starting with 30 cores...
Loaded 7 seed configs from: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\Readmission_Pareto_Phase2_20260305_1431_mar26.csv
[I 2026-03-05 14:45:09,771] A new study created in memory with name: XGB_Readm_Optuna2_Study_AJ_StochMI
Running Optuna2 Phase 1 (50 trials incl. seeds)...
100%|██████████| 50/50 [29:43<00:00, 35.67s/it]   
Phase 1 Pareto:
Trial 0 | Imp=1 | C=0.6464 | IBS=0.1107
Trial 4 | Imp=5 | C=0.6443 | IBS=0.1101
Trial 6 | Imp=2 | C=0.6436 | IBS=0.1099
Trial 9 | Imp=5 | C=0.6403 | IBS=0.1099
Trial 10 | Imp=1 | C=0.6452 | IBS=0.1106
Trial 11 | Imp=2 | C=0.6446 | IBS=0.1105
Trial 17 | Imp=3 | C=0.6445 | IBS=0.1105
Trial 23 | Imp=4 | C=0.6447 | IBS=0.1105
Trial 25 | Imp=1 | C=0.6386 | IBS=0.1097
Trial 33 | Imp=4 | C=0.6454 | IBS=0.1106
Trial 41 | Imp=2 | C=0.6391 | IBS=0.1098
Trial 43 | Imp=4 | C=0.6449 | IBS=0.1106
Trial 47 | Imp=3 | C=0.6395 | IBS=0.1098
Saved study: G:\My Drive\Alvacast\SISTRAT 2023\cons\_input\XGB_Readm_Optuna2_Study_AJ_StochMI_20260305_1514_mar26.pkl
Time taken: 29.73 minutes
Code
# @title Optuna2 Phase 2 - Re-score top candidates on ALL imputations (Dual Stratified)
import os
import re
import glob
import gc
import time
import joblib
import numpy as np
import pandas as pd
import xgboost as xgb
from datetime import datetime
from collections import defaultdict
from sklearn.model_selection import StratifiedKFold
from sksurv.metrics import concordance_index_ipcw, brier_score
from joblib import Parallel, delayed, parallel_backend
from pathlib import Path

start_time = time.time()
nb_print("Optuna2 Phase 2: loading study and rescoring top candidates on all imputations...")

TOP_K = 10
CANDIDATE_PARALLEL = 3  # safer default; increase to 3 if RAM stable
EVAL_HORIZONS = [3, 6, 12, 36, 60]
if "ROOT_PROJECT" in globals():
    ROOT_PROJECT = Path(ROOT_PROJECT).resolve()
elif "PROJECT_ROOT" in globals():
    ROOT_PROJECT = Path(PROJECT_ROOT).resolve()
else:
    raise RuntimeError("ROOT_PROJECT/PROJECT_ROOT is not defined. Run the root setup cell first.")

INPUT_DIR = ROOT_PROJECT / "_input"
OUTPUT_DIR = ROOT_PROJECT / "_out"
INPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

nb_print(f"ROOT_PROJECT: {ROOT_PROJECT}")


# ---------------------------
# Load/reuse study
# ---------------------------
expected_study_name = "XGB_Readm_Optuna2_Study_AJ_StochMI"
if 'study' in globals() and getattr(study, 'study_name', None) == expected_study_name:
    nb_print(f"Using in-memory study '{expected_study_name}'.")
else:
    pattern_name = "XGB_Readm_Optuna2_Study_AJ_StochMI_*_mar26.pkl"
    files = list(INPUT_DIR.glob(pattern_name))
    if not files:
        raise FileNotFoundError(f"No files found for pattern: {INPUT_DIR / pattern_name}")

    timestamped = []
    for f in files:
        m = re.search(r"XGB_Readm_Optuna2_Study_AJ_StochMI_(\d{8}_\d{4})_mar26\.pkl$", f.name)
        if m:
            try:
                dt = datetime.strptime(m.group(1), "%Y%m%d_%H%M")
                timestamped.append((dt, f))
            except ValueError:
                pass

    if not timestamped:
        raise FileNotFoundError("No valid timestamped study files found.")

    latest_file = max(timestamped, key=lambda x: x[0])[1]
    study = joblib.load(latest_file)
    nb_print(f"Loaded study: {latest_file}")

if not (len(imputations_list_mar26) == len(y_surv_readm_list_corrected) == len(y_surv_death_list)):
    raise ValueError("Mismatch among imputations/readm/death list lengths.")

# ---------------------------
# Helpers
# ---------------------------
def get_dual_stratification_labels(df, y_struct):
    labels = np.zeros(len(df), dtype=int)
    if 'plan_type_corr_pg_pr' in df.columns: labels[df['plan_type_corr_pg_pr'] == 1] = 1
    if 'plan_type_corr_m_pr' in df.columns: labels[df['plan_type_corr_m_pr'] == 1] = 2
    if 'plan_type_corr_pg_pai' in df.columns: labels[df['plan_type_corr_pg_pai'] == 1] = 3
    if 'plan_type_corr_m_pai' in df.columns: labels[df['plan_type_corr_m_pai'] == 1] = 4
    event_status = y_struct['event'].astype(int)
    return (labels * 10) + event_status

def predict_cif_aalen_johansen_approx(y_tr_readm, y_tr_death, risk_tr, risk_va, eval_times):
    if np.any(risk_tr <= 0):
        risk_tr = np.exp(risk_tr)
        risk_va = np.exp(risk_va)

    time_train = y_tr_readm['time']
    event_any = y_tr_readm['event'] | y_tr_death['event']

    order = np.argsort(time_train)
    t_ord = time_train[order]
    e_any_ord = event_any[order]
    e_readm_ord = y_tr_readm['event'][order]
    risk_tr_ord = risk_tr[order]

    unique_times = np.unique(t_ord[e_any_ord])
    S_all = np.ones(len(unique_times) + 1)
    baseline_hazard_readm = np.zeros(len(unique_times))

    current_S = 1.0
    for i, t in enumerate(unique_times):
        at_risk_mask = t_ord >= t
        n_at_risk_t = np.sum(at_risk_mask)
        events_any_t = np.sum((t_ord == t) & e_any_ord)
        events_readm_t = np.sum((t_ord == t) & e_readm_ord)
        if n_at_risk_t > 0:
            S_all[i + 1] = current_S * (1.0 - events_any_t / n_at_risk_t)
            current_S = S_all[i + 1]
            baseline_hazard_readm[i] = events_readm_t / np.sum(risk_tr_ord[at_risk_mask])

    cif_va = np.zeros((len(risk_va), len(eval_times)))
    for j, eval_t in enumerate(eval_times):
        valid_idx = np.where(unique_times <= eval_t)[0]
        if len(valid_idx) > 0:
            S_all_t_minus = S_all[valid_idx]
            dH_readm = baseline_hazard_readm[valid_idx]
            base_cif_increment = S_all_t_minus * dH_readm
            cif_va[:, j] = risk_va * np.sum(base_cif_increment)

    return 1.0 - cif_va

# ---------------------------
# Select candidates from Phase 1 Pareto
# ---------------------------
pareto_trials = [t for t in study.best_trials if t.values is not None and len(t.values) == 2]
if len(pareto_trials) == 0:
    raise ValueError("No valid Pareto trials in study.")

phase21_rows = []
for t in pareto_trials:
    row = {
        "trial_id": t.number,
        "Phase1_C_Index": float(t.values[0]),
        "Phase1_IBS": float(t.values[1]),
        "Phase1_Global_C_Index": float(t.user_attrs.get("Global_C_Index", np.nan)),
        "Sampled_Imputation": t.user_attrs.get("Sampled_Imputation", np.nan),
    }
    row.update(t.params)
    phase21_rows.append(row)

df_phase21 = pd.DataFrame(phase21_rows)
df_phase21["Phase1_Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase21["Phase1_C_Index"])**2 + (df_phase21["Phase1_IBS"])**2
)
df_phase21 = df_phase21.sort_values("Phase1_Distance_to_Ideal", ascending=True).reset_index(drop=True)

top_k = min(TOP_K, len(df_phase21))
candidate_ids = df_phase21.head(top_k)["trial_id"].astype(int).tolist()
trial_map = {t.number: t for t in pareto_trials}
candidate_trials = [trial_map[i] for i in candidate_ids]

# ---------------------------
# Parallel config
# ---------------------------
n_imputations = len(imputations_list_mar26)
N_CORES = max(1, os.cpu_count() - 2)
max_tasks = len(candidate_trials) * n_imputations
TOTAL_WORKERS = min(N_CORES, CANDIDATE_PARALLEL * n_imputations, max_tasks)
TOTAL_WORKERS = max(1, TOTAL_WORKERS)

XGB_THREADS = max(1, N_CORES // TOTAL_WORKERS)
XGB_THREADS = min(XGB_THREADS, 2)

nb_print(
    f"Phase 2 config -> workers={TOTAL_WORKERS}, xgb_threads/model={XGB_THREADS}, "
    f"candidate_parallel~{CANDIDATE_PARALLEL}, candidates={len(candidate_trials)}"
)

# ---------------------------
# Precompute per-imputation payload
# ---------------------------
imp_payloads = []
for imp_id, (X_imp, y_readm_imp, y_death_imp) in enumerate(
    zip(imputations_list_mar26, y_surv_readm_list_corrected, y_surv_death_list)
):
    y_label = np.where(y_readm_imp['event'], y_readm_imp['time'], -y_readm_imp['time'])
    strat_labels = get_dual_stratification_labels(X_imp, y_readm_imp)
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=2125)
    splits = list(skf.split(X_imp, strat_labels))

    imp_payloads.append({
        "imp_id": imp_id,
        "X_imp": X_imp,
        "y_label": y_label,
        "y_readm_imp": y_readm_imp,
        "y_death_imp": y_death_imp,
        "splits": splits,
    })

def evaluate_trial_imputation(trial_id, trial_params, payload):
    try:
        X_imp = payload["X_imp"]
        y_label = payload["y_label"]
        y_readm_imp = payload["y_readm_imp"]
        y_death_imp = payload["y_death_imp"]
        splits = payload["splits"]
        imp_id = payload["imp_id"]

        params = {
            'objective': 'survival:cox',
            'eval_metric': 'cox-nloglik',
            'tree_method': 'hist',
            'nthread': XGB_THREADS,
            'verbosity': 0,
            'seed': 2125,
            **trial_params
        }

        fold_c, fold_ibs, fold_global_c = [], [], []

        for tr_idx, va_idx in splits:
            X_tr, X_va = X_imp.iloc[tr_idx], X_imp.iloc[va_idx]
            y_tr, y_va = y_label[tr_idx], y_label[va_idx]
            y_tr_readm, y_va_readm = y_readm_imp[tr_idx], y_readm_imp[va_idx]
            y_tr_death = y_death_imp[tr_idx]

            dtr = xgb.DMatrix(X_tr, label=y_tr)
            dva = xgb.DMatrix(X_va, label=y_va)

            model = xgb.train(
                params, dtr,
                num_boost_round=1500,
                evals=[(dva, "val")],
                early_stopping_rounds=30,
                verbose_eval=False
            )

            risk_tr = model.predict(dtr)
            risk_va = model.predict(dva)

            h_c = []
            for tau in EVAL_HORIZONS:
                try:
                    h_c.append(concordance_index_ipcw(y_tr_readm, y_va_readm, risk_va, tau=tau)[0])
                except Exception:
                    pass
            fold_c.append(float(np.mean(h_c)) if len(h_c) > 0 else 0.5)

            try:
                fold_global_c.append(float(concordance_index_ipcw(y_tr_readm, y_va_readm, risk_va)[0]))
            except Exception:
                fold_global_c.append(0.5)

            try:
                surv_probs_va = predict_cif_aalen_johansen_approx(
                    y_tr_readm, y_tr_death, risk_tr, risk_va, EVAL_HORIZONS
                )
                _, brier_at_tau = brier_score(y_tr_readm, y_va_readm, surv_probs_va, EVAL_HORIZONS)
                fold_ibs.append(float(np.mean(brier_at_tau)))
            except Exception:
                fold_ibs.append(0.25)

            del model, dtr, dva, risk_tr, risk_va
            gc.collect()

        return (trial_id, imp_id, float(np.mean(fold_c)), float(np.mean(fold_ibs)), float(np.mean(fold_global_c)), None)
    except Exception as e:
        return (trial_id, payload["imp_id"], 0.5, 0.25, 0.5, str(e))

# ---------------------------
# Parallel execution over (candidate, imputation)
# ---------------------------
tasks = [(t.number, t.params, payload) for t in candidate_trials for payload in imp_payloads]
nb_print(f"Launching {len(tasks)} tasks ({len(candidate_trials)} candidates x {len(imp_payloads)} imputations).")

with parallel_backend("threading", n_jobs=TOTAL_WORKERS):
    results = Parallel(verbose=10)(
        delayed(evaluate_trial_imputation)(trial_id, trial_params, payload)
        for trial_id, trial_params, payload in tasks
    )

# ---------------------------
# Aggregate
# ---------------------------
trial_metrics = defaultdict(lambda: {"c": {}, "ibs": {}, "global_c": {}, "errors": []})

for trial_id, imp_id, c_val, ibs_val, g_val, err in results:
    trial_metrics[trial_id]["c"][imp_id] = c_val
    trial_metrics[trial_id]["ibs"][imp_id] = ibs_val
    trial_metrics[trial_id]["global_c"][imp_id] = g_val
    if err is not None:
        trial_metrics[trial_id]["errors"].append((imp_id, err))

phase22_rows = []
phase22_detail = {}

for t in candidate_trials:
    m = trial_metrics[t.number]

    c_per_imp = [m["c"].get(i, 0.5) for i in range(n_imputations)]
    ibs_per_imp = [m["ibs"].get(i, 0.25) for i in range(n_imputations)]
    global_c_per_imp = [m["global_c"].get(i, 0.5) for i in range(n_imputations)]

    row = {
        "trial_id": t.number,
        "Phase2_Multi_Horizon_C_Index": float(np.mean(c_per_imp)),
        "Phase2_Aalen_Johansen_Brier_Score": float(np.mean(ibs_per_imp)),
        "Phase2_Global_C_Index": float(np.mean(global_c_per_imp)),
        "C_Index_SD_across_imputations": float(np.std(c_per_imp)),
        "IBS_SD_across_imputations": float(np.std(ibs_per_imp)),
        "Phase2_Error_Count": len(m["errors"]),
    }
    row.update(t.params)
    phase22_rows.append(row)
    phase22_detail[t.number] = {"c_per_imp": c_per_imp, "ibs_per_imp": ibs_per_imp, "errors": m["errors"]}

df_phase22 = pd.DataFrame(phase22_rows)
df_phase22["Distance_to_Ideal"] = np.sqrt(
    (1.0 - df_phase22["Phase2_Multi_Horizon_C_Index"])**2
    + (df_phase22["Phase2_Aalen_Johansen_Brier_Score"])**2
)
df_phase22 = df_phase22.sort_values("Distance_to_Ideal", ascending=True).reset_index(drop=True)

winner = df_phase22.iloc[0]
winner_trial_id = int(winner["trial_id"])
param_keys = list(candidate_trials[0].params.keys())
params_winner = {k: winner[k] for k in param_keys}

nb_print("\nWinner hyperparameters:")
nb_print(params_winner)

nb_print("\nFinal winner from Optuna2 Phase 2:")
nb_print(f"  Trial ID: {winner_trial_id}")
nb_print(f"  Multi-Horizon C-Index: {winner['Phase2_Multi_Horizon_C_Index']:.4f}")
nb_print(f"  Aalen-Johansen Brier: {winner['Phase2_Aalen_Johansen_Brier_Score']:.4f}")
nb_print(f"  Global C-Index: {winner['Phase2_Global_C_Index']:.4f}")
nb_print(f"  Distance to Ideal: {winner['Distance_to_Ideal']:.4f}")
nb_print(f"  Error count: {int(winner['Phase2_Error_Count'])}")

c_indices_all_imp = phase22_detail[winner_trial_id]["c_per_imp"]
ibs_all_imp = phase22_detail[winner_trial_id]["ibs_per_imp"]
nb_print(f"C-index variance across imputations: {np.std(c_indices_all_imp):.4f}")
nb_print(f"IBS variance across imputations: {np.std(ibs_all_imp):.4f}")

timestamp_str = datetime.now().strftime("%Y%m%d_%H%M")
phase21_csv = OUTPUT_DIR / f"Readmission_Optuna2_Pareto_Phase1_{timestamp_str}_mar26.csv"
phase22_csv = OUTPUT_DIR / f"Readmission_Optuna2_Pareto_Phase2_{timestamp_str}_mar26.csv"

df_phase21.to_csv(phase21_csv, index=False)
df_phase22.to_csv(phase22_csv, index=False)

nb_print(f"\nSaved: {phase21_csv}")
nb_print(f"Saved: {phase22_csv}")

elapsed_min = (time.time() - start_time) / 60.0
nb_print(f"Time taken: {elapsed_min:.2f} minutes")
Optuna2 Phase 2: loading study and rescoring top candidates on all imputations...
ROOT_PROJECT: G:\My Drive\Alvacast\SISTRAT 2023\cons
Using in-memory study 'XGB_Readm_Optuna2_Study_AJ_StochMI'.
Phase 2 config -> workers=15, xgb_threads/model=2, candidate_parallel~3, candidates=10
Launching 50 tasks (10 candidates x 5 imputations).
[Parallel(n_jobs=15)]: Using backend ThreadingBackend with 15 concurrent workers.
[Parallel(n_jobs=15)]: Done   2 tasks      | elapsed:  8.5min
[Parallel(n_jobs=15)]: Done  11 tasks      | elapsed:  9.3min
[Parallel(n_jobs=15)]: Done  20 tasks      | elapsed: 15.4min
[Parallel(n_jobs=15)]: Done  27 out of  50 | elapsed: 16.0min remaining: 13.6min
[Parallel(n_jobs=15)]: Done  33 out of  50 | elapsed: 23.9min remaining: 12.3min
[Parallel(n_jobs=15)]: Done  39 out of  50 | elapsed: 24.6min remaining:  6.9min
[Parallel(n_jobs=15)]: Done  45 out of  50 | elapsed: 24.9min remaining:  2.8min
[Parallel(n_jobs=15)]: Done  50 out of  50 | elapsed: 28.1min finished
Winner hyperparameters:
{'learning_rate': np.float64(0.0033765926340982), 'max_depth': np.float64(9.0), 'min_child_weight': np.float64(4.0), 'subsample': np.float64(0.6102749292690339), 'colsample_bytree': np.float64(0.405146532644876), 'reg_alpha': np.float64(0.2849441972783855), 'reg_lambda': np.float64(0.6681588791085157), 'gamma': np.float64(0.0431853685583679)}
Final winner from Optuna2 Phase 2:
  Trial ID: 0
  Multi-Horizon C-Index: 0.6459
  Aalen-Johansen Brier: 0.1106
  Global C-Index: 0.6159
  Distance to Ideal: 0.3710
  Error count: 0
C-index variance across imputations: 0.0002
IBS variance across imputations: 0.0000
Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\Readmission_Optuna2_Pareto_Phase1_20260305_1553_mar26.csv
Saved: G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\Readmission_Optuna2_Pareto_Phase2_20260305_1553_mar26.csv
Time taken: 28.08 minutes
Code
params_winner_final=params_winner
display(Markdown("### Optuna2 - Phase 1 (Stochastic MI Pareto)"))
show_scrollable_df(df_phase21, max_height=500, max_width=1200)

Optuna2 - Phase 1 (Stochastic MI Pareto)

trial_id Phase1_C_Index Phase1_IBS Phase1_Global_C_Index Sampled_Imputation learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Phase1_Distance_to_Ideal
0 0.646380 0.110654 0.616290 1 0.003377 9 4 0.610275 0.405147 0.284944 0.668159 0.043185 0.370529
33 0.645408 0.110616 0.616405 4 0.002958 9 5 0.721627 0.465382 0.005320 1.095623 1.267510 0.371445
10 0.645224 0.110613 0.617615 1 0.004562 7 13 0.895885 0.527995 0.000327 0.294984 0.704946 0.371619
43 0.644879 0.110592 0.616547 4 0.007342 7 9 0.636120 0.588509 0.009120 1.225093 1.212974 0.371943
23 0.644671 0.110534 0.616670 4 0.007266 8 15 0.715472 0.501864 0.000276 0.206993 1.854012 0.372125
11 0.644643 0.110530 0.616396 2 0.011049 6 15 0.623485 0.577611 0.017213 0.139574 0.690884 0.372150
17 0.644535 0.110500 0.616822 3 0.006960 9 18 0.661477 0.451618 0.000257 1.477238 1.194202 0.372244
4 0.644324 0.110147 0.616293 5 0.002803 8 12 0.745269 0.476133 0.066481 0.650462 0.120555 0.372341
6 0.643633 0.109931 0.616083 2 0.003245 7 15 0.675051 0.407253 0.002712 3.084713 0.208060 0.372937
9 0.640267 0.109906 0.614428 5 0.004244 4 29 0.734294 0.539887 0.124436 1.785109 1.680789 0.376148
47 0.639532 0.109785 0.613350 3 0.003415 4 29 0.698507 0.457739 0.000113 0.132438 0.686893 0.376816
41 0.639133 0.109771 0.613131 2 0.003033 4 18 0.688294 0.523706 0.069412 0.927038 1.475790 0.377193
25 0.638612 0.109742 0.612039 1 0.002630 4 18 0.741451 0.569747 0.062265 2.086441 1.966366 0.377683
Code
show_scrollable_df(df_phase22, max_height=500, max_width=1200)
trial_id Phase2_Multi_Horizon_C_Index Phase2_Aalen_Johansen_Brier_Score Phase2_Global_C_Index C_Index_SD_across_imputations IBS_SD_across_imputations Phase2_Error_Count learning_rate max_depth min_child_weight subsample colsample_bytree reg_alpha reg_lambda gamma Distance_to_Ideal
0 0.645887 0.110638 0.615911 0.000172 0.000031 0 0.003377 9 4 0.610275 0.405147 0.284944 0.668159 0.043185 0.370994
33 0.645208 0.110670 0.616221 0.000231 0.000041 0 0.002958 9 5 0.721627 0.465382 0.005320 1.095623 1.267510 0.371652
23 0.644981 0.110602 0.616796 0.000185 0.000041 0 0.007266 8 15 0.715472 0.501864 0.000276 0.206993 1.854012 0.371848
10 0.644864 0.110622 0.617598 0.000200 0.000027 0 0.004562 7 13 0.895885 0.527995 0.000327 0.294984 0.704946 0.371966
43 0.644546 0.110694 0.616975 0.000055 0.000032 0 0.007342 7 9 0.636120 0.588509 0.009120 1.225093 1.212974 0.372292
17 0.644356 0.110476 0.615752 0.000234 0.000075 0 0.006960 9 18 0.661477 0.451618 0.000257 1.477238 1.194202 0.372408
4 0.644237 0.110131 0.616062 0.000141 0.000011 0 0.002803 8 12 0.745269 0.476133 0.066481 0.650462 0.120555 0.372419
6 0.643762 0.109933 0.616189 0.000146 0.000009 0 0.003245 7 15 0.675051 0.407253 0.002712 3.084713 0.208060 0.372815
11 0.643850 0.110629 0.617196 0.000134 0.000034 0 0.011049 6 15 0.623485 0.577611 0.017213 0.139574 0.690884 0.372937
9 0.640107 0.109918 0.614482 0.000075 0.000004 0 0.004244 4 29 0.734294 0.539887 0.124436 1.785109 1.680789 0.376304

📊 IPython: Summary of Pareto Findings

Optimism correction

Code
# @title Harrell's Bootstrap Optimism Correction (Readmission, Parallelized, CPU-2, with 95% CI)
import numpy as np
import pandas as pd
import xgboost as xgb
import os
from sklearn.utils import resample
from sklearn.model_selection import train_test_split
from sksurv.metrics import concordance_index_ipcw, concordance_index_censored
from joblib import Parallel, delayed
import warnings
from pathlib import Path  # add this import

# ROOT_PROJECT / PROJECT_ROOT -> output directory
if "ROOT_PROJECT" in globals():
    ROOT_PROJECT = Path(ROOT_PROJECT).resolve()
elif "PROJECT_ROOT" in globals():
    ROOT_PROJECT = Path(PROJECT_ROOT).resolve()
else:
    raise RuntimeError("ROOT_PROJECT (or PROJECT_ROOT) is not defined. Run the root setup cell first.")

OUT_DIR = ROOT_PROJECT / "_out"
OUT_DIR.mkdir(parents=True, exist_ok=True)
nb_print(f"ROOT_PROJECT: {ROOT_PROJECT}")

warnings.filterwarnings("ignore")

# Fallback for nb_print
if "nb_print" not in globals():
    def nb_print(*args, **kwargs):
        print(*args, **kwargs)

nb_print("Initializing Parallel Harrell's Bootstrap Optimism Correction for Readmission...")

# --- CPU CONFIGURATION ---
N_CORES = max(1, (os.cpu_count() or 2) - 2)
nb_print(f"Parallel Execution Configured: Using {N_CORES} CPU cores.")

# --- 1. DATA SETUP ---
try:
    df_tune = imputations_list_mar26[0].copy()
    y_tune_struct = y_surv_readm_list_corrected[0]

    assert len(df_tune) == len(y_tune_struct), (
        f"X/y mismatch: df_tune={len(df_tune)}, y_tune_struct={len(y_tune_struct)}"
    )
except Exception as e:
    raise ValueError(f"Data Error: {e}. Please ensure mar26 data is loaded.")

y_xgb_label = np.where(y_tune_struct["event"], y_tune_struct["time"], -y_tune_struct["time"])

# --- 2. FINAL MAR26 WINNER HYPERPARAMETERS ---
# Manual fallback winner params
manual_params_winner = {
    "objective": "survival:cox",
    "eval_metric": "cox-nloglik",
    "tree_method": "hist",
    "device": "cpu",
    "verbosity": 0,
    "seed": 2125,
    "learning_rate": 0.0033765926340982,
    "max_depth": 9,
    "min_child_weight": 4,
    "subsample": 0.6102749292690339,
    "colsample_bytree": 0.405146532644876,
    "reg_alpha": 0.2849441972783855,
    "reg_lambda": 0.6681588791085157,
    "gamma": 0.0431853685583679,
}

allowed_param_keys = {
    "objective", "eval_metric", "tree_method", "device", "verbosity", "seed",
    "learning_rate", "max_depth", "min_child_weight", "subsample",
    "colsample_bytree", "reg_alpha", "reg_lambda", "gamma"
}

if "params_winner_final" in globals() and isinstance(params_winner_final, dict) and len(params_winner_final) > 0:
    mem_params = {k: v for k, v in params_winner_final.items() if k in allowed_param_keys and pd.notna(v)}
    params_winner = manual_params_winner.copy()
    params_winner.update(mem_params)
    source_params = "params_winner_final (memory)"
else:
    params_winner = manual_params_winner.copy()
    source_params = "manual fallback"

# Enforce fixed setup for this notebook cell
params_winner.update({
    "objective": "survival:cox",
    "eval_metric": "cox-nloglik",
    "tree_method": "hist",
    "device": "cpu",
    "verbosity": 0,
    "seed": 2125,
})

for k in ("max_depth", "min_child_weight"):
    if k in params_winner:
        params_winner[k] = int(round(float(params_winner[k])))

nb_print(f"Hyperparameter source: {source_params}")

B_ITERATIONS = 500
ALPHA = 0.05  # 95% CI

# --- STRATIFICATION HELPER ---
def get_dual_stratification_labels(df, y_struct):
    labels = np.zeros(len(df), dtype=int)
    if "plan_type_corr_pg_pr" in df.columns:
        labels[df["plan_type_corr_pg_pr"] == 1] = 1
    if "plan_type_corr_m_pr" in df.columns:
        labels[df["plan_type_corr_m_pr"] == 1] = 2
    if "plan_type_corr_pg_pai" in df.columns:
        labels[df["plan_type_corr_pg_pai"] == 1] = 3
    if "plan_type_corr_m_pai" in df.columns:
        labels[df["plan_type_corr_m_pai"] == 1] = 4
    event_status = y_struct["event"].astype(int)
    return (labels * 10) + event_status

strat_labels_dual = get_dual_stratification_labels(df_tune, y_tune_struct)

# --- 3. CALCULATE APPARENT PERFORMANCE ON ORIGINAL DATA ---
nb_print("Calculating apparent performance on the original full dataset...")

params_initial = params_winner.copy()
params_initial["nthread"] = N_CORES

X_train_app, X_val_app, y_train_xgb_app, y_val_xgb_app = train_test_split(
    df_tune, y_xgb_label, test_size=0.2, random_state=2125, stratify=strat_labels_dual
)

dtrain_app = xgb.DMatrix(X_train_app, label=y_train_xgb_app)
dval_app = xgb.DMatrix(X_val_app, label=y_val_xgb_app)

temp_model = xgb.train(
    params_initial,
    dtrain_app,
    num_boost_round=2000,
    evals=[(dval_app, "val")],
    early_stopping_rounds=30,
    verbose_eval=False
)

if getattr(temp_model, "best_iteration", None) is None or temp_model.best_iteration < 0:
    optimal_boost_rounds = 2000
else:
    optimal_boost_rounds = int(temp_model.best_iteration) + 1

nb_print(f"Optimal boosting rounds determined: {optimal_boost_rounds}")

# Train definitive baseline model on 100% of data
dorig = xgb.DMatrix(df_tune, label=y_xgb_label)
baseline_model = xgb.train(
    params_initial,
    dorig,
    num_boost_round=optimal_boost_rounds,
    verbose_eval=False
)

risk_orig = baseline_model.predict(dorig)
try:
    c_apparent_orig = float(concordance_index_ipcw(y_tune_struct, y_tune_struct, risk_orig)[0])
except Exception:
    c_apparent_orig = float(
        concordance_index_censored(y_tune_struct["event"], y_tune_struct["time"], risk_orig)[0]
    )

nb_print(f"Baseline Apparent Global C-index: {c_apparent_orig:.4f}")

# --- 4. PARALLEL BOOTSTRAP WORKER FUNCTION ---
def parallel_bootstrap_worker(b, df_original, y_xgb_original, y_struct_orig, params, opt_rounds):
    boot_params = params.copy()
    boot_params["nthread"] = 1  # avoid CPU thrashing

    idx = np.arange(len(df_original))
    boot_indices = resample(idx, replace=True, n_samples=len(idx), random_state=b)

    X_boot = df_original.iloc[boot_indices]
    y_xgb_boot = y_xgb_original[boot_indices]
    y_struct_boot = y_struct_orig[boot_indices]

    dboot = xgb.DMatrix(X_boot, label=y_xgb_boot)
    dorig_local = xgb.DMatrix(df_original, label=y_xgb_original)

    boot_model = xgb.train(
        boot_params,
        dboot,
        num_boost_round=opt_rounds,
        verbose_eval=False
    )

    # Apparent boot performance
    risk_boot = boot_model.predict(dboot)
    try:
        c_boot_app = float(concordance_index_ipcw(y_struct_boot, y_struct_boot, risk_boot)[0])
    except Exception:
        c_boot_app = float(
            concordance_index_censored(y_struct_boot["event"], y_struct_boot["time"], risk_boot)[0]
        )

    # Test on original data
    risk_test_orig = boot_model.predict(dorig_local)
    try:
        c_boot_test = float(concordance_index_ipcw(y_struct_boot, y_struct_orig, risk_test_orig)[0])
    except Exception:
        c_boot_test = float(
            concordance_index_censored(y_struct_orig["event"], y_struct_orig["time"], risk_test_orig)[0]
        )

    return c_boot_app - c_boot_test

# --- 5. EXECUTE PARALLEL BOOTSTRAP LOOP ---
nb_print(f"\nLaunching {B_ITERATIONS} Parallel Bootstrap Iterations...")

optimism_values = Parallel(n_jobs=N_CORES, verbose=10)(
    delayed(parallel_bootstrap_worker)(
        b, df_tune, y_xgb_label, y_tune_struct, params_winner, optimal_boost_rounds
    )
    for b in range(B_ITERATIONS)
)

# --- 6. FINAL METRICS + 95% CI ---
optimism_values = np.asarray(optimism_values, dtype=float)
optimism_values = optimism_values[np.isfinite(optimism_values)]

if optimism_values.size == 0:
    raise ValueError("No valid bootstrap optimism values were produced.")

mean_optimism = float(np.mean(optimism_values))
c_index_corrected = float(c_apparent_orig - mean_optimism)

corrected_samples = c_apparent_orig - optimism_values
opt_ci_low, opt_ci_high = np.quantile(optimism_values, [ALPHA / 2, 1 - ALPHA / 2])
corr_ci_low, corr_ci_high = np.quantile(corrected_samples, [ALPHA / 2, 1 - ALPHA / 2])

nb_print("\n--------------------------------------------------")
nb_print("FINAL OPTIMISM-CORRECTED RESULTS (READMISSION)")
nb_print("--------------------------------------------------")
nb_print(f"Apparent C-Index (Original Data)            : {c_apparent_orig:.4f}")
nb_print(f"Mean Optimism (from {optimism_values.size} boots) : {mean_optimism:.4f}")
nb_print(f"Optimism 95% CI                             : [{opt_ci_low:.4f}, {opt_ci_high:.4f}]")
nb_print(f"Optimism-Corrected C-Index                  : {c_index_corrected:.4f}")
nb_print(f"Corrected C-Index 95% CI                    : [{corr_ci_low:.4f}, {corr_ci_high:.4f}]")
nb_print("--------------------------------------------------")

# --- 7. EXPORT ---

summary_df = pd.DataFrame({
    "Metric": ["Apparent_C_Index", "Mean_Optimism", "Corrected_C_Index"],
    "Value": [c_apparent_orig, mean_optimism, c_index_corrected],
    "CI_95_Lower": [np.nan, opt_ci_low, corr_ci_low],
    "CI_95_Upper": [np.nan, opt_ci_high, corr_ci_high],
    "Bootstrap_N": [np.nan, optimism_values.size, optimism_values.size]
})
dist_df = pd.DataFrame({
    "optimism": optimism_values,
    "corrected_c_index_sample": corrected_samples
})

timestamp_str = pd.Timestamp.now().strftime("%Y%m%d_%H%M")

summary_file = OUT_DIR / f"XGB_Readm_Bootstrap_Optimism_Results_{timestamp_str}_mar26.csv"
summary_df.to_csv(summary_file, index=False)

dist_file = OUT_DIR / f"XGB_Readm_Bootstrap_Optimism_Distribution_{timestamp_str}_mar26.csv"
dist_df.to_csv(dist_file, index=False)

nb_print(f"Results saved successfully to {summary_file}")
nb_print(f"Bootstrap distribution saved to {dist_file}")
ROOT_PROJECT: G:\My Drive\Alvacast\SISTRAT 2023\cons
Initializing Parallel Harrell's Bootstrap Optimism Correction for Readmission...
Parallel Execution Configured: Using 30 CPU cores.
Hyperparameter source: params_winner_final (memory)
Calculating apparent performance on the original full dataset...
Optimal boosting rounds determined: 1686
Baseline Apparent Global C-index: 0.7489
Launching 500 Parallel Bootstrap Iterations...
[Parallel(n_jobs=30)]: Using backend LokyBackend with 30 concurrent workers.
[Parallel(n_jobs=30)]: Done   1 tasks      | elapsed:  2.3min
[Parallel(n_jobs=30)]: Done  12 tasks      | elapsed:  2.5min
[Parallel(n_jobs=30)]: Done  25 tasks      | elapsed:  3.1min
[Parallel(n_jobs=30)]: Done  38 tasks      | elapsed:  5.1min
[Parallel(n_jobs=30)]: Done  53 tasks      | elapsed:  6.5min
[Parallel(n_jobs=30)]: Done  68 tasks      | elapsed:  7.8min
[Parallel(n_jobs=30)]: Done  85 tasks      | elapsed:  9.4min
[Parallel(n_jobs=30)]: Done 102 tasks      | elapsed: 10.4min
[Parallel(n_jobs=30)]: Done 121 tasks      | elapsed: 12.7min
[Parallel(n_jobs=30)]: Done 140 tasks      | elapsed: 13.3min
[Parallel(n_jobs=30)]: Done 161 tasks      | elapsed: 15.8min
[Parallel(n_jobs=30)]: Done 182 tasks      | elapsed: 18.7min
[Parallel(n_jobs=30)]: Done 205 tasks      | elapsed: 21.4min
[Parallel(n_jobs=30)]: Done 228 tasks      | elapsed: 22.0min
[Parallel(n_jobs=30)]: Done 253 tasks      | elapsed: 24.6min
[Parallel(n_jobs=30)]: Done 278 tasks      | elapsed: 27.3min
[Parallel(n_jobs=30)]: Done 305 tasks      | elapsed: 30.2min
[Parallel(n_jobs=30)]: Done 332 tasks      | elapsed: 32.8min
[Parallel(n_jobs=30)]: Done 361 tasks      | elapsed: 35.6min
[Parallel(n_jobs=30)]: Done 390 tasks      | elapsed: 38.5min
[Parallel(n_jobs=30)]: Done 421 tasks      | elapsed: 41.4min
[Parallel(n_jobs=30)]: Done 492 out of 500 | elapsed: 47.9min remaining:   46.6s
[Parallel(n_jobs=30)]: Done 500 out of 500 | elapsed: 48.0min finished
--------------------------------------------------
FINAL OPTIMISM-CORRECTED RESULTS (READMISSION)
--------------------------------------------------
Apparent C-Index (Original Data)            : 0.7489
Mean Optimism (from 500 boots) : 0.0927
Optimism 95% CI                             : [0.0879, 0.0973]
Optimism-Corrected C-Index                  : 0.6562
Corrected C-Index 95% CI                    : [0.6516, 0.6609]
--------------------------------------------------
Results saved successfully to G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\XGB_Readm_Bootstrap_Optimism_Results_20260305_1654_mar26.csv
Bootstrap distribution saved to G:\My Drive\Alvacast\SISTRAT 2023\cons\_out\XGB_Readm_Bootstrap_Optimism_Distribution_20260305_1654_mar26.csv

~32 minutes

Readmission Model – Internal Validation Summary

  • Apparent performance overestimated discrimination by ~0.09 C-index units.
  • Bootstrap-corrected C-index stabilized at 0.656 (The optimism-corrected global Uno’s C-index across the full follow-up distribution.)
  • Corrected performance aligns closely with cross-validation results.
  • Overfitting was moderate but appropriately corrected.
  • Readmission prediction shows structurally lower stability than mortality.
  • Flexible tree-based models inflate apparent performance without correction.
  • Internal validation confirms a realistic discrimination ceiling (~0.64).
  • The model demonstrates moderate, clinically plausible discrimination.
  • Optimism correction strengthens methodological credibility.
  • Results are suitable for transparent reporting under TRIPOD guidelines.
Code
from IPython.display import HTML, display
html_table = summary_df.to_html(index=True, escape=False)
scroll_box = f"""
<div style="max-height:600px; max-width:100%; overflow-y:auto; overflow-x:auto; border:1px solid #ddd; padding:6px;">
{html_table}
</div>
"""
display(HTML(scroll_box))
Metric Value CI_95_Lower CI_95_Upper Bootstrap_N
0 Apparent_C_Index 0.748892 NaN NaN NaN
1 Mean_Optimism 0.092679 0.087945 0.097318 500.0
2 Corrected_C_Index 0.656212 0.651574 0.660946 500.0
Back to top