Deep Hit (part 1)

Hyperparameter tuning for a DeepHit competing risks model designed to jointly predict mortality and hospital readmission. It employs a rigorous two-stage grid search strategy, using Uno’s IPCW C-index as the primary optimization metric, to identify the best model configuration. The process begins with a coarse grid search over a wide range of hyperparameters (learning rate, weight decay, network architecture, dropout, batch size, and loss function parameters), followed by a targeted “zoom-in” search around the top-performing candidates to confirm optimality. Using 5-fold stratified cross-validation on the first imputed dataset to ensure robustness, the tuning converged on a final configuration—featuring a [256, 256, 128] architecture, a learning rate of 0.001, and moderate regularization—that maximized the model’s ability to discriminate patient risk over time for both competing events.

Author

ags

Published

February 11, 2026

0. Package loading and installation

Code
# Commented out IPython magic to ensure Python compatibility.
# For Jupyter/Colab notebooks
%reset -f
import gc
gc.collect()

import numpy as np
import pandas as pd
import time

#conda activate surv-deephit
#conda install ipykernel -y
#conda install -c conda-forge pytorch torchtuples pycox
#conda install pytorch torchvision torchaudio pytorch-cuda=11.8 -c pytorch -c nvidia
# conda install pycox torchtuples scikit-learn scikit-survival lifelines shap seaborn matplotlib scipy pandas -c conda-forge -y
# por si: conda install pycox torchtuples -c conda-forge -y

#Conda te avisa que va a hacer dos cambios porque estás instalando PyTorch con CUDA:
#conda-forge::cuda-cudart 12.9  →  nvidia::cuda-cudart 11.8

#Packages stored in : 
#conda env export --no-builds > "G:\My Drive\Alvacast\SISTRAT 2023\dh\environment.yml"

#Load packages in:
#conda activate base
#conda-lock install \
#  -n surv-deephit \
#  "G:\My Drive\Alvacast\SISTRAT 2023\dh\conda-lock.yml"

import sys
import subprocess

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd
import torch
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"✅ Compute Device: {device}")

from sksurv.metrics import (
    concordance_index_ipcw,
    brier_score,
    integrated_brier_score
)
from sksurv.util import Surv



#Glimpse function
def glimpse(df, max_width=80):
    print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")
    for col in df.columns:
        dtype = df[col].dtype
        preview = df[col].astype(str).head(5).tolist()
        preview_str = ", ".join(preview)
        if len(preview_str) > max_width:
            preview_str = preview_str[:max_width] + "..."
        print(f"{col:<30} {str(dtype):<15} {preview_str}")
#Tabyl function
def tabyl(series):
    counts = series.value_counts(dropna=False)
    props = series.value_counts(normalize=True, dropna=False)
    return pd.DataFrame({"value": counts.index,
                         "n": counts.values,
                         "percent": props.values})
#clean_names
import re

def clean_names(df):
    """
    Mimic janitor::clean_names for pandas DataFrames.
    - Lowercase
    - Replace spaces and special chars with underscores
    - Remove non-alphanumeric/underscore
    """
    new_cols = []
    for col in df.columns:
        # lowercase
        col = col.lower()
        # replace spaces and special chars with underscore
        col = re.sub(r"[^\w]+", "_", col)
        # strip leading/trailing underscores
        col = col.strip("_")
        new_cols.append(col)
    df.columns = new_cols
    return df
✅ Compute Device: cuda
Code
packages = ["torch", "torchtuples", "pycox"]

for p in packages:
    try:
        mod = __import__(p)
        print(f"✅ {p} installed | version:", getattr(mod, "__version__", "unknown"))
    except ImportError:
        print(f"❌ {p} NOT installed")
✅ torch installed | version: 2.5.1
✅ torchtuples installed | version: 0.2.2
✅ pycox installed | version: 0.3.0

Load data

Code

from pathlib import Path

BASE_DIR = Path(
    r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1"
)


import pickle

with open(BASE_DIR / "imputations_list_jan26.pkl", "rb") as f:
    imputations_list_jan26 = pickle.load(f)


imputation_nodum_1 = pd.read_parquet(
    BASE_DIR / "imputation_nondum_1.parquet",
    engine="fastparquet"
)

X_reduced_imp0 = pd.read_parquet(
    BASE_DIR / "X_reduced_imp0.parquet",
    engine="fastparquet"
)

imputation_1 = pd.read_parquet(
    BASE_DIR / "imputation_1.parquet",
    engine="fastparquet"
)


# Quick check
glimpse(imputation_nodum_1)
glimpse(imputation_1)
glimpse(X_reduced_imp0)
Rows: 88504 | Columns: 43
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
sex_rec                        object          man, man, man, woman, man
tenure_status_household        object          stays temporarily with a relative, owner/transferred dwellings/pays dividends, s...
cohabitation                   object          alone, family of origin, with couple/children, with couple/children, family of o...
sub_dep_icd10_status           object          drug dependence, hazardous consumption, drug dependence, drug dependence, drug d...
any_violence                   object          0.No domestic violence/sex abuse, 0.No domestic violence/sex abuse, 0.No domesti...
prim_sub_freq_rec              object          2.2–6 days/wk, 3.Daily, 3.Daily, 3.Daily, 3.Daily
tr_outcome                     object          referral, dropout, adm discharge - rule violation/undet, dropout, completion
adm_motive                     object          sanitary sector, spontaneous consultation, sanitary sector, sanitary sector, spo...
first_sub_used                 object          alcohol, alcohol, alcohol, cocaine paste, alcohol
primary_sub_mod                object          alcohol, cocaine paste, cocaine paste, cocaine paste, cocaine paste
tipo_de_vivienda_rec2          object          other/unknown, formal housing, formal housing, formal housing, formal housing
national_foreign               int32           0, 0, 0, 0, 0
plan_type_corr                 object          pg-pab, pg-pab, pg-pr, m-pr, pg-pai
occupation_condition_corr24    object          unemployed, employed, employed, inactive, unemployed
marital_status_rec             object          single, single, single, married/cohabiting, single
urbanicity_cat                 object          3.Urban, 3.Urban, 3.Urban, 3.Urban, 3.Urban
ed_attainment_corr             object          2-Completed high school or less, 3-Completed primary school or less, 2-Completed...
evaluacindelprocesoteraputico  object          logro alto, logro minimo, logro minimo, logro minimo, logro alto
eva_consumo                    object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_fam                        object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro alto
eva_relinterp                  object          logro alto, logro minimo, logro minimo, logro intermedio, logro alto
eva_ocupacion                  object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
eva_sm                         object          logro intermedio, logro minimo, logro minimo, logro intermedio, logro minimo
eva_fisica                     object          logro alto, logro minimo, logro intermedio, logro intermedio, logro alto
eva_transgnorma                object          logro alto, logro minimo, logro minimo, logro minimo, logro intermedio
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_event                    int32           0, 0, 0, 0, 0
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
center_id                      object          330, 489, 341, 411, 176
Rows: 88504 | Columns: 78
readmit_time_from_adm_m        float64         84.93548387096774, 12.833333333333334, 13.733333333333333, 11.966666666666667, 1...
death_time_from_adm_m          float64         84.93548387096774, 87.16129032258064, 117.2258064516129, 98.93548387096774, 37.9...
adm_age_rec3                   float64         31.53, 20.61, 42.52, 60.61, 45.08
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
national_foreign               int32           0, 0, 0, 0, 0
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         bool            False, False, True, False, False
dg_psiq_cie_10_dg              bool            True, False, False, True, False
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f_any_severe_mental         bool            True, False, False, False, False
any_phys_dx                    bool            False, False, False, True, False
polysubstance_strict           int32           0, 1, 1, 1, 1
readmit_time_from_disch_m      float64         68.96774193548387, 7.0, 13.258064516129032, 5.0, 7.354838709677419
readmit_event                  float64         0.0, 1.0, 1.0, 1.0, 1.0
death_time_from_disch_m        float64         68.96774193548387, 81.3225806451613, 116.74193548387096, 91.96774193548387, 31.0...
death_event                    int32           0, 0, 0, 0, 0
sex_rec_woman                  float64         0.0, 0.0, 0.0, 1.0, 0.0
tenure_status_household_illegal_settlement float64         0.0, 0.0, 0.0, 0.0, 0.0
tenure_status_household_owner_transferred_dwellings_pays_dividends float64         0.0, 1.0, 0.0, 1.0, 0.0
tenure_status_household_renting float64         0.0, 0.0, 0.0, 0.0, 0.0
tenure_status_household_stays_temporarily_with_a_relative float64         1.0, 0.0, 1.0, 0.0, 1.0
cohabitation_alone             float64         1.0, 0.0, 0.0, 0.0, 0.0
cohabitation_with_couple_children float64         0.0, 0.0, 1.0, 1.0, 0.0
cohabitation_family_of_origin  float64         0.0, 1.0, 0.0, 0.0, 1.0
sub_dep_icd10_status_drug_dependence float64         1.0, 0.0, 1.0, 1.0, 1.0
any_violence_1_domestic_violence_sex_abuse float64         0.0, 0.0, 0.0, 0.0, 0.0
prim_sub_freq_rec_2_2_6_days_wk float64         1.0, 0.0, 0.0, 0.0, 0.0
prim_sub_freq_rec_3_daily      float64         0.0, 1.0, 1.0, 1.0, 1.0
tr_outcome_adm_discharge_adm_reasons float64         0.0, 0.0, 0.0, 0.0, 0.0
tr_outcome_adm_discharge_rule_violation_undet float64         0.0, 0.0, 1.0, 0.0, 0.0
tr_outcome_completion          float64         0.0, 0.0, 0.0, 0.0, 1.0
tr_outcome_dropout             float64         0.0, 1.0, 0.0, 1.0, 0.0
tr_outcome_referral            float64         1.0, 0.0, 0.0, 0.0, 0.0
adm_motive_another_sud_facility_fonodrogas_senda_previene float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_justice_sector      float64         0.0, 0.0, 0.0, 0.0, 0.0
adm_motive_sanitary_sector     float64         1.0, 0.0, 1.0, 1.0, 0.0
adm_motive_spontaneous_consultation float64         0.0, 1.0, 0.0, 0.0, 1.0
first_sub_used_alcohol         float64         1.0, 1.0, 1.0, 0.0, 1.0
first_sub_used_cocaine_paste   float64         0.0, 0.0, 0.0, 1.0, 0.0
first_sub_used_cocaine_powder  float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_marijuana       float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_opioids         float64         0.0, 0.0, 0.0, 0.0, 0.0
first_sub_used_tranquilizers_hypnotics float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_cocaine_paste  float64         0.0, 1.0, 1.0, 1.0, 1.0
primary_sub_mod_cocaine_powder float64         0.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_alcohol        float64         1.0, 0.0, 0.0, 0.0, 0.0
primary_sub_mod_marijuana      float64         0.0, 0.0, 0.0, 0.0, 0.0
tipo_de_vivienda_rec2_other_unknown float64         1.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_m_pai           float64         0.0, 0.0, 0.0, 0.0, 0.0
plan_type_corr_m_pr            float64         0.0, 0.0, 0.0, 1.0, 0.0
plan_type_corr_pg_pai          float64         0.0, 0.0, 0.0, 0.0, 1.0
plan_type_corr_pg_pr           float64         0.0, 0.0, 1.0, 0.0, 0.0
occupation_condition_corr24_inactive float64         0.0, 0.0, 0.0, 1.0, 0.0
occupation_condition_corr24_unemployed float64         1.0, 0.0, 0.0, 0.0, 1.0
marital_status_rec_separated_divorced_annulled_widowed float64         0.0, 0.0, 0.0, 0.0, 0.0
marital_status_rec_single      float64         1.0, 1.0, 1.0, 0.0, 1.0
urbanicity_cat_1_rural         float64         0.0, 0.0, 0.0, 0.0, 0.0
urbanicity_cat_2_mixed         float64         0.0, 0.0, 0.0, 0.0, 0.0
ed_attainment_corr_2_completed_high_school_or_less float64         1.0, 0.0, 1.0, 1.0, 0.0
ed_attainment_corr_3_completed_primary_school_or_less float64         0.0, 1.0, 0.0, 0.0, 1.0
evaluacindelprocesoteraputico_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 0.0
evaluacindelprocesoteraputico_logro_minimo float64         0.0, 1.0, 1.0, 1.0, 0.0
eva_consumo_logro_intermedio   float64         0.0, 0.0, 0.0, 1.0, 0.0
eva_consumo_logro_minimo       float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_fam_logro_intermedio       float64         1.0, 0.0, 0.0, 1.0, 0.0
eva_fam_logro_minimo           float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_relinterp_logro_intermedio float64         0.0, 0.0, 0.0, 1.0, 0.0
eva_relinterp_logro_minimo     float64         0.0, 1.0, 1.0, 0.0, 0.0
eva_ocupacion_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 1.0
eva_ocupacion_logro_minimo     float64         0.0, 1.0, 1.0, 1.0, 0.0
eva_sm_logro_intermedio        float64         1.0, 0.0, 0.0, 1.0, 0.0
eva_sm_logro_minimo            float64         0.0, 1.0, 1.0, 0.0, 1.0
eva_fisica_logro_intermedio    float64         0.0, 0.0, 1.0, 1.0, 0.0
eva_fisica_logro_minimo        float64         0.0, 1.0, 0.0, 0.0, 0.0
eva_transgnorma_logro_intermedio float64         0.0, 0.0, 0.0, 0.0, 1.0
eva_transgnorma_logro_minimo   float64         0.0, 1.0, 1.0, 1.0, 0.0
Rows: 88504 | Columns: 123
ethnicity                      float64         0.0, 0.0, 0.0, 0.0, 0.0
dg_psiq_cie_10_instudy         int64           0, 0, 1, 0, 0
dg_psiq_cie_10_dg              int64           1, 0, 0, 1, 0
f0_organic                     int64           0, 0, 0, 0, 0
f2_psychotic                   int64           0, 0, 0, 0, 0
f3_mood                        int64           0, 0, 0, 0, 0
f4_anxiety_stress_somatoform   int64           0, 0, 0, 0, 0
f5_physio_eating_sleep_sexual  int64           0, 0, 0, 0, 0
f6_personality_adult_behaviour int64           0, 0, 0, 1, 0
f7_intellectual_disability     int64           1, 0, 0, 0, 0
f8_9_neurodevelopment_child    int64           0, 0, 0, 0, 0
dx_f2_smi_psychotic            int32           0, 0, 0, 0, 0
dx_f3_mood                     int32           0, 0, 0, 0, 0
dx_f45_anx_stress_phys         int32           0, 0, 0, 0, 0
dx_f6_personality              int32           0, 0, 0, 1, 0
dx_f0789_neurocog_dev          int32           1, 0, 0, 0, 0
phys_dx_instudy                int32           1, 1, 1, 1, 1
phys_dx_other_spec_medical_cond int32           0, 0, 0, 1, 0
phys_dx_organ_system_med_dis   int32           0, 0, 0, 0, 0
phys_dx_injuries_and_sequelae  int32           0, 0, 0, 0, 0
phys_dx_infectious_diseases    int32           0, 0, 0, 0, 0
polysubstance_strict           int32           0, 1, 1, 1, 1
treat_lt_90                    int32           0, 0, 1, 0, 0
adm_age_log                    float64         3.48216274048526, 3.0731561705187946, 3.773220602547687, 4.120824195026479, 3.83...
adm_age_pow2                   float64         994.1409000000001, 424.77209999999997, 1807.9504000000002, 3673.5721, 2032.20639...
adm_age_pow3                   float64         31345.262577000005, 8754.552980999999, 76874.05100800002, 222655.204981, 91611.8...
adm_age_c                      float64         -4.23091837657055, -15.150918376570552, 6.759081623429452, 24.84908162342945, 9....
porc_pobr                      float64         0.175679117441177, 0.187835901975632, 0.130412444472313, 0.133759185671806, 0.08...
porc_pobr_log                  float64         0.1618459529248312, 0.17213308174953304, 0.12258256123798024, 0.1255388244402368...
porc_pobr_c                    float64         0.03381019166263441, 0.04596697619708939, -0.011456481306229588, -0.008109740106...
dit_m                          float64         15.967741935483872, 5.833333333333334, 0.4752688172043005, 6.966666666666667, 6....
treat_log                      float64         2.831314008252564, 1.921812597476253, 0.3888402221385285, 2.075266170269355, 2.0...
treat_days_pow2                float64         254.96878251821022, 34.027777777777786, 0.22588044860677478, 48.534444444444446,...
treat_days_pow3                float64         4071.2757208552925, 198.49537037037044, 0.10735393363891862, 338.1232962962963, ...
sex_rec_woman                  bool            False, False, False, True, False
tenure_status_household_illegal settlement bool            False, False, False, False, False
tenure_status_household_others bool            False, False, False, False, False
tenure_status_household_renting bool            False, False, False, False, False
tenure_status_household_stays temporarily with a relative bool            True, False, True, False, True
occupation_condition_corr24_inactive bool            False, False, False, True, False
occupation_condition_corr24_unemployed bool            True, False, False, False, True
marital_status_rec_married/cohabiting bool            False, False, False, True, False
marital_status_rec_separated/divorced/annulled/widowed bool            False, False, False, False, False
marital_status_rec_single      bool            True, True, True, False, True
urbanicity_cat_2.Mixed         bool            False, False, False, False, False
urbanicity_cat_3.Urban         bool            True, True, True, True, True
ed_attainment_corr_1-More than high school bool            False, False, False, False, False
ed_attainment_corr_2-Completed high school or less bool            True, False, True, True, False
ed_attainment_corr_3-Completed primary school or less bool            False, True, False, False, True
cohabitation_with couple/children bool            False, False, True, True, False
cohabitation_family of origin  bool            False, True, False, False, True
cohabitation_Others            bool            False, False, False, False, False
sub_dep_icd10_status_drug dependence bool            True, False, True, True, True
dom_violence_Domestic violence bool            False, False, True, False, False
sex_abuse_Sexual abuse         bool            False, False, False, False, False
any_violence_0.No domestic violence/sex abuse bool            True, True, False, True, True
prim_sub_freq_1. Less than 1 day a week bool            False, False, False, False, False
prim_sub_freq_2. 1 day a week  bool            False, False, False, False, False
prim_sub_freq_3. 2 to 3 days a week bool            True, False, False, False, False
prim_sub_freq_4. 4 to 6 days a week bool            False, False, False, False, False
prim_sub_freq_5. Daily         bool            False, True, True, True, True
prim_sub_freq_rec_2.2–6 days/wk bool            True, False, False, False, False
prim_sub_freq_rec_3.Daily      bool            False, True, True, True, True
tr_outcome_adm discharge - adm reasons bool            False, False, False, False, False
tr_outcome_adm discharge - rule violation/undet bool            False, False, True, False, False
tr_outcome_dropout             bool            False, True, False, True, False
tr_outcome_other               bool            False, False, False, False, False
tr_outcome_referral            bool            True, False, False, False, False
adm_motive_another SUD facility/FONODROGAS/SENDA Previene bool            False, False, False, False, False
adm_motive_justice sector      bool            False, False, False, False, False
adm_motive_other               bool            False, False, False, False, False
adm_motive_sanitary sector     bool            True, False, True, True, False
primary_sub_amphetamine-type stimulants bool            False, False, False, False, False
primary_sub_cocaine paste      bool            False, True, True, True, True
primary_sub_cocaine powder     bool            False, False, False, False, False
primary_sub_dissociatives      bool            False, False, False, False, False
primary_sub_hallucinogens      bool            False, False, False, False, False
primary_sub_inhalants          bool            False, False, False, False, False
primary_sub_marijuana          bool            False, False, False, False, False
primary_sub_opioids            bool            False, False, False, False, False
primary_sub_others             bool            False, False, False, False, False
primary_sub_tranquilizers/hypnotics bool            False, False, False, False, False
primary_sub_mod_cocaine paste  bool            False, True, True, True, True
primary_sub_mod_cocaine powder bool            False, False, False, False, False
primary_sub_mod_alcohol        bool            True, False, False, False, False
primary_sub_mod_others         bool            False, False, False, False, False
usuario_tribunal_trat_droga_no bool            True, True, True, True, True
usuario_tribunal_trat_droga_si bool            False, False, False, False, False
tipo_de_vivienda_rec_shared/secondary unit bool            False, False, False, False, False
tipo_de_vivienda_rec_homeless/unsheltered/informal/temporary housing/institutional/collective bool            False, False, False, False, False
tipo_de_vivienda_rec_other/unknown bool            True, False, False, False, False
tipo_de_vivienda_rec2_other/unknown bool            True, False, False, True, False
evaluacindelprocesoteraputico_logro alto bool            True, False, False, False, True
evaluacindelprocesoteraputico_logro intermedio bool            False, False, False, False, False
evaluacindelprocesoteraputico_logro minimo bool            False, True, True, True, False
eva_consumo_logro alto         bool            True, False, False, False, True
eva_consumo_logro intermedio   bool            False, False, False, True, False
eva_consumo_logro minimo       bool            False, True, True, False, False
eva_fam_logro alto             bool            False, False, False, False, True
eva_fam_logro intermedio       bool            True, False, False, True, False
eva_fam_logro minimo           bool            False, True, True, False, False
eva_relinterp_logro alto       bool            True, False, False, False, True
eva_relinterp_logro intermedio bool            False, False, False, True, False
eva_relinterp_logro minimo     bool            False, True, True, False, False
eva_ocupacion_logro alto       bool            True, False, False, False, False
eva_ocupacion_logro intermedio bool            False, False, False, False, True
eva_ocupacion_logro minimo     bool            False, True, True, True, False
eva_sm_logro alto              bool            False, False, False, False, False
eva_sm_logro intermedio        bool            True, False, False, True, False
eva_sm_logro minimo            bool            False, True, True, False, True
eva_fisica_logro alto          bool            True, False, False, False, True
eva_fisica_logro intermedio    bool            False, False, True, True, False
eva_fisica_logro minimo        bool            False, True, False, False, False
eva_transgnorma_logro alto     bool            True, False, False, False, False
eva_transgnorma_logro intermedio bool            False, False, False, False, True
eva_transgnorma_logro minimo   bool            False, True, True, True, False
adm_age_cat_30-44              bool            True, False, True, False, False
adm_age_cat_45-64              bool            False, False, False, True, True
nationality_chile_other        bool            False, False, False, False, False
plan_type_corr_m-pr            bool            False, False, False, True, False
plan_type_corr_pg-pab          bool            True, True, False, False, False
plan_type_corr_pg-pai          bool            False, False, False, False, True
plan_type_corr_pg-pr           bool            False, False, True, False, False

Load in python

Code
if isinstance(imputations_list_jan26, list) and len(imputations_list_jan26) > 0:
    print("First element type:", type(imputations_list_jan26[0]))
    if isinstance(imputations_list_jan26[0], dict):
        print("First element keys:", imputations_list_jan26[0].keys())
    elif isinstance(imputations_list_jan26[0], (pd.DataFrame, np.ndarray)):
        print("First element shape:", imputations_list_jan26[0].shape)
First element type: <class 'pandas.core.frame.DataFrame'>
First element shape: (88504, 56)

This code block:

  1. Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
  2. Specifies the file_path: It points to the .pkl file you selected.
  3. Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
  4. Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
  5. Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.

Compare databases (transformed and original)

Inspect and compare the column names of two datasets: the first imputation from imputations_list_jan26 (which likely contains dummy variables) and imputation_nodum_1 (which, as its name suggests, probably doesn’t have dummy variables).

Code
# Inspect columns of the first imputation
cols_first_imp = imputations_list_jan26[0].columns.tolist()
print("First imputation columns:", cols_first_imp[:10], "... total:", len(cols_first_imp))

# Inspect columns of imputation_no_dum
cols_nodum = imputation_nodum_1.columns.tolist()
print("No-dum columns:", cols_nodum[:10], "... total:", len(cols_nodum))

# Compare overlap
common_cols = set(cols_first_imp).intersection(cols_nodum)
missing_in_imp = [c for c in cols_nodum if c not in cols_first_imp]
missing_in_nodum = [c for c in cols_first_imp if c not in cols_nodum]

print("Common columns:", len(common_cols))
print("Missing in imputations_list_jan26:", missing_in_imp)

# Inspect columns of the first imputation
cols_first_imp_raw = imputation_1.columns.tolist()
print("First imputation columns:", cols_first_imp_raw[:10], "... total:", len(cols_first_imp_raw))

# Compare overlap
common_cols_raw = set(cols_first_imp_raw).intersection(cols_nodum)
missing_in_imp_raw = [c for c in cols_nodum if c not in cols_first_imp_raw]

print("Common columns:", len(common_cols_raw))
print("Missing in imputations_list_jan26:", missing_in_imp_raw)
print(common_cols_raw)

import pandas as pd

# Example: choose a combination of variables that uniquely identify rows
key_vars = ["adm_age_rec3", "porc_pobr", "dit_m"]

# Take one imputation (first element of the list) and merge with the no-dum dataset
df_imp = imputations_list_jan26[0]
df_nodum = imputation_nodum_1

merged_check = pd.merge(
    df_imp[key_vars],
    df_nodum[key_vars],
    on=key_vars,
    how="inner"
)

print(f"Merged rows: {merged_check.shape[0]}")
print("Preview of merged check:")
print(merged_check.head())

#drop merge
del merged_check

import pandas as pd

# Example: choose a combination of variables that uniquely identify rows
key_vars_raw = ['dit_m',
            'readmit_time_from_adm_m',
            'death_time_from_adm_m',
            'adm_age_rec3']
# Take one imputation (first element of the list) and merge with the no-dum dataset
df_raw = imputation_1

merged_check_raw = pd.merge(
    df_imp[key_vars],
    df_raw[key_vars],
    on=key_vars,
    how="inner"
)

print(f"Merged rows: {merged_check_raw.shape[0]}")
print("Preview of merged check:")
print(merged_check_raw.head())
print(f"{(merged_check_raw.shape[0] / imputation_1.shape[0] * 100):.2f}%")
#drop merge
del merged_check_raw
First imputation columns: ['adm_age_rec3', 'porc_pobr', 'dit_m', 'tenure_status_household', 'prim_sub_freq_rec', 'national_foreign', 'urbanicity_cat', 'ed_attainment_corr', 'evaluacindelprocesoteraputico', 'eva_consumo'] ... total: 56
No-dum columns: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'adm_age_rec3', 'porc_pobr', 'dit_m', 'sex_rec', 'tenure_status_household', 'cohabitation', 'sub_dep_icd10_status', 'any_violence'] ... total: 43
Common columns: 24
Missing in imputations_list_jan26: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'sex_rec', 'cohabitation', 'sub_dep_icd10_status', 'any_violence', 'tr_outcome', 'adm_motive', 'first_sub_used', 'primary_sub_mod', 'tipo_de_vivienda_rec2', 'plan_type_corr', 'occupation_condition_corr24', 'marital_status_rec', 'readmit_event', 'death_event', 'readmit_time_from_disch_m', 'death_time_from_disch_m', 'center_id']
First imputation columns: ['readmit_time_from_adm_m', 'death_time_from_adm_m', 'adm_age_rec3', 'porc_pobr', 'dit_m', 'national_foreign', 'ethnicity', 'dg_psiq_cie_10_instudy', 'dg_psiq_cie_10_dg', 'dx_f3_mood'] ... total: 78
Common columns: 18
Missing in imputations_list_jan26: ['sex_rec', 'tenure_status_household', 'cohabitation', 'sub_dep_icd10_status', 'any_violence', 'prim_sub_freq_rec', 'tr_outcome', 'adm_motive', 'first_sub_used', 'primary_sub_mod', 'tipo_de_vivienda_rec2', 'plan_type_corr', 'occupation_condition_corr24', 'marital_status_rec', 'urbanicity_cat', 'ed_attainment_corr', 'evaluacindelprocesoteraputico', 'eva_consumo', 'eva_fam', 'eva_relinterp', 'eva_ocupacion', 'eva_sm', 'eva_fisica', 'eva_transgnorma', 'center_id']
{'dx_f6_personality', 'dg_psiq_cie_10_instudy', 'dx_f3_mood', 'any_phys_dx', 'readmit_time_from_disch_m', 'dx_f_any_severe_mental', 'porc_pobr', 'polysubstance_strict', 'dg_psiq_cie_10_dg', 'ethnicity', 'national_foreign', 'adm_age_rec3', 'death_event', 'death_time_from_adm_m', 'dit_m', 'readmit_time_from_adm_m', 'readmit_event', 'death_time_from_disch_m'}
Merged rows: 88516
Preview of merged check:
   adm_age_rec3  porc_pobr      dit_m
0         31.53   0.175679  15.967742
1         20.61   0.187836   5.833333
2         42.52   0.130412   0.475269
3         60.61   0.133759   6.966667
4         45.08   0.083189   6.903226
Merged rows: 88516
Preview of merged check:
   adm_age_rec3  porc_pobr      dit_m
0         31.53   0.175679  15.967742
1         20.61   0.187836   5.833333
2         42.52   0.130412   0.475269
3         60.61   0.133759   6.966667
4         45.08   0.083189   6.903226
100.01%

Create bins for followup (landmarks)

This code prepares your data for survival analysis. It extracts the time until an event (like readmission or death) and whether that event actually happened for each patient from the df_nodum dataset. Then, it automatically creates a set of important time points, called an ‘evaluation grid’, which are specific moments to assess the model’s performance on both readmission and death outcomes.

Code
import numpy as np

# Required columns for survival outcomes
required = ["readmit_time_from_disch_m", "readmit_event",
            "death_time_from_disch_m", "death_event"]

# Check that df_raw has all required columns
missing = [c for c in required if c not in df_raw.columns]
if missing:
    raise KeyError(f"df_nodum is missing columns: {missing}")

# Create time/event arrays directly from df_raw
time_readm = df_raw["readmit_time_from_adm_m"].to_numpy()
event_readm = (df_raw["readmit_event"].to_numpy() == 1)

time_death = df_raw["death_time_from_adm_m"].to_numpy()
event_death = (df_nodum["death_event"].to_numpy() == 1)

print("Arrays created for df_raw:")
print("Readmission times:", time_readm[:5])
print("Readmission events:", event_readm[:5])
print("Death times:", time_death[:5])
print("Death events:", event_death[:5])

# Build evaluation grids (quantiles of event times)
event_times_readm = time_readm[event_readm]
event_times_death = time_death[event_death]

if len(event_times_readm) < 5 or len(event_times_death) < 5:
    raise ValueError("Too few events in df_raw to build reliable time grids.")

times_eval_readm = np.unique(np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50)))
times_eval_death = np.unique(np.quantile(event_times_death, np.linspace(0.05, 0.95, 50)))

print("Eval times (readmission):", times_eval_readm[:5], "...", times_eval_readm[-5:])
print("Eval times (death):", times_eval_death[:5], "...", times_eval_death[-5:])
Arrays created for df_raw:
Readmission times: [84.93548387 12.83333333 13.73333333 11.96666667 14.25806452]
Readmission events: [False  True  True  True  True]
Death times: [ 84.93548387  87.16129032 117.22580645  98.93548387  37.93548387]
Death events: [False False False False False]
Eval times (readmission): [3.93548387 4.77419355 5.45058701 6.06492649 6.67741935] ... [54.44173469 58.41566162 63.23333333 68.54767171 74.68983871]
Eval times (death): [4.16290323 5.43022383 6.68564845 8.24254115 9.77961817] ... [81.92700461 85.41186103 88.78518762 93.5538183  99.21935484]

Prepare survival data

Code
import numpy as np

# Step 1. Extract survival outcomes directly from df_raw
time_readm = df_raw["readmit_time_from_adm_m"].to_numpy()
event_readm = (df_raw["readmit_event"].to_numpy() == 1)

time_death = df_raw["death_time_from_adm_m"].to_numpy()
event_death = (df_raw["death_event"].to_numpy() == 1)

# Step 2. Build structured arrays (Surv objects)
y_surv_readm = np.empty(len(time_readm), dtype=[("event", "?"), ("time", "<f8")])
y_surv_readm["event"] = event_readm
y_surv_readm["time"] = time_readm

y_surv_death = np.empty(len(time_death), dtype=[("event", "?"), ("time", "<f8")])
y_surv_death["event"] = event_death
y_surv_death["time"] = time_death

# Step 3. Replicate across imputations
n_imputations = len(imputations_list_jan26)
y_surv_readm_list = [y_surv_readm for _ in range(n_imputations)]
y_surv_death_list = [y_surv_death for _ in range(n_imputations)]

PyCox

We tuned a DeepHit competing-risks survival model for mortality and hospital readmission using five-fold stratified cross-validation and selected hyperparameters based on Uno’s inverse probability–weighted concordance index for mortality, treating readmission as a competing event.

We implemented a stratified hyperparameter tuning procedure for a DeepHit competing-risks survival model, jointly modeling mortality and hospital readmission. Model training utilized 5-fold stratified cross-validation over the first imputed dataset. To ensure robust generalization across rare events and heterogeneous treatment modalities, stratified sampling was performed using composite labels representing the intersection of event type (mortality, readmission, or censoring) and care plan. Continuous covariates were standardized within training folds to prevent data leakage, and event times were discretized into 100 intervals using a data-driven cut-point transformation.

Rather than manual tuning, a systematic grid search was employed to identify the optimal configuration of Hyperparameters - Learning rate (LR): how fast the model updates its weights - Weight decay: (reg) L2 regularization to reduce overfitting - Batch size: Samples per training step (large batches help with rare events) - Dropout: Fraction of neurons randomly dropped during training - Nodes: Network depth and width - Alpha: Strength of the ranking loss (time ordering) - Sigma: Smoothing for the ranking loss - num_durations: Fixed on 100 equidistant intervals (1 month per bin approx.)

A shared neural network architecture with cause-specific output heads was optimized over a grid of learning rates, weight decay regularization, network depth, and ranking-loss parameters (α,σ). Model discrimination was assessed using Uno’s Inverse Probability of Censoring Weighted (IPCW) C-index. To avoid overfitting to a single time point or outcome, the final hyperparameter configuration was selected based on the maximum composite C-index, averaged across both competing risks (mortality and readmission) and five annual evaluation horizons (1–5 years). All configurations and metrics were logged to ensure reproducibility.

Code
#@title ⚡ Step 1: DeepHit Tuning (5-Fold, Thermal Safe, Multi-Horizon)
import itertools
import gc
import time
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchtuples as tt
import random
import os
from datetime import datetime
from pycox.models import DeepHit
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_ipcw

# --- CONFIGURATION ---
NUM_RISKS = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Evaluation Horizons: We average performance over years 1-5 
# This is more robust than a single month, but cleaner than averaging all 108 months.
EVAL_HORIZONS = [12, 24, 36, 48, 60] 

# Suppress harmless PyTorch warnings
warnings.filterwarnings("ignore", message=".*weights_only=False.*")

# --- 0. REPRODUCIBILITY SEED ---
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# --- HELPER: Network Wrapper ---
class CauseSpecificNet(nn.Module):
    def __init__(self, in_f, nodes, out_f, dropout, num_risks):
        super().__init__()
        self.net = tt.practical.MLPVanilla(in_f, nodes, out_f, batch_norm=True, dropout=dropout)
        self.num_risks = num_risks
    def forward(self, x):
        return self.net(x).view(x.size(0), self.num_risks, -1)

# --- 1. DATA PREP ---
def prepare_stratified_data(df_idx=0):
    df = imputations_list_jan26[df_idx]
    y_d = y_surv_death_list[df_idx]
    y_r = y_surv_readm_list[df_idx]

    t_d = y_d['time'].values if hasattr(y_d['time'], 'values') else y_d['time']
    e_d_raw = y_d['event'].values if hasattr(y_d['event'], 'values') else y_d['event']
    e_r_raw = y_r['event'].values if hasattr(y_r['event'], 'values') else y_r['event']

    events = np.zeros(len(df), dtype=int)
    times = t_d.copy().astype('float32')
    e_d = e_d_raw.astype(bool)
    e_r = e_r_raw.astype(bool)

    # 🟢 Competing Risk Priority: Death (1) overrides Readm (2)
    events[e_r] = 2
    events[e_d] = 1 

    # Stratification Logic
    plan_cols = ['plan_type_corr_pg_pr', 'plan_type_corr_m_pr', 
                 'plan_type_corr_pg_pai', 'plan_type_corr_m_pai']
    available_plans = [c for c in plan_cols if c in df.columns]

    plan_category = np.zeros(len(df), dtype=int)
    for i, col in enumerate(available_plans, 1):
        plan_category[df[col] == 1] = i

    strat_labels = (events * 10) + plan_category
    return df, events, times, strat_labels

# --- 2. EXECUTION ---
X_all, events_all, times_all, strat_labels = prepare_stratified_data()
start_time = time.time()

# Updated Search Space
param_grid = {
    'lr': [1e-3, 1e-4], #(learning rate): Controls the step size of parameter updates during optimization; lower values promote more stable convergence, particularly important for rare-event survival outcomes.
    'weight_decay': [1e-4, 1e-3], #L2 regularization applied to network weights to reduce overfitting by penalizing large parameter values.
    'batch_size': [1024, 2048], #Number of observations processed per training step; larger batches improve gradient stability and event representation in datasets with low event incidence.
    'dropout': [0.2, 0.5], #Proportion of neurons randomly deactivated during training to prevent overfitting and improve generalization.
    'nodes': [[256, 256], [256, 256, 128]], #Defines the number and size of hidden layers in the neural network, controlling model capacity and representational complexity.
    'alpha': [0.2, 0.5], #Weight of the ranking loss component in DeepHit, regulating the emphasis on temporal risk discrimination across individuals.
    'sigma': [0.1, 0.5] #Smoothing parameter for the ranking loss that controls tolerance to small temporal ordering errors in event times.
}

keys, values = zip(*param_grid.items())
search_space = [dict(zip(keys, v)) for v in itertools.product(*values)]
tuning_results = []

print(f"⚡ Starting Defensible Tuning on {len(search_space)} combos...")

for i, params in enumerate(search_space):
    # 🟢 CHANGED: 5 Folds as requested
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_scores = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_all, strat_labels)):
        torch.cuda.empty_cache()
        gc.collect()

        X_train, X_val = X_all.iloc[train_idx], X_all.iloc[val_idx]
        e_train, e_val = events_all[train_idx], events_all[val_idx]
        t_train, t_val = times_all[train_idx], times_all[val_idx]

        scaler = StandardScaler().fit(X_train)
        X_train_s = scaler.transform(X_train).astype('float32')
        X_val_s = scaler.transform(X_val).astype('float32')

        # 100 intervals balance resolution & stability
        labtrans = LabTransDiscreteTime(100)
        y_train = labtrans.fit_transform(t_train, e_train)
        y_val = labtrans.transform(t_val, e_val)
        y_train = (y_train[0].astype('int64'), y_train[1].astype('int64'))
        y_val = (y_val[0].astype('int64'), y_val[1].astype('int64'))

        in_f = X_train.shape[1]
        out_f = labtrans.out_features * NUM_RISKS

        net = CauseSpecificNet(in_f, params['nodes'], out_f, params['dropout'], NUM_RISKS)
        
        # Single Initialization (Corrected)
        model = DeepHit(net, tt.optim.Adam, 
                        alpha=params['alpha'], 
                        sigma=params['sigma'], 
                        duration_index=labtrans.cuts)        
        
        model.set_device(DEVICE)
        model.optimizer.set_lr(params['lr'])
        model.optimizer.param_groups[0]['weight_decay'] = params['weight_decay']

        try:
            model.fit(X_train_s, y_train, batch_size=params['batch_size'], epochs=50,
                      callbacks=[tt.callbacks.EarlyStopping()], verbose=False, val_data=(X_val_s, y_val))
            
            cif = model.predict_cif(X_val_s)
            
            # 🟢 UPDATED: Evaluate across 5 key years (12-60m) and average
            horizon_scores = []
            for h in EVAL_HORIZONS:
                idx_h = np.searchsorted(model.duration_index, h)
                if idx_h >= len(model.duration_index): idx_h = len(model.duration_index) - 1
                
                score_d = cif[1][idx_h, :] # Death
                score_r = cif[2][idx_h, :] # Readm

                y_tr_st = np.array([(bool(e==1), t) for e, t in zip(e_train, t_train)], dtype=[('e', bool), ('t', float)])
                y_va_st_d = np.array([(bool(e==1), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])
                y_va_st_r = np.array([(bool(e==2), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])

                c_d = concordance_index_ipcw(y_tr_st, y_va_st_d, score_d, tau=h)[0]
                c_r = concordance_index_ipcw(y_tr_st, y_va_st_r, score_r, tau=h)[0]
                horizon_scores.append((c_d + c_r) / 2)
            
            fold_scores.append(np.mean(horizon_scores))

        except Exception as e:
            # print(f"Fold Error: {e}") 
            fold_scores.append(np.nan)

        # Cleanup per fold
        del model; del net; gc.collect()

        # 🟢 THERMAL PAUSE
        # If the GPU temperature is high, enforce a short real cooldown period
        # This allows the cooling system to reduce core temperature and
        # helps prevent thermal throttling or unstable training behavior
        # (We run this after every fold to stay safe)
        print("❄️", end="") 
        time.sleep(30) 

    avg_s = np.nanmean(fold_scores)
    tuning_results.append({**params, 'score': avg_s})
    print(f"   [{i+1}/{len(search_space)}] Avg C-Index (1-5yr): {avg_s:.4f}")

# --- 3. RESULTS & EXPORT ---
results_df = pd.DataFrame(tuning_results).sort_values('score', ascending=False)
best_params = results_df.iloc[0].to_dict()

print("\n" + "="*60)
print(f"🏆 Best Config: {best_params}")
print("="*60)

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"DH_Tuning_5Fold_{timestamp}.csv"
results_df.to_csv(filename, index=False)
print(f"💾 Results saved to: {filename}")

elapsed_minutes = (time.time() - start_time) / 60
print(f"⏱️ Total Time: {elapsed_minutes:.2f} min")
⚡ Starting Defensible Tuning on 128 combos...
❄️❄️❄️❄️❄️   [1/128] Avg C-Index (1-5yr): 0.6977
❄️❄️❄️❄️❄️   [2/128] Avg C-Index (1-5yr): 0.6521
❄️❄️❄️❄️❄️   [3/128] Avg C-Index (1-5yr): 0.6722
❄️❄️❄️❄️❄️   [4/128] Avg C-Index (1-5yr): 0.6449
❄️❄️❄️❄️❄️   [5/128] Avg C-Index (1-5yr): 0.7061
❄️❄️❄️❄️❄️   [6/128] Avg C-Index (1-5yr): 0.6637
❄️❄️❄️❄️❄️   [7/128] Avg C-Index (1-5yr): 0.6920
❄️❄️❄️❄️❄️   [8/128] Avg C-Index (1-5yr): 0.6539
❄️❄️❄️❄️❄️   [9/128] Avg C-Index (1-5yr): 0.7183
❄️❄️❄️❄️❄️   [10/128] Avg C-Index (1-5yr): 0.7145
❄️❄️❄️❄️❄️   [11/128] Avg C-Index (1-5yr): 0.7111
❄️❄️❄️❄️❄️   [12/128] Avg C-Index (1-5yr): 0.7067
❄️❄️❄️❄️❄️   [13/128] Avg C-Index (1-5yr): 0.7281
❄️❄️❄️❄️❄️   [14/128] Avg C-Index (1-5yr): 0.7266
❄️❄️❄️❄️❄️   [15/128] Avg C-Index (1-5yr): 0.7264
❄️❄️❄️❄️❄️   [16/128] Avg C-Index (1-5yr): 0.7265
❄️❄️❄️❄️❄️   [17/128] Avg C-Index (1-5yr): 0.6943
❄️❄️❄️❄️❄️   [18/128] Avg C-Index (1-5yr): 0.6527
❄️❄️❄️❄️❄️   [19/128] Avg C-Index (1-5yr): 0.6734
❄️❄️❄️❄️❄️   [20/128] Avg C-Index (1-5yr): 0.6171
❄️❄️❄️❄️❄️   [21/128] Avg C-Index (1-5yr): 0.6977
❄️❄️❄️❄️❄️   [22/128] Avg C-Index (1-5yr): 0.6643
❄️❄️❄️❄️❄️   [23/128] Avg C-Index (1-5yr): 0.6717
❄️❄️❄️❄️❄️   [24/128] Avg C-Index (1-5yr): 0.6497
❄️❄️❄️❄️❄️   [25/128] Avg C-Index (1-5yr): 0.7218
❄️❄️❄️❄️❄️   [26/128] Avg C-Index (1-5yr): 0.7121
❄️❄️❄️❄️❄️   [27/128] Avg C-Index (1-5yr): 0.7132
❄️❄️❄️❄️❄️   [28/128] Avg C-Index (1-5yr): 0.7050
❄️❄️❄️❄️❄️   [29/128] Avg C-Index (1-5yr): 0.7276
❄️❄️❄️❄️❄️   [30/128] Avg C-Index (1-5yr): 0.7221
❄️❄️❄️❄️❄️   [31/128] Avg C-Index (1-5yr): 0.7305
❄️❄️❄️❄️❄️   [32/128] Avg C-Index (1-5yr): 0.7186
❄️❄️❄️❄️❄️   [33/128] Avg C-Index (1-5yr): 0.7212
❄️❄️❄️❄️❄️   [34/128] Avg C-Index (1-5yr): 0.7022
❄️❄️❄️❄️❄️   [35/128] Avg C-Index (1-5yr): 0.7124
❄️❄️❄️❄️❄️   [36/128] Avg C-Index (1-5yr): 0.7100
❄️❄️❄️❄️❄️   [37/128] Avg C-Index (1-5yr): 0.7243
❄️❄️❄️❄️❄️   [38/128] Avg C-Index (1-5yr): 0.7183
❄️❄️❄️❄️❄️   [39/128] Avg C-Index (1-5yr): 0.7197
❄️❄️❄️❄️❄️   [40/128] Avg C-Index (1-5yr): 0.7097
❄️❄️❄️❄️❄️   [41/128] Avg C-Index (1-5yr): 0.7164
❄️❄️❄️❄️❄️   [42/128] Avg C-Index (1-5yr): 0.7032
❄️❄️❄️❄️❄️   [43/128] Avg C-Index (1-5yr): 0.7299
❄️❄️❄️❄️❄️   [44/128] Avg C-Index (1-5yr): 0.7310
❄️❄️❄️❄️❄️   [45/128] Avg C-Index (1-5yr): 0.7218
❄️❄️❄️❄️❄️   [46/128] Avg C-Index (1-5yr): 0.7227
❄️❄️❄️❄️❄️   [47/128] Avg C-Index (1-5yr): 0.7300
❄️❄️❄️❄️❄️   [48/128] Avg C-Index (1-5yr): 0.7372
❄️❄️❄️❄️❄️   [49/128] Avg C-Index (1-5yr): 0.7149
❄️❄️❄️❄️❄️   [50/128] Avg C-Index (1-5yr): 0.7077
❄️❄️❄️❄️❄️   [51/128] Avg C-Index (1-5yr): 0.7095
❄️❄️❄️❄️❄️   [52/128] Avg C-Index (1-5yr): 0.7011
❄️❄️❄️❄️❄️   [53/128] Avg C-Index (1-5yr): 0.7105
❄️❄️❄️❄️❄️   [54/128] Avg C-Index (1-5yr): 0.7056
❄️❄️❄️❄️❄️   [55/128] Avg C-Index (1-5yr): 0.7123
❄️❄️❄️❄️❄️   [56/128] Avg C-Index (1-5yr): 0.7022
❄️❄️❄️❄️❄️   [57/128] Avg C-Index (1-5yr): 0.7225
❄️❄️❄️❄️❄️   [58/128] Avg C-Index (1-5yr): 0.7092
❄️❄️❄️❄️❄️   [59/128] Avg C-Index (1-5yr): 0.7240
❄️❄️❄️❄️❄️   [60/128] Avg C-Index (1-5yr): 0.7274
❄️❄️❄️❄️❄️   [61/128] Avg C-Index (1-5yr): 0.7223
❄️❄️❄️❄️❄️   [62/128] Avg C-Index (1-5yr): 0.7243
❄️❄️❄️❄️❄️   [63/128] Avg C-Index (1-5yr): 0.7341
❄️❄️❄️❄️❄️   [64/128] Avg C-Index (1-5yr): 0.7344
❄️❄️❄️❄️❄️   [65/128] Avg C-Index (1-5yr): 0.6774
❄️❄️❄️❄️❄️   [66/128] Avg C-Index (1-5yr): 0.6427
❄️❄️❄️❄️❄️   [67/128] Avg C-Index (1-5yr): 0.6562
❄️❄️❄️❄️❄️   [68/128] Avg C-Index (1-5yr): 0.6190
❄️❄️❄️❄️❄️   [69/128] Avg C-Index (1-5yr): 0.6784
❄️❄️❄️❄️❄️   [70/128] Avg C-Index (1-5yr): 0.6295
❄️❄️❄️❄️❄️   [71/128] Avg C-Index (1-5yr): 0.6541
❄️❄️❄️❄️❄️   [72/128] Avg C-Index (1-5yr): 0.6087
❄️❄️❄️❄️❄️   [73/128] Avg C-Index (1-5yr): 0.6827
❄️❄️❄️❄️❄️   [74/128] Avg C-Index (1-5yr): 0.6418
❄️❄️❄️❄️❄️   [75/128] Avg C-Index (1-5yr): 0.6560
❄️❄️❄️❄️❄️   [76/128] Avg C-Index (1-5yr): 0.6270
❄️❄️❄️❄️❄️   [77/128] Avg C-Index (1-5yr): 0.6517
❄️❄️❄️❄️❄️   [78/128] Avg C-Index (1-5yr): 0.6167
❄️❄️❄️❄️❄️   [79/128] Avg C-Index (1-5yr): 0.6204
❄️❄️❄️❄️❄️   [80/128] Avg C-Index (1-5yr): 0.5881
❄️❄️❄️❄️❄️   [81/128] Avg C-Index (1-5yr): 0.6415
❄️❄️❄️❄️❄️   [82/128] Avg C-Index (1-5yr): 0.5711
❄️❄️❄️❄️❄️   [83/128] Avg C-Index (1-5yr): 0.6103
❄️❄️❄️❄️❄️   [84/128] Avg C-Index (1-5yr): 0.5779
❄️❄️❄️❄️❄️   [85/128] Avg C-Index (1-5yr): 0.6494
❄️❄️❄️❄️❄️   [86/128] Avg C-Index (1-5yr): 0.5649
❄️❄️❄️❄️❄️   [87/128] Avg C-Index (1-5yr): 0.6164
❄️❄️❄️❄️❄️   [88/128] Avg C-Index (1-5yr): 0.5761
❄️❄️❄️❄️❄️   [89/128] Avg C-Index (1-5yr): 0.6635
❄️❄️❄️❄️❄️   [90/128] Avg C-Index (1-5yr): 0.5780
❄️❄️❄️❄️❄️   [91/128] Avg C-Index (1-5yr): 0.6252
❄️❄️❄️❄️❄️   [92/128] Avg C-Index (1-5yr): 0.5837
❄️❄️❄️❄️❄️   [93/128] Avg C-Index (1-5yr): 0.6267
❄️❄️❄️❄️❄️   [94/128] Avg C-Index (1-5yr): 0.5435
❄️❄️❄️❄️❄️   [95/128] Avg C-Index (1-5yr): 0.5810
❄️❄️❄️❄️❄️   [96/128] Avg C-Index (1-5yr): 0.5442
❄️❄️❄️❄️❄️   [97/128] Avg C-Index (1-5yr): 0.7006
❄️❄️❄️❄️❄️   [98/128] Avg C-Index (1-5yr): 0.6583
❄️❄️❄️❄️❄️   [99/128] Avg C-Index (1-5yr): 0.6939
❄️❄️❄️❄️❄️   [100/128] Avg C-Index (1-5yr): 0.6670
❄️❄️❄️❄️❄️   [101/128] Avg C-Index (1-5yr): 0.7036
❄️❄️❄️❄️❄️   [102/128] Avg C-Index (1-5yr): 0.6658
❄️❄️❄️❄️❄️   [103/128] Avg C-Index (1-5yr): 0.6911
❄️❄️❄️❄️❄️   [104/128] Avg C-Index (1-5yr): 0.6500
❄️❄️❄️❄️❄️   [105/128] Avg C-Index (1-5yr): 0.6989
❄️❄️❄️❄️❄️   [106/128] Avg C-Index (1-5yr): 0.6592
❄️❄️❄️❄️❄️   [107/128] Avg C-Index (1-5yr): 0.6878
❄️❄️❄️❄️❄️   [108/128] Avg C-Index (1-5yr): 0.6557
❄️❄️❄️❄️❄️   [109/128] Avg C-Index (1-5yr): 0.6707
❄️❄️❄️❄️❄️   [110/128] Avg C-Index (1-5yr): 0.6603
❄️❄️❄️❄️❄️   [111/128] Avg C-Index (1-5yr): 0.6611
❄️❄️❄️❄️❄️   [112/128] Avg C-Index (1-5yr): 0.6362
❄️❄️❄️❄️❄️   [113/128] Avg C-Index (1-5yr): 0.6858
❄️❄️❄️❄️❄️   [114/128] Avg C-Index (1-5yr): 0.6210
❄️❄️❄️❄️❄️   [115/128] Avg C-Index (1-5yr): 0.6527
❄️❄️❄️❄️❄️   [116/128] Avg C-Index (1-5yr): 0.6136
❄️❄️❄️❄️❄️   [117/128] Avg C-Index (1-5yr): 0.6901
❄️❄️❄️❄️❄️   [118/128] Avg C-Index (1-5yr): 0.6321
❄️❄️❄️❄️❄️   [119/128] Avg C-Index (1-5yr): 0.6570
❄️❄️❄️❄️❄️   [120/128] Avg C-Index (1-5yr): 0.6061
❄️❄️❄️❄️❄️   [121/128] Avg C-Index (1-5yr): 0.6759
❄️❄️❄️❄️❄️   [122/128] Avg C-Index (1-5yr): 0.6088
❄️❄️❄️❄️❄️   [123/128] Avg C-Index (1-5yr): 0.6323
❄️❄️❄️❄️❄️   [124/128] Avg C-Index (1-5yr): 0.6001
❄️❄️❄️❄️❄️   [125/128] Avg C-Index (1-5yr): 0.6583
❄️❄️❄️❄️❄️   [126/128] Avg C-Index (1-5yr): 0.6064
❄️❄️❄️❄️❄️   [127/128] Avg C-Index (1-5yr): 0.6242
❄️❄️❄️❄️❄️   [128/128] Avg C-Index (1-5yr): 0.5564

============================================================
🏆 Best Config: {'lr': 0.001, 'weight_decay': 0.001, 'batch_size': 1024, 'dropout': 0.5, 'nodes': [256, 256, 128], 'alpha': 0.5, 'sigma': 0.5, 'score': 0.7371553906507226}
============================================================
💾 Results saved to: DH_Tuning_5Fold_20260208_1324.csv
⏱️ Total Time: 823.12 min

🏆 Best Config: {‘lr’: 0.001, ‘weight_decay’: 0.001, ‘batch_size’: 1024, ‘dropout’: 0.5, ‘nodes’: [256, 256, 128], ‘alpha’: 0.5, ‘sigma’: 0.5, ‘score’: 0.7371553906507226}

⏱️ Total Time: 823.12 min Downloading “DH1_tun_par_20260203_1458.csv”:

Code
#@title 📝 Take-Home Message: Interpretation of Best DeepHit Configuration

import pandas as pd
from IPython.display import display

# --- HYPERPARAMETER INTERPRETATION DATAFRAME ---
config_interpretation = pd.DataFrame([
    {
        'Component': 'Regularization (The "Shield")',
        'Selected Value': 'Dropout: 0.5 | Weight Decay: 0.001 (High)',
        'Interpretation': 'The model required maximum "braking" power. The high dropout indicates the 4% mortality class is noisy; the model forces itself to ignore specific patient details (memorization) and learn only the most robust, universal predictors.'
    },
    {
        'Component': 'Model Capacity (Architecture)',
        'Selected Value': 'Nodes: [256, 256, 128] (Deep & Wide)',
        'Interpretation': 'Risk factors are highly non-linear. Simple "older = higher risk" logic is insufficient. The model requires deep layers to capture complex interactions, likely between Substance Type, Treatment Plan, and comorbidities.'
    },
    {
        'Component': 'Loss Function Strategy',
        'Selected Value': 'Alpha: 0.5 | Sigma: 0.5 (Balanced & Soft)',
        'Interpretation': 'The model prioritizes "Ranking" (who dies first) equally with "Timing" (when they die). The higher sigma (0.5) creates a "softer" margin of error, stabilizing the training against the high volume of censored data.'
    },
    {
        'Component': 'Optimization Mechanics',
        'Selected Value': 'Batch: 1024 | LR: 0.001 (Slow & Steady)',
        'Interpretation': 'Large batches are critical for rare events. A batch of 1024 ensures the model sees enough death cases in every single update to learn effectively, preventing "empty learning" steps common with smaller batches.'
    }
])

# --- DISPLAY ---
print("\n>>> TAKE-HOME MESSAGE: WHY THIS CONFIGURATION WON")
pd.set_option('display.max_colwidth', None)
display(config_interpretation)

>>> TAKE-HOME MESSAGE: WHY THIS CONFIGURATION WON
Component Selected Value Interpretation
0 Regularization (The "Shield") Dropout: 0.5 | Weight Decay: 0.001 (High) The model required maximum "braking" power. The high dropout indicates the 4% mortality class is noisy; the model forces itself to ignore specific patient details (memorization) and learn only the most robust, universal predictors.
1 Model Capacity (Architecture) Nodes: [256, 256, 128] (Deep & Wide) Risk factors are highly non-linear. Simple "older = higher risk" logic is insufficient. The model requires deep layers to capture complex interactions, likely between Substance Type, Treatment Plan, and comorbidities.
2 Loss Function Strategy Alpha: 0.5 | Sigma: 0.5 (Balanced & Soft) The model prioritizes "Ranking" (who dies first) equally with "Timing" (when they die). The higher sigma (0.5) creates a "softer" margin of error, stabilizing the training against the high volume of censored data.
3 Optimization Mechanics Batch: 1024 | LR: 0.001 (Slow & Steady) Large batches are critical for rare events. A batch of 1024 ensures the model sees enough death cases in every single update to learn effectively, preventing "empty learning" steps common with smaller batches.

Given that the initial hyperparameter sweep identified optimal values at the upper boundaries of the search space—specifically favoring maximum regularization (dropout 0.5, weight decay 0.001) and the deepest available network architecture—we conducted a secondary, targeted ‘zoom-in’ grid search. This follow-up analysis extended the search range to explore stronger regularization (dropout 0.6, weight decay 0.01) and increased model capacity (up to 512 nodes) to determine if the performance plateau had truly been reached. Parameters that demonstrated stability in the initial phase (batch size, alpha, and sigma) were fixed to their winning values, concentrating computational resources on fine-tuning the critical balance between model complexity and overfitting.

Code
#@title ⚡ Step 1.5: Targeted "Zoom-In" Tuning (Pushing Boundaries)
import itertools
import gc
import time
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torchtuples as tt
import random
import os
from datetime import datetime
from pycox.models import DeepHit
from pycox.preprocessing.label_transforms import LabTransDiscreteTime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_ipcw

# --- CONFIGURATION ---
NUM_RISKS = 3
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
EVAL_HORIZONS = [12, 24, 36, 48, 60] 

warnings.filterwarnings("ignore", message=".*weights_only=False.*")

# --- REPRODUCIBILITY SEED ---
def set_seed(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

set_seed(42)

# --- NETWORK WRAPPER ---
class CauseSpecificNet(nn.Module):
    def __init__(self, in_f, nodes, out_f, dropout, num_risks):
        super().__init__()
        self.net = tt.practical.MLPVanilla(in_f, nodes, out_f, batch_norm=True, dropout=dropout)
        self.num_risks = num_risks
    def forward(self, x):
        return self.net(x).view(x.size(0), self.num_risks, -1)

# --- DATA PREP ---
def prepare_stratified_data(df_idx=0):
    df = imputations_list_jan26[df_idx]
    y_d = y_surv_death_list[df_idx]
    y_r = y_surv_readm_list[df_idx]

    t_d = y_d['time'].values if hasattr(y_d['time'], 'values') else y_d['time']
    e_d_raw = y_d['event'].values if hasattr(y_d['event'], 'values') else y_d['event']
    e_r_raw = y_r['event'].values if hasattr(y_r['event'], 'values') else y_r['event']

    events = np.zeros(len(df), dtype=int)
    times = t_d.copy().astype('float32')
    e_d = e_d_raw.astype(bool)
    e_r = e_r_raw.astype(bool)

    events[e_r] = 2
    events[e_d] = 1 

    plan_cols = ['plan_type_corr_pg_pr', 'plan_type_corr_m_pr', 
                 'plan_type_corr_pg_pai', 'plan_type_corr_m_pai']
    available_plans = [c for c in plan_cols if c in df.columns]

    plan_category = np.zeros(len(df), dtype=int)
    for i, col in enumerate(available_plans, 1):
        plan_category[df[col] == 1] = i

    strat_labels = (events * 10) + plan_category
    return df, events, times, strat_labels

# --- EXECUTION ---
X_all, events_all, times_all, strat_labels = prepare_stratified_data()
start_time = time.time()

# 🚀 TARGETED SEARCH SPACE (Based on Previous Winners)
param_grid = {
    # Test slightly higher LR vs current winner (0.001)
    'lr': [0.001, 0.003], 
    
    # Test stronger regularization vs current winner (0.001)
    'weight_decay': [0.001, 0.01],
    
    # Test extreme dropout vs current winner (0.5)
    'dropout': [0.5, 0.6], 
    
    # Test wider capacity vs current winner ([256, 256, 128])
    'nodes': [[256, 256, 128], [512, 256, 128]], 
    
    # Fixed best performers from previous run to save time
    'batch_size': [1024], 
    'alpha': [0.5], 
    'sigma': [0.5] 
}

keys, values = zip(*param_grid.items())
search_space = [dict(zip(keys, v)) for v in itertools.product(*values)]
tuning_results = []

print(f"⚡ Starting Targeted 'Zoom-In' Tuning on {len(search_space)} combos...")

for i, params in enumerate(search_space):
    skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)
    fold_scores = []

    for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_all, strat_labels)):
        torch.cuda.empty_cache()
        gc.collect()

        X_train, X_val = X_all.iloc[train_idx], X_all.iloc[val_idx]
        e_train, e_val = events_all[train_idx], events_all[val_idx]
        t_train, t_val = times_all[train_idx], times_all[val_idx]

        scaler = StandardScaler().fit(X_train)
        X_train_s = scaler.transform(X_train).astype('float32')
        X_val_s = scaler.transform(X_val).astype('float32')

        labtrans = LabTransDiscreteTime(100)
        y_train = labtrans.fit_transform(t_train, e_train)
        y_val = labtrans.transform(t_val, e_val)
        y_train = (y_train[0].astype('int64'), y_train[1].astype('int64'))
        y_val = (y_val[0].astype('int64'), y_val[1].astype('int64'))

        in_f = X_train.shape[1]
        out_f = labtrans.out_features * NUM_RISKS

        net = CauseSpecificNet(in_f, params['nodes'], out_f, params['dropout'], NUM_RISKS)
        model = DeepHit(net, tt.optim.Adam, alpha=params['alpha'], sigma=params['sigma'], duration_index=labtrans.cuts)        
        model.set_device(DEVICE)
        model.optimizer.set_lr(params['lr'])
        model.optimizer.param_groups[0]['weight_decay'] = params['weight_decay']

        try:
            model.fit(X_train_s, y_train, batch_size=params['batch_size'], epochs=50,
                      callbacks=[tt.callbacks.EarlyStopping()], verbose=False, val_data=(X_val_s, y_val))
            
            cif = model.predict_cif(X_val_s)
            
            # Evaluate across 5 key years (12-60m) and average
            horizon_scores = []
            for h in EVAL_HORIZONS:
                idx_h = np.searchsorted(model.duration_index, h)
                if idx_h >= len(model.duration_index): idx_h = len(model.duration_index) - 1
                
                score_d = cif[1][idx_h, :] # Death
                score_r = cif[2][idx_h, :] # Readm

                y_tr_st = np.array([(bool(e==1), t) for e, t in zip(e_train, t_train)], dtype=[('e', bool), ('t', float)])
                y_va_st_d = np.array([(bool(e==1), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])
                y_va_st_r = np.array([(bool(e==2), t) for e, t in zip(e_val, t_val)], dtype=[('e', bool), ('t', float)])

                c_d = concordance_index_ipcw(y_tr_st, y_va_st_d, score_d, tau=h)[0]
                c_r = concordance_index_ipcw(y_tr_st, y_va_st_r, score_r, tau=h)[0]
                horizon_scores.append((c_d + c_r) / 2)
            
            fold_scores.append(np.mean(horizon_scores))

        except Exception as e:
            fold_scores.append(np.nan)

        del model; del net; gc.collect()
        
        # Short pause to prevent thermal throttling
        time.sleep(5) 

    avg_s = np.nanmean(fold_scores)
    tuning_results.append({**params, 'score': avg_s})
    print(f"   [{i+1}/{len(search_space)}] Avg C-Index (1-5yr): {avg_s:.4f}")

# --- RESULTS ---
results_df = pd.DataFrame(tuning_results).sort_values('score', ascending=False)
best_params = results_df.iloc[0].to_dict()

print("\n" + "="*60)
print(f"🏆 Best Zoom-In Config: {best_params}")
print("="*60)

timestamp = datetime.now().strftime("%Y%m%d_%H%M")
filename = f"DH_ZoomIn_Tuning_{timestamp}.csv"
results_df.to_csv(filename, index=False)
print(f"💾 Saved to: {filename}")
⚡ Starting Targeted 'Zoom-In' Tuning on 16 combos...
   [1/16] Avg C-Index (1-5yr): 0.7373
   [2/16] Avg C-Index (1-5yr): 0.7331
   [3/16] Avg C-Index (1-5yr): 0.7321
   [4/16] Avg C-Index (1-5yr): 0.7316
   [5/16] Avg C-Index (1-5yr): 0.7008
   [6/16] Avg C-Index (1-5yr): 0.6936
   [7/16] Avg C-Index (1-5yr): 0.6946
   [8/16] Avg C-Index (1-5yr): 0.6987
   [9/16] Avg C-Index (1-5yr): 0.7145
   [10/16] Avg C-Index (1-5yr): 0.7191
   [11/16] Avg C-Index (1-5yr): 0.7165
   [12/16] Avg C-Index (1-5yr): 0.7106
   [13/16] Avg C-Index (1-5yr): 0.6884
   [14/16] Avg C-Index (1-5yr): 0.6802
   [15/16] Avg C-Index (1-5yr): 0.6863
   [16/16] Avg C-Index (1-5yr): 0.6813

============================================================
🏆 Best Zoom-In Config: {'lr': 0.001, 'weight_decay': 0.001, 'dropout': 0.5, 'nodes': [256, 256, 128], 'batch_size': 1024, 'alpha': 0.5, 'sigma': 0.5, 'score': 0.7372626094515411}
============================================================
💾 Saved to: DH_ZoomIn_Tuning_20260208_1440.csv

To identify the optimal hyperparameter configuration, we conducted a two-stage tuning process using 5-fold stratified cross-validation on the first imputed dataset. An initial coarse grid search explored a broad range of learning rates, regularization strengths (weight decay, dropout), and network architectures. The top-performing configuration from this phase (C-index: 0.737) favored high model capacity and strong regularization, residing at the upper boundaries of the search space. Consequently, a secondary, targeted ‘zoom-in’ search was performed to explore even strictly higher regularization and deeper architectures. This confirmatory step yielded negligible performance gains (C-index: 0.7373 vs. 0.7371), indicating that the model had reached its convergence plateau. The final selected configuration (Learning Rate: 0.001, Weight Decay: 0.001, Dropout: 0.5, Nodes: [256, 256, 128]) was thus confirmed as both robust and optimal for the subsequent 10-fold evaluation.

Back to top