This notebook implements a rigorous survival analysis framework using Cox proportional hazards models with Elastic Net regularization (CoxNet) to predict two critical clinical outcomes for substance use disorder patients: (1) risk of readmission following treatment discharge and (2) risk of mortality. The analysis employs multiple imputation (5 datasets) to handle missing data while maintaining statistical validity, and uses stratified 10-fold cross-validation with Uno’s C-index for robust performance evaluation that properly accounts for censored observations and competing risks (where death precludes readmission). Through systematic hyperparameter tuning of the L1 ratio (balancing Lasso and Ridge penalties) and alpha (penalty strength), the notebook identifies optimal models that reveal fundamental differences between the two prediction tasks: the mortality model demonstrates remarkable temporal stability (maintaining ~0.78 AUC even at 9 years) driven primarily by biological factors like age and alcohol use, while the readmission model shows significant performance degradation over time (AUC dropping from 0.69 at 6 months to 0.56 at 5 years), reflecting the complex interplay of behavioral and social determinants. The analysis concludes with permutation importance rankings to identify the most influential predictors for each outcome and calibration assessments to ensure clinical reliability of risk estimates.
Author
ags
Published
2026-02-13
0. Package loading and installation
Code
#@title 🛠️ Environment Setup & Helper Functions { display-mode: "form" }# 1. Reset environment and Clear Memoryimport gcimport reimport numpy as npimport pandas as pdimport sysimport gcimport subprocessdef clear_workspace():"""Clear user-defined variables safely (like rm(list=ls()) in R)""" globals_copy =list(globals().keys())for name in globals_copy:ifnot name.startswith("_") and name notin ["clear_workspace", "gc", "sys" ]:delglobals()[name] gc.collect()print("🧹 Workspace cleared.")clear_workspace()# Check if we're using the expected environment pathexpected_env_name ="coxnet"python_path = sys.executableif expected_env_name notin python_path:raiseRuntimeError(f"This notebook requires the '{expected_env_name}' Conda environment.\n"f"Current Python path: {python_path}\n"f"Please select the correct interpreter in Positron." )#conda remove -n coxnet --all#conda env create -f "G:\My Drive\Alvacast\SISTRAT 2023\cons\coxnet.yml"#conda activate coxnet# 3. Importsfrom sksurv.metrics import concordance_index_ipcw, brier_score, integrated_brier_scorefrom sksurv.util import Surv# 4. CUSTOM HELPER FUNCTIONS (R-style)def glimpse(df, max_width=80):"""View dataframe structure similar to R's glimpse()"""print(f"Rows: {df.shape[0]} | Columns: {df.shape[1]}")for col in df.columns: dtype = df[col].dtype preview = df[col].astype(str).head(5).tolist() preview_str =", ".join(preview)iflen(preview_str) > max_width: preview_str = preview_str[:max_width] +"..."print(f"{col:<30}{str(dtype):<15}{preview_str}")def tabyl(series):"""Frequency table similar to R's janitor::tabyl()""" counts = series.value_counts(dropna=False) props = series.value_counts(normalize=True, dropna=False)return pd.DataFrame({"value": counts.index,"n": counts.values,"percent": props.values }).sort_values("value")def clean_names(df):"""Clean column names similar to R's janitor::clean_names()""" new_cols = []for col in df.columns: col = col.lower() col = re.sub(r"[^\w]+", "_", col) col = col.strip("_") new_cols.append(col) df.columns = new_colsreturn df# 5. Enable Interactive Tables for better head() visualizationtry:import itables itables.init()print("✅ Interactive tables enabled.")exceptImportError:print("ℹ️ itables not installed — using standard DataFrame display.")gc.collect()print("✅ Environment reset. Libraries installed. Helper functions loaded.")
🧹 Workspace cleared.
ℹ️ itables not installed — using standard DataFrame display.
✅ Environment reset. Libraries installed. Helper functions loaded.
Code
from pathlib import Pathimport sysimport numpy as npimport pandas as pdimport pickle# ---- Fix for NumPy 2.x pickle compatibility ----# NumPy 2 stored objects under numpy._core.*# NumPy 1.26 uses numpy.core.*try:import numpy.core.numeric sys.modules["numpy._core.numeric"] = numpy.core.numericexceptException:pass# -----------------------------------------------BASE_DIR = Path(r"G:\My Drive\Alvacast\SISTRAT 2023\data\20241015_out\pred1")withopen(BASE_DIR /"imputations_list_jan26.pkl", "rb") as f: imputations_list_jan26 = pickle.load(f)# Parquet files (these are safe)imputation_nodum_1 = pd.read_parquet( BASE_DIR /"imputation_nondum_1.parquet")X_reduced_imp0 = pd.read_parquet( BASE_DIR /"X_reduced_imp0.parquet")imputation_1 = pd.read_parquet( BASE_DIR /"imputation_1.parquet")
Imports the pickle library: This library implements binary protocols for serializing and de-serializing a Python object structure.
Specifies the file_path: It points to the .pkl file you selected.
Opens the file in binary read mode ('rb'): This is necessary for loading pickle files.
Loads the object: pickle.load(f) reads the serialized object from the file and reconstructs it in memory.
Prints confirmation and basic information: It verifies that the file was loaded and shows the type of the loaded object, and some details about the first element if it’s a list containing common data structures.
Code
# Inspect columns of the first imputationcols_first_imp = imputations_list_jan26[0].columns.tolist()print("First imputation columns:", cols_first_imp[:10], "... total:", len(cols_first_imp))# Inspect columns of imputation_no_dumcols_nodum = imputation_nodum_1.columns.tolist()print("No-dum columns:", cols_nodum[:10], "... total:", len(cols_nodum))# Compare overlapcommon_cols =set(cols_first_imp).intersection(cols_nodum)missing_in_imp = [c for c in cols_nodum if c notin cols_first_imp]missing_in_nodum = [c for c in cols_first_imp if c notin cols_nodum]print("Common columns:", len(common_cols))print("Missing in imputations_list_jan26:", missing_in_imp)
# Inspect columns of the first imputationcols_first_imp_raw = imputation_1.columns.tolist()print("First imputation columns:", cols_first_imp_raw[:10], "... total:", len(cols_first_imp_raw))# Compare overlapcommon_cols_raw =set(cols_first_imp_raw).intersection(cols_nodum)missing_in_imp_raw = [c for c in cols_nodum if c notin cols_first_imp_raw]print("Common columns:", len(common_cols_raw))print("Missing in imputations_list_jan26:", missing_in_imp_raw)print(common_cols_raw)
import pandas as pd# Example: choose a combination of variables that uniquely identify rowskey_vars = ["adm_age_rec3", "porc_pobr", "dit_m"]# Take one imputation (first element of the list) and merge with the no-dum datasetdf_imp = imputations_list_jan26[0]df_nodum = imputation_nodum_1merged_check = pd.merge( df_imp[key_vars], df_nodum[key_vars], on=key_vars, how="inner")print(f"Merged rows: {merged_check.shape[0]}")print("Preview of merged check:")print(merged_check.head())#drop mergedel merged_check
import pandas as pd# Example: choose a combination of variables that uniquely identify rowskey_vars_raw = ['dit_m','readmit_time_from_adm_m','death_time_from_adm_m','adm_age_rec3']# Take one imputation (first element of the list) and merge with the no-dum datasetdf_raw = imputation_1merged_check_raw = pd.merge( df_imp[key_vars], df_raw[key_vars], on=key_vars, how="inner")print(f"Merged rows: {merged_check_raw.shape[0]}")print("Preview of merged check:")print(merged_check_raw.head())print(f"{(merged_check_raw.shape[0] / imputation_1.shape[0] *100):.2f}%")#drop mergedel merged_check_raw
This code prepares your data for survival analysis. It extracts the time until an event (like readmission or death) and whether that event actually happened for each patient from the df_nodum dataset. Then, it automatically creates a set of important time points, called an ‘evaluation grid’, which are specific moments to assess the model’s performance on both readmission and death outcomes.
Code
import numpy as np# Required columns for survival outcomesrequired = ["readmit_time_from_disch_m", "readmit_event","death_time_from_disch_m", "death_event"]# Check that df_raw has all required columnsmissing = [c for c in required if c notin df_raw.columns]if missing:raiseKeyError(f"df_nodum is missing columns: {missing}")# Create time/event arrays directly from df_rawtime_readm = df_raw["readmit_time_from_adm_m"].to_numpy()event_readm = (df_raw["readmit_event"].to_numpy() ==1)time_death = df_raw["death_time_from_adm_m"].to_numpy()event_death = (df_nodum["death_event"].to_numpy() ==1)print("Arrays created for df_raw:")print("Readmission times:", time_readm[:5])print("Readmission events:", event_readm[:5])print("Death times:", time_death[:5])print("Death events:", event_death[:5])# Build evaluation grids (quantiles of event times)event_times_readm = time_readm[event_readm]event_times_death = time_death[event_death]iflen(event_times_readm) <5orlen(event_times_death) <5:raiseValueError("Too few events in df_raw to build reliable time grids.")times_eval_readm = np.unique(np.quantile(event_times_readm, np.linspace(0.05, 0.95, 50)))times_eval_death = np.unique(np.quantile(event_times_death, np.linspace(0.05, 0.95, 50)))print("Eval times (readmission):", times_eval_readm[:5], "...", times_eval_readm[-5:])print("Eval times (death):", times_eval_death[:5], "...", times_eval_death[-5:])
“Best predictors” (variable importance) based on discrimination
Inside each imputed dataset, we run k-fold CV, fit Coxnet on the training folds, and compute Uno’s C-index on the test folds.
For each fold, we computed permutation importance by shuffling one predictor at a time in the test set, recomputing the C-index, and measuring the drop.
We then pooled all these drops across folds and imputations, so mean_drop_cindex summarized how much that predictor hurts out-of-sample C-index on average, while respecting both multiple imputation and cross-validation.
Sorting by mean_drop_cindex and taking the top 20 output the most influential predictors in a way that is robust to missing data and optimistic bias.
Correction for inmortal time bias
First, we eliminated inmortal time bias (dead patients look like without readmission).
This correction is essentially the Cause-Specific Hazard preparation. It is the correct way to handle Aim 3 unless you switch to a Fine-Gray model (which treats death as a specific type of event 2, rather than censoring 0). For RSF/Coxnet, censoring 0 is the correct approach.
Code
import numpy as np# Step 1. Extract survival outcomes directly from df_rawtime_readm = df_raw["readmit_time_from_adm_m"].to_numpy()event_readm = (df_raw["readmit_event"].to_numpy() ==1)time_death = df_raw["death_time_from_adm_m"].to_numpy()event_death = (df_raw["death_event"].to_numpy() ==1)# Step 2. Build structured arrays (Surv objects)y_surv_readm = np.empty(len(time_readm), dtype=[("event", "?"), ("time", "<f8")])y_surv_readm["event"] = event_readmy_surv_readm["time"] = time_readmy_surv_death = np.empty(len(time_death), dtype=[("event", "?"), ("time", "<f8")])y_surv_death["event"] = event_deathy_surv_death["time"] = time_death# Step 3. Replicate across imputationsn_imputations =len(imputations_list_jan26)y_surv_readm_list = [y_surv_readm for _ inrange(n_imputations)]y_surv_death_list = [y_surv_death for _ inrange(n_imputations)]import numpy as npdef correct_competing_risks(X_list, y_readm_list, y_death_list):""" Adjust survival outcomes for competing risks (death vs. readmission). Parameters ---------- X_list : list of pd.DataFrame Imputed predictor datasets (same rows across imputations). y_readm_list : list of structured arrays Surv(event, time) arrays for readmission. y_death_list : list of structured arrays Surv(event, time) arrays for death. Returns ------- y_readm_corrected_list : list of structured arrays Corrected readmission outcomes (death treated as censoring). """ corrected = []for y_readm, y_death inzip(y_readm_list, y_death_list): y_corr = y_readm.copy()# If patient died before readmission → censor at death timefor i inrange(len(y_corr)):if y_death["event"][i] and y_death["time"][i] < y_corr["time"][i]: y_corr["event"][i] =False y_corr["time"][i] = y_death["time"][i] corrected.append(y_corr)return corrected# Step 4. Apply correctiony_surv_readm_list_corrected = correct_competing_risks( imputations_list_jan26, y_surv_readm_list, y_surv_death_list)
Code
# Check type and lengthtype(y_surv_readm_list_corrected), len(y_surv_readm_list_corrected)# Look at the first elementy_surv_readm_list_corrected[0][:5] # first 5 rowsneg_times = (y_surv_death_list[0]["time"] <0).sum()zero_times = (y_surv_death_list[0]["time"] ==0).sum()print(f"Negative survival times: {neg_times}")print(f"Zero survival times: {zero_times}")
**Set time=1e-5** for time<=0 events to stop division-by-zero crashes.
Use Matrix Prediction (predict(X)), not loops (predict(X, alpha=a)), for speed.
Force common Alpha Grid to make results comparable across CV folds.
Stratify by Competing Risks (Death vs Readmission) to balance test sets.
Merge Rare Strata to prevent “Class not in fold” errors during splitting.
Fallback to Harrell’s C if Uno’s IPCW fails due to censoring distributions.
Pool Imputations to average out noise from missing data handling.
Reshape 1D Arrays from .predict to avoid indexing errors when path collapses.
**Disable fit_baseline** during tuning loops to save computational time.
Catch specific errors per fold to prevent one failure from crashing the whole job.
Code
import numpy as npimport pandas as pdimport timefrom sklearn.model_selection import StratifiedKFoldfrom sksurv.linear_model import CoxnetSurvivalAnalysisfrom sksurv.metrics import concordance_index_ipcwfrom joblib import Parallel, delayeddef tune_coxnet_mi_stratified_cv_uno_only( X_list, y_surv_readm_list, y_surv_death_list, l1_ratios=(0.1, 0.5, 0.7, 0.9, 0.95, 1.0), n_alphas=100, alpha_min_ratio=0.01, n_splits=10, random_state=2125, max_iter=100000, n_jobs=-2,):ifnot (len(X_list) ==len(y_surv_readm_list) ==len(y_surv_death_list)):raiseValueError("X_list, y_surv_readm_list, y_surv_death_list must have same length.") n_imputations =len(X_list)print(f"Starting Stratified Tuning (Uno's Only): {n_imputations} imputations, {n_splits}-fold CV...")# ---------- Step 0: Safety fix for time <= 0 ---------- y_readm_safe, y_death_safe = [], []for i inrange(n_imputations): y_r = y_surv_readm_list[i].copy() y_d = y_surv_death_list[i].copy()# Cox models & IPCW metrics crash if time <= 0. Set to epsilon. m_r = y_r["time"] <=0if np.any(m_r): y_r["time"][m_r] =1e-5 m_d = y_d["time"] <=0if np.any(m_d): y_d["time"][m_d] =1e-5 y_readm_safe.append(y_r) y_death_safe.append(y_d)# ---------- Step 1: Common alpha grid ---------- X_sample = X_list[0] y_sample = y_readm_safe[0]# Fit dummy model to get grid dummy_model = CoxnetSurvivalAnalysis( l1_ratio=1.0, n_alphas=n_alphas, alpha_min_ratio=alpha_min_ratio, fit_baseline_model=False, ).fit(X_sample, y_sample) common_alphas = dummy_model.alphas_print(f" > Alpha grid established: {len(common_alphas)} alphas.")# ---------- Step 2: Stratification (Competing Risk + Plan) ---------- e_r, t_r = y_readm_safe[0]["event"], y_readm_safe[0]["time"] e_d, t_d = y_death_safe[0]["event"], y_death_safe[0]["time"]# 1=DeathFirst, 2=ReadmFirst, 0=Censored events_cr = np.zeros(len(e_r), dtype=int) events_cr[e_d & (~e_r | (t_d < t_r))] =1 events_cr[e_r & (~e_d | (t_r < t_d))] =2# Initialize with 0 (This automatically captures the 'pg-pab' reference group) plan_idx = np.zeros(len(X_sample), dtype=int)# Overwrite the index for the 4 explicit dummy variablesif"plan_type_corr_m-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pr"] ==1] =1if"plan_type_corr_pg-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pai"] ==1] =2if"plan_type_corr_pg-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pr"] ==1] =3if"plan_type_corr_m-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pai"] ==1] =4 strat_labels = (events_cr *10) + plan_idx# Merge rare groups to prevent 'n_splits' errors counts = pd.Series(strat_labels).value_counts() rare_groups = counts[counts < n_splits].indexfor g in rare_groups: strat_labels[strat_labels == g] =0 skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) cv_splits =list(skf.split(X_sample, strat_labels))print(f" > Stratification successful: {len(np.unique(strat_labels))} groups.")# ---------- Step 3: Worker Function ----------def evaluate_fold(l1_ratio, fold_idx, train_idx, test_idx, imp_idx):try: X_curr = X_list[imp_idx] y_curr = y_readm_safe[imp_idx] X_train, X_test = X_curr.iloc[train_idx], X_curr.iloc[test_idx] y_train, y_test = y_curr[train_idx], y_curr[test_idx] model = CoxnetSurvivalAnalysis( l1_ratio=l1_ratio, alphas=common_alphas, normalize=False, fit_baseline_model=False, max_iter=max_iter, ) model.fit(X_train, y_train)# full score vector aligned to common_alphas fold_scores = np.full(len(common_alphas), np.nan, dtype=float)# score each alpha explicitly (Uno only)for a in model.alphas_:try:# map model alpha -> common alpha index idx = np.where(np.isclose(common_alphas, a, rtol=1e-10, atol=1e-12))[0]iflen(idx) ==0:continue idx =int(idx[0]) risk = model.predict(X_test, alpha=float(a))# stabilize Uno by clipping tau to train/test overlap tau =min(float(np.max(y_train["time"])), float(np.max(y_test["time"])))if tau <=0:continue c_idx = concordance_index_ipcw(y_train, y_test, risk, tau=tau)[0] fold_scores[idx] =float(c_idx)exceptException:# Uno-only mode: keep NaNpassreturn fold_scores.tolist(), NoneexceptExceptionas e:return [np.nan] *len(common_alphas), f"l1={l1_ratio}, fold={fold_idx}, imp={imp_idx}: {repr(e)}"# ---------- Step 4: Parallel run ---------- tasks = [ (l1, fold_i, train_idx, test_idx, imp_i)for l1 in l1_ratiosfor fold_i, (train_idx, test_idx) inenumerate(cv_splits)for imp_i inrange(n_imputations) ]print(f" > Processing {len(tasks)} tasks using {n_jobs} jobs...") out = Parallel(n_jobs=n_jobs)(delayed(evaluate_fold)(*t) for t in tasks) results_flat = [x[0] for x in out] errors = [x[1] for x in out if x[1] isnotNone]# ---------- Step 5: Aggregate ---------- records = [] task_counter =0for l1 in l1_ratios:for _fold_i inrange(n_splits):for _imp_i inrange(n_imputations): scores = results_flat[task_counter] task_counter +=1for alpha_idx, val inenumerate(scores):ifnot np.isnan(val): records.append({"l1_ratio": l1,"alpha_idx": alpha_idx,"alpha": common_alphas[alpha_idx],"c_index": val, })ifnot records:print("\n❌ CRITICAL ERROR: All evaluations failed.")if errors:print("First errors:")for msg in errors[:5]: print(" -", msg)returnNone, None, None results_df = pd.DataFrame(records) tuning_summary = ( results_df.groupby(["l1_ratio", "alpha_idx", "alpha"])["c_index"] .agg(["mean", "std", "count"]) .reset_index() ) best_idx = tuning_summary["mean"].idxmax() best_row = tuning_summary.loc[best_idx]print("\n--- Tuning Complete ---")print(f"Best L1: {best_row['l1_ratio']}")print(f"Best Alpha: {best_row['alpha']:.5f} (Index: {int(best_row['alpha_idx'])})")print(f"Best C-Index: {best_row['mean']:.4f}")return tuning_summary, best_row, common_alphas
Code
start_time = time.time()tuning_results_readm, best_params_readm, common_alphas_readm = tune_coxnet_mi_stratified_cv_uno_only( X_list=imputations_list_jan26, y_surv_readm_list=y_surv_readm_list_corrected, y_surv_death_list=y_surv_death_list, n_jobs=-2, # set n_jobs=1 once if you want easier debugging)print(f"Completed in {(time.time() - start_time)/60:.2f} min")
Starting Stratified Tuning (Uno's Only): 5 imputations, 10-fold CV...
> Alpha grid established: 100 alphas.
> Stratification successful: 3 groups.
> Processing 300 tasks using -2 jobs...
--- Tuning Complete ---
Best L1: 0.1
Best Alpha: 0.00346 (Index: 99)
Best C-Index: 0.6081
Completed in 8.86 min
Code
import seaborn as snsimport matplotlib.pyplot as pltdef plot_tuning_heatmap(tuning_summary):# Pivot: Rows=L1, Cols=Alpha Index, Values=Mean C-Index pivot = tuning_summary.pivot(index="l1_ratio", columns="alpha_idx", values="mean") plt.figure(figsize=(12, 6)) sns.heatmap(pivot, cmap="viridis", annot=False) plt.title("Hyperparameter Performance (Mean C-Index)") plt.xlabel("Alpha Index (Left=High Penalty, Right=Low Penalty)") plt.ylabel("L1 Ratio (Bottom=Ridge, Top=Lasso)") plt.gca().invert_yaxis() # Put L1=1.0 at the top plt.show()plot_tuning_heatmap(tuning_results_readm)
The L1 Ratio= 0.1 (Ridge-Dominant): Only 10% Lasso. Likely data is composed of many correlated predictors with small or cumulative effects on readmission/death.
Alpha= 0.00346 (Index 99): It chose the lowest possible penalty allowed in the grid (min_ratio= 0.01), which can be interpreted as its rejecting penalties on data.
The gold standard (recommended by Hastie & Tibshirani, the creators of ElasticNet) is the 1-SE Rule: - We find the absolute best C-Index. - We calculate its standard error across the CV folds. - We choose the simplest model (the one with the highest alpha / strongest penalty) that falls within 1 Standard Error of the absolute best. This guarantees a parsimonious, robust model that won’t break on new patients.
Code
import numpy as npimport pandas as pdimport timefrom sklearn.model_selection import StratifiedKFoldfrom sksurv.linear_model import CoxnetSurvivalAnalysisfrom sksurv.metrics import concordance_index_ipcwfrom joblib import Parallel, delayeddef tune_coxnet_mi_stratified_cv_clinical( X_list, y_surv_readm_list, y_surv_death_list, l1_ratios=(0.1, 0.5, 0.7, 0.9, 0.95, 1.0), n_alphas=100, alpha_min_ratio=0.001, # <--- Expanded grid to explore lower penalties n_splits=10, random_state=2125, max_iter=100000, n_jobs=-2,):ifnot (len(X_list) ==len(y_surv_readm_list) ==len(y_surv_death_list)):raiseValueError("X_list, y_surv_readm_list, y_surv_death_list must have same length.") n_imputations =len(X_list)print(f"Starting Clinical Tuning (Uno's Only): {n_imputations} imputations, {n_splits}-fold CV...")# ---------- Step 0: Safety fix for time <= 0 ---------- y_readm_safe, y_death_safe = [], []for i inrange(n_imputations): y_r = y_surv_readm_list[i].copy() y_d = y_surv_death_list[i].copy() m_r = y_r["time"] <=0if np.any(m_r): y_r["time"][m_r] =1e-5 m_d = y_d["time"] <=0if np.any(m_d): y_d["time"][m_d] =1e-5 y_readm_safe.append(y_r) y_death_safe.append(y_d)# ---------- Step 1: Common alpha grid ---------- X_sample = X_list[0] y_sample = y_readm_safe[0] dummy_model = CoxnetSurvivalAnalysis( l1_ratio=1.0, n_alphas=n_alphas, alpha_min_ratio=alpha_min_ratio, fit_baseline_model=False, ).fit(X_sample, y_sample) common_alphas = dummy_model.alphas_print(f" > Alpha grid established: {len(common_alphas)} alphas "f"(Max: {common_alphas.max():.4f}, Min: {common_alphas.min():.5f})" )# ---------- Step 2: Stratification (Competing Risk + Plan) ---------- e_r, t_r = y_readm_safe[0]["event"], y_readm_safe[0]["time"] e_d, t_d = y_death_safe[0]["event"], y_death_safe[0]["time"] events_cr = np.zeros(len(e_r), dtype=int) events_cr[e_d & (~e_r | (t_d < t_r))] =1 events_cr[e_r & (~e_d | (t_r < t_d))] =2# Initialize with 0 (This automatically captures the 'pg-pab' reference group) plan_idx = np.zeros(len(X_sample), dtype=int)# Overwrite the index for the 4 explicit dummy variablesif"plan_type_corr_m-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pr"] ==1] =1if"plan_type_corr_pg-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pai"] ==1] =2if"plan_type_corr_pg-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pr"] ==1] =3if"plan_type_corr_m-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pai"] ==1] =4 strat_labels = (events_cr *10) + plan_idx# merge very small strata to avoid split failures counts = pd.Series(strat_labels).value_counts() rare_groups = counts[counts < n_splits].indexfor g in rare_groups: strat_labels[strat_labels == g] =0 skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) cv_splits =list(skf.split(X_sample, strat_labels))print(f" > Stratification successful: {len(np.unique(strat_labels))} groups.")# ---------- Step 3: Worker Function ----------def evaluate_fold(l1_ratio, fold_idx, train_idx, test_idx, imp_idx):try: X_curr = X_list[imp_idx] y_curr = y_readm_safe[imp_idx] X_train, X_test = X_curr.iloc[train_idx], X_curr.iloc[test_idx] y_train, y_test = y_curr[train_idx], y_curr[test_idx] model = CoxnetSurvivalAnalysis( l1_ratio=l1_ratio, alphas=common_alphas, normalize=False, fit_baseline_model=False, max_iter=max_iter, ) model.fit(X_train, y_train)# full score vector aligned to common_alphas fold_scores = np.full(len(common_alphas), np.nan, dtype=float)# Explicit loop requested (safest mapping for dynamic alphas)for a in model.alphas_:try:# map model alpha -> common alpha index idx = np.where(np.isclose(common_alphas, a, rtol=1e-10, atol=1e-12))[0]iflen(idx) ==0:continue idx =int(idx[0]) risk = model.predict(X_test, alpha=float(a))# stabilize Uno by clipping tau to train/test overlap tau =min(float(np.max(y_train["time"])), float(np.max(y_test["time"])))if tau <=0:continue c_idx = concordance_index_ipcw(y_train, y_test, risk, tau=tau)[0] fold_scores[idx] =float(c_idx)exceptException:# Uno-only mode: keep NaNpassreturn fold_scores.tolist(), NoneexceptExceptionas e:return [np.nan] *len(common_alphas), f"l1={l1_ratio}, fold={fold_idx}, imp={imp_idx}: {repr(e)}"# ---------- Step 4: Parallel run ---------- tasks = [ (l1, fold_i, train_idx, test_idx, imp_i)for l1 in l1_ratiosfor fold_i, (train_idx, test_idx) inenumerate(cv_splits)for imp_i inrange(n_imputations) ]print(f" > Processing {len(tasks)} tasks using {n_jobs} jobs...") out = Parallel(n_jobs=n_jobs)(delayed(evaluate_fold)(*t) for t in tasks) results_flat = [x[0] for x in out] errors = [x[1] for x in out if x[1] isnotNone]# ---------- Step 5: Aggregate and Apply 1-SE Rule ---------- records = [] task_counter =0for l1 in l1_ratios:for _fold_i inrange(n_splits):for _imp_i inrange(n_imputations): scores = results_flat[task_counter] task_counter +=1for alpha_idx, val inenumerate(scores):ifnot np.isnan(val): records.append({"l1_ratio": l1,"alpha_idx": alpha_idx,"alpha": common_alphas[alpha_idx],"c_index": val, })ifnot records:print("\n❌ CRITICAL ERROR: All evaluations failed.")if errors:print("First errors:")for msg in errors[:5]: print(" -", msg)returnNone, None, None results_df = pd.DataFrame(records) tuning_summary = ( results_df.groupby(["l1_ratio", "alpha_idx", "alpha"])["c_index"] .agg(["mean", "std", "count"]) .reset_index() )# 1. Absolute Best Model (Max C-Index) best_idx_raw = tuning_summary["mean"].idxmax() best_row_raw = tuning_summary.loc[best_idx_raw]# 2. Standard Error of the mean C-Index across CV folds & imputations max_mean = best_row_raw["mean"] se = best_row_raw["std"] / np.sqrt(best_row_raw["count"]) threshold_1se = max_mean - se# 3. Apply 1-SE Rule: Find highest 'alpha' (most parsimonious/penalized) above threshold candidates_1se = tuning_summary[tuning_summary["mean"] >= threshold_1se] best_row_1se = candidates_1se.sort_values(by="alpha", ascending=False).iloc[0]print("\n--- Clinical Tuning Complete ---")print(f"Absolute Best C-Index: {max_mean:.4f} ± {se:.4f} (L1: {best_row_raw['l1_ratio']}, Alpha: {best_row_raw['alpha']:.5f})")print(f"1-SE Robustness Threshold: {threshold_1se:.4f}")print("\n✓ SELECTED MODEL (1-SE Rule Applied):")print(f"Best L1: {best_row_1se['l1_ratio']}")print(f"Best Alpha: {best_row_1se['alpha']:.5f} (Index: {int(best_row_1se['alpha_idx'])})")print(f"C-Index: {best_row_1se['mean']:.4f}")if best_row_raw['alpha'] != best_row_1se['alpha'] or best_row_raw['l1_ratio'] != best_row_1se['l1_ratio']:print("\nNote: The 1-SE rule chose a simpler, more heavily penalized model to protect SUD patients against overfitting.")return tuning_summary, best_row_1se, common_alphas
Starting Clinical Tuning (Uno's Only): 5 imputations, 10-fold CV...
> Alpha grid established: 100 alphas (Max: 0.3455, Min: 0.00035)
> Stratification successful: 3 groups.
> Processing 300 tasks using -2 jobs...
--- Clinical Tuning Complete ---
Absolute Best C-Index: 0.6100 ± 0.0015 (L1: 0.1, Alpha: 0.00035)
1-SE Robustness Threshold: 0.6085
✓ SELECTED MODEL (1-SE Rule Applied):
Best L1: 0.1
Best Alpha: 0.00280 (Index: 69)
C-Index: 0.6085
Note: The 1-SE rule chose a simpler, more heavily penalized model to protect SUD patients against overfitting.
Completed in 11.49 min
Code
import seaborn as snsimport matplotlib.pyplot as pltplot_tuning_heatmap(tuning_results_readm_post)
Code
import pandas as pdfrom IPython.display import display# --- HYPERPARAMETER SEARCH STRATEGY DATAFRAME ---search_strategy_msg = pd.DataFrame([ {'Component': 'L1 Ratio (`l1_ratio`)','Role': 'ElasticNet Mixing (Ridge vs. Lasso)','Grid Evaluated': '[0.1, 0.5, 0.7, 0.9, 0.95, 1.0]','Rationale': 'Scans the spectrum from Ridge-dominant (shrinks collinear variables equally) to Lasso-dominant (forces extreme sparsity and feature selection).' }, {'Component': 'Penalty Strength (`alpha`)','Role': 'Overall Regularization / Shrinkage','Grid Evaluated': '100 steps (min_ratio=0.001)','Rationale': 'An expanded 100-step path allowing the algorithm to explore very low-penalty zones, ensuring the absolute mathematical peak is captured before applying clinical corrections.' }, {'Component': 'Validation Strategy','Role': 'Test Set Balance & Stability','Grid Evaluated': '10-Fold Stratified CV across 5 Imputations','Rationale': 'Ensures proportional representation of Competing Risks (Death vs. Readmission) across all folds, stabilizing Uno\'s C-index while accounting for missing data uncertainty.' }, {'Component': 'Selection Criteria','Role': 'Overfitting Prevention','Grid Evaluated': '1-Standard Error (1-SE) Rule','Rationale': 'Sacrifices a negligible fraction of training performance (within 1-SE of the absolute peak) to select a simpler, more heavily penalized model that generalizes safely to new clinical populations.' }])# --- TUNED RESULTS & INTERPRETATION DATAFRAME ---tuned_results_msg = pd.DataFrame([ {'Parameter': 'Best L1 Ratio','Winning Value': '0.1 (Ridge-Dominant)','Mathematical Meaning': 'Applies 90% L2 (Ridge) penalty and 10% L1 (Lasso) penalty.','Clinical Interpretation': 'The model avoids dropping variables entirely. SUD readmission is driven by a complex web of cumulative factors rather than a few isolated "magic bullets."' }, {'Parameter': 'Best Alpha','Winning Value': '0.00280 (Index 69)','Mathematical Meaning': 'A moderate penalty selected via the 1-SE rule, stepping back from the absolute lowest penalty (0.00035 at Index 99).','Clinical Interpretation': 'Provides a parsimonious, robust fit. By increasing the penalty 8x from the absolute peak, it aggressively shrinks "noisy" coefficients, protecting future patient predictions from overfitting.' }, {'Parameter': 'Uno\'s C-Index','Winning Value': '0.6085 ± 0.0015','Mathematical Meaning': 'Stable out-of-sample discriminative ability across 50 evaluations (10 folds × 5 imputations).','Clinical Interpretation': 'Demonstrates fair, highly stable discrimination. This is realistic for behavioral health models, where systemic randomness and post-discharge social determinants introduce variance clinical data cannot fully capture.' }])# --- DISPLAY ---print("\n>>> TAKE-HOME MESSAGE: COXNET HYPERPARAMETER TUNING STRATEGY")pd.set_option('display.max_colwidth', None)display(search_strategy_msg.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'background-color': '#f8f9fa','border': '1px solid black'}))print("\n>>> TAKE-HOME MESSAGE: WINNING PARAMETERS (READMISSION)")display(tuned_results_msg.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap','background-color': '#eef6fc','border': '1px solid black'}))
Scans the spectrum from Ridge-dominant (shrinks collinear variables equally) to Lasso-dominant (forces extreme sparsity and feature selection).
1
Penalty Strength (`alpha`)
Overall Regularization / Shrinkage
100 steps (min_ratio=0.001)
An expanded 100-step path allowing the algorithm to explore very low-penalty zones, ensuring the absolute mathematical peak is captured before applying clinical corrections.
2
Validation Strategy
Test Set Balance & Stability
10-Fold Stratified CV across 5 Imputations
Ensures proportional representation of Competing Risks (Death vs. Readmission) across all folds, stabilizing Uno's C-index while accounting for missing data uncertainty.
3
Selection Criteria
Overfitting Prevention
1-Standard Error (1-SE) Rule
Sacrifices a negligible fraction of training performance (within 1-SE of the absolute peak) to select a simpler, more heavily penalized model that generalizes safely to new clinical populations.
Applies 90% L2 (Ridge) penalty and 10% L1 (Lasso) penalty.
The model avoids dropping variables entirely. SUD readmission is driven by a complex web of cumulative factors rather than a few isolated "magic bullets."
1
Best Alpha
0.00280 (Index 69)
A moderate penalty selected via the 1-SE rule, stepping back from the absolute lowest penalty (0.00035 at Index 99).
Provides a parsimonious, robust fit. By increasing the penalty 8x from the absolute peak, it aggressively shrinks "noisy" coefficients, protecting future patient predictions from overfitting.
Demonstrates fair, highly stable discrimination. This is realistic for behavioral health models, where systemic randomness and post-discharge social determinants introduce variance clinical data cannot fully capture.
Death
Code
start_time = time.time()tuning_results_death, best_params_death, common_alphas_death = tune_coxnet_mi_stratified_cv_uno_only( X_list=imputations_list_jan26, y_surv_readm_list=y_surv_death_list, y_surv_death_list=y_surv_readm_list_corrected, n_jobs=-2, # set n_jobs=1 once if you want easier debugging)print(f"Completed in {(time.time() - start_time)/60:.2f} min")
C:\Users\andre\miniconda3\envs\coxnet\Lib\site-packages\joblib\externals\loky\process_executor.py:782: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.
warnings.warn(
--- Tuning Complete ---
Best L1: 0.1
Best Alpha: 0.00344 (Index: 99)
Best C-Index: 0.7439
Completed in 6.35 min
Starting Clinical Tuning (Uno's Only): 5 imputations, 10-fold CV...
> Alpha grid established: 100 alphas (Max: 0.3442, Min: 0.00034)
> Stratification successful: 3 groups.
> Processing 300 tasks using -2 jobs...
--- Clinical Tuning Complete ---
Absolute Best C-Index: 0.7470 ± 0.0023 (L1: 0.5, Alpha: 0.00034)
1-SE Robustness Threshold: 0.7447
✓ SELECTED MODEL (1-SE Rule Applied):
Best L1: 0.1
Best Alpha: 0.00260 (Index: 70)
C-Index: 0.7449
Note: The 1-SE rule chose a simpler, more heavily penalized model to protect SUD patients against overfitting.
Completed in 9.93 min
Code
plot_tuning_heatmap(tuning_results_death_post)
Code
import pandas as pdfrom IPython.display import display# --- TUNED RESULTS & INTERPRETATION DATAFRAME (DEATH) ---tuned_results_death_msg = pd.DataFrame([ {'Parameter': 'Best L1 Ratio','Winning Value': '0.1 (Ridge-Dominant)','Mathematical Meaning': 'Shifted from 0.5 (absolute peak) to 0.1 under the 1-SE rule. Applies 90% L2 (Ridge) penalty.','Clinical Interpretation': 'Mortality is driven by a cumulative burden of many health and demographic factors. Rather than dropping variables (Lasso), the safest clinical model retains most predictors but shrinks their weights evenly to prevent overfitting.' }, {'Parameter': 'Best Alpha','Winning Value': '0.00260 (Index 70)','Mathematical Meaning': 'A heavily penalized model selected via the 1-SE rule, representing a 7.6x increase in penalty over the absolute mathematical peak (0.00034).','Clinical Interpretation': 'A textbook application of parsimony. It sacrifices a statistically negligible 0.0021 in predictive power to aggressively shrink noisy coefficients, ensuring the mortality model generalizes safely to new patients.' }, {'Parameter': 'Uno\'s C-Index','Winning Value': '0.7449 ± 0.0023','Mathematical Meaning': 'Excellent and highly stable out-of-sample discriminative ability.','Clinical Interpretation': 'Confirming clinical intuition: biology is more predictable than behavior. Mortality (0.74) is significantly more predictable than readmission (0.61) because it relies on harder, physiological baseline markers.' }])print("\n>>> TAKE-HOME MESSAGE: WINNING PARAMETERS (TIME TO DEATH)")pd.set_option('display.max_colwidth', None)display(tuned_results_death_msg.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap','background-color': '#fff0f0', # Light red/pink to distinguish from readmission'border': '1px solid black'}))
>>> TAKE-HOME MESSAGE: WINNING PARAMETERS (TIME TO DEATH)
Parameter
Winning Value
Mathematical Meaning
Clinical Interpretation
0
Best L1 Ratio
0.1 (Ridge-Dominant)
Shifted from 0.5 (absolute peak) to 0.1 under the 1-SE rule. Applies 90% L2 (Ridge) penalty.
Mortality is driven by a cumulative burden of many health and demographic factors. Rather than dropping variables (Lasso), the safest clinical model retains most predictors but shrinks their weights evenly to prevent overfitting.
1
Best Alpha
0.00260 (Index 70)
A heavily penalized model selected via the 1-SE rule, representing a 7.6x increase in penalty over the absolute mathematical peak (0.00034).
A textbook application of parsimony. It sacrifices a statistically negligible 0.0021 in predictive power to aggressively shrink noisy coefficients, ensuring the mortality model generalizes safely to new patients.
2
Uno's C-Index
0.7449 ± 0.0023
Excellent and highly stable out-of-sample discriminative ability.
Confirming clinical intuition: biology is more predictable than behavior. Mortality (0.74) is significantly more predictable than readmission (0.61) because it relies on harder, physiological baseline markers.
Permutation importance (based on Uno’s C-index) for a tuned Coxnet model, using Multiple Imputation + Stratified Cross-Validation
How much does each predictor contribute to discrimination of readmission risk, under the tuned penalized Cox model, accounting for missing data and competing risk structure?
We computed permutation importance for a tuned Coxnet survival model.
Using Multiple Imputation to account for missing data.
Applied stratified 10-fold cross-validation.
Preserved competing risk structure across folds.
Used Uno’s C-index for censoring-adjusted discrimination.
Reused the exact tuned alpha and l1_ratio.
Measured importance as C-index drop after permutation.
Repeated permutations to reduce randomness.
Aggregated results across folds and imputations.
Returned mean baseline C-index and feature importance table.
Code
import numpy as npimport pandas as pdfrom sklearn.model_selection import StratifiedKFoldfrom sksurv.linear_model import CoxnetSurvivalAnalysisfrom sksurv.metrics import concordance_index_ipcwfrom joblib import Parallel, delayeddef permutation_importance_cindex_cv_mi_stratified( X_list, y_surv_readm_list, y_surv_death_list, alpha_idx=69, # WINNING ALPHA INDEX l1_ratio=0.1, # WINNING L1 RATIO alpha_min_ratio=0.001, # Must match the expanded grid n_alphas=100, # Must match tuning n_splits=10, # Updated to match 10-fold tuning n_repeats=3, random_state=2125, max_iter=100000, n_jobs=-2,):""" Clinically robust MI + Stratified CV permutation importance for Coxnet. Uses the exact hyperparameter grid and splits from the tuning phase. """ n_imputations =len(X_list) feature_names = X_list[0].columns.tolist() n_features =len(feature_names)print(f"Starting Permutation Importance: {n_imputations} imputations, {n_splits}-fold CV...")print(f"Target Model: L1 Ratio = {l1_ratio}, Alpha Index = {alpha_idx}")# ---------- Step 0: Safety fix for time <= 0 ---------- y_readm_safe, y_death_safe = [], []for i inrange(n_imputations): y_r = y_surv_readm_list[i].copy() y_d = y_surv_death_list[i].copy()if np.any(y_r["time"] <=0): y_r["time"][y_r["time"] <=0] =1e-5if np.any(y_d["time"] <=0): y_d["time"][y_d["time"] <=0] =1e-5 y_readm_safe.append(y_r) y_death_safe.append(y_d)# ---------- Step 1: Recreate Common Alpha Grid ---------- X_sample = X_list[0] y_sample = y_readm_safe[0] dummy_model = CoxnetSurvivalAnalysis( l1_ratio=1.0, n_alphas=n_alphas, alpha_min_ratio=alpha_min_ratio, fit_baseline_model=False ).fit(X_sample, y_sample) common_alphas = dummy_model.alphas_ target_alpha = common_alphas[alpha_idx]print(f" > Target Alpha extracted: {target_alpha:.5f}")# ---------- Step 2: Stratification (Competing Risk + Plan) ---------- e_r, t_r = y_sample["event"], y_sample["time"] e_d, t_d = y_death_safe[0]["event"], y_death_safe[0]["time"] events_cr = np.zeros(len(e_r), dtype=int) events_cr[e_d & (~e_r | (t_d < t_r))] =1 events_cr[e_r & (~e_d | (t_r < t_d))] =2# Initialize with 0 (This automatically captures the 'pg-pab' reference group) plan_idx = np.zeros(len(X_sample), dtype=int) # Overwrite the index for the 4 explicit dummy variablesif"plan_type_corr_m-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pr"] ==1] =1if"plan_type_corr_pg-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pai"] ==1] =2if"plan_type_corr_pg-pr"in X_sample.columns: plan_idx[X_sample["plan_type_corr_pg-pr"] ==1] =3if"plan_type_corr_m-pai"in X_sample.columns: plan_idx[X_sample["plan_type_corr_m-pai"] ==1] =4 strat_labels = (events_cr *10) + plan_idx counts = pd.Series(strat_labels).value_counts() rare_groups = counts[counts < n_splits].indexfor g in rare_groups: strat_labels[strat_labels == g] =0 skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) cv_splits =list(skf.split(X_sample, strat_labels))# Convert X_list to NumPy arrays NOW for fast permutation X_list_np = [X.values.astype(float) for X in X_list]# ---------- Step 3: Worker Function ----------def compute_fold(d, fold_idx, train_idx, test_idx): X_imp = X_list_np[d] X_train, X_test = X_imp[train_idx, :], X_imp[test_idx, :] y_train, y_test = y_readm_safe[d][train_idx], y_readm_safe[d][test_idx] local_rng = np.random.RandomState(random_state + d * n_splits + fold_idx)# Fit model using the exact grid model = CoxnetSurvivalAnalysis( l1_ratio=l1_ratio, alphas=common_alphas, normalize=False, fit_baseline_model=False, max_iter=max_iter ) model.fit(X_train, y_train)# Tau clipping to stabilize Uno's C-index tau =min(float(np.max(y_train["time"])), float(np.max(y_test["time"]))) -1e-7# Baseline Risk & Score risk_baseline = model.predict(X_test, alpha=float(target_alpha))try:if tau <=0: raiseValueError cindex_baseline = concordance_index_ipcw(y_train, y_test, risk_baseline, tau=tau)[0]exceptException:# If baseline crashes, we can't calculate drops for this foldreturn np.nan, [[np.nan] * n_repeats for _ inrange(n_features)]# Permutation drops per feature fold_drops = [[] for _ inrange(n_features)]for col_idx inrange(n_features):for r inrange(n_repeats): X_perm = X_test.copy() X_perm[:, col_idx] = local_rng.permutation(X_perm[:, col_idx]) risk_perm = model.predict(X_perm, alpha=float(target_alpha))try: cindex_perm = concordance_index_ipcw(y_train, y_test, risk_perm, tau=tau)[0] fold_drops[col_idx].append(cindex_baseline - cindex_perm)exceptException: fold_drops[col_idx].append(np.nan)return cindex_baseline, fold_drops# ---------- Step 4: Parallel Execution ---------- tasks = [ (d, fold_idx, train_idx, test_idx)for d inrange(n_imputations)for fold_idx, (train_idx, test_idx) inenumerate(cv_splits) ]print(f" > Processing {len(tasks)} folds for Permutation Importance...") results = Parallel(n_jobs=n_jobs, verbose=5)(delayed(compute_fold)(*t) for t in tasks)# ---------- Step 5: Aggregation ---------- baseline_cindices = [res[0] for res in results ifnot np.isnan(res[0])] global_drops = [[] for _ inrange(n_features)]for res in results:if np.isnan(res[0]): continue# Skip failed folds fold_drops = res[1]for col_idx inrange(n_features): global_drops[col_idx].extend(fold_drops[col_idx]) imp_rows = []for col_idx inrange(n_features):# Drop NaNs that might have occurred during individual permutations arr = np.array(global_drops[col_idx]) arr = arr[~np.isnan(arr)] mean_drop =float(arr.mean()) if arr.size >0else np.nan sd_drop =float(arr.std(ddof=1)) if arr.size >1else0.0 imp_rows.append({"feature": feature_names[col_idx],"mean_drop_cindex": mean_drop,"sd_drop_cindex": sd_drop,"n_evals": int(arr.size), }) df_imp_proc = pd.DataFrame(imp_rows).sort_values("mean_drop_cindex", ascending=False).reset_index(drop=True) baseline_cindex_mean =float(np.mean(baseline_cindices)) if baseline_cindices else np.nan baseline_cindex_sd =float(np.std(baseline_cindices, ddof=1)) iflen(baseline_cindices) >1else0.0print("\n=== Baseline CV Uno C-index over imputations & folds ===")print(f"Mean ± SD: {baseline_cindex_mean:.4f} ± {baseline_cindex_sd:.4f}")return baseline_cindex_mean, baseline_cindex_sd, df_imp_proc
Execute
Code
import time# Start timerstart_time = time.time()baseline_cidx_readm_initial, baseline_cidx_sd_readm_initial, df_imp_readm_initial = ( permutation_importance_cindex_cv_mi_stratified( X_list=imputations_list_jan26, y_surv_readm_list=y_surv_readm_list_corrected, y_surv_death_list=y_surv_death, # ← keep if using competing risk stratification# --- WINNING HYPERPARAMETERS --- alpha_idx=69, l1_ratio=0.1, alpha_min_ratio=0.001, n_alphas=100,# --- CROSS-VALIDATION --- n_splits=10, # Updated to 10-fold n_repeats=30, # 20 permutation repetitions max_iter=100000, n_jobs=-1# or os.cpu_count()-1 for safer parallel ))# End timerend_time = time.time() # Print elapsed time in secondselapsed = end_time - start_timeprint(f"Process completed in {elapsed/60:.2f} minutes")#~3 hrs. 21 min with 3 repeats
Starting Permutation Importance: 5 imputations, 10-fold CV...
Target Model: L1 Ratio = 0.1, Alpha Index = 69
> Target Alpha extracted: 0.00280
> Processing 50 folds for Permutation Importance...
[Parallel(n_jobs=-1)]: Using backend LokyBackend with 32 concurrent workers.
[Parallel(n_jobs=-1)]: Done 9 out of 50 | elapsed: 10.4min remaining: 47.4min
[Parallel(n_jobs=-1)]: Done 20 out of 50 | elapsed: 10.5min remaining: 15.8min
[Parallel(n_jobs=-1)]: Done 31 out of 50 | elapsed: 10.7min remaining: 6.5min
[Parallel(n_jobs=-1)]: Done 42 out of 50 | elapsed: 17.1min remaining: 3.3min
=== Baseline CV Uno C-index over imputations & folds ===
Mean ± SD: 0.6083 ± 0.0090
Process completed in 17.20 minutes
[Parallel(n_jobs=-1)]: Done 50 out of 50 | elapsed: 17.1min finished
import pandas as pdfrom IPython.display import display# Raw data from your Death Permutation rundata_death = [ {"Rank": 1, "Feature": "adm_age_rec3", "Mean Drop": 0.1219, "Category": "Demographic/Bio", "Interpretation": "The dominant driver. Mortality risk is overwhelmingly a function of biological aging."}, {"Rank": 2, "Feature": "primary_sub_mod_alcohol", "Mean Drop": 0.0253, "Category": "Substance", "Interpretation": "Alcohol is the most lethal substance phenotype, likely due to long-term organ damage (cirrhosis, etc.) compared to other drugs."}, {"Rank": 3, "Feature": "any_phys_dx", "Mean Drop": 0.0068, "Category": "Clinical Health", "Interpretation": "Presence of physical comorbidities (e.g., HIV, Hep C, cardiovascular) significantly shortens survival."}, {"Rank": 4, "Feature": "prim_sub_freq_rec", "Mean Drop": 0.0051, "Category": "Substance", "Interpretation": "Frequency of use serves as a proxy for addiction severity and acute toxicity risk."}, {"Rank": 5, "Feature": "eva_ocupacion", "Mean Drop": 0.0049, "Category": "Social/Functional", "Interpretation": "Evaluation of occupational functioning; likely a marker for severe functional impairment."}, {"Rank": 6, "Feature": "occupation_unemployed", "Mean Drop": 0.0037, "Category": "Social Determinant", "Interpretation": "Unemployment acts as a major mortality risk factor (poverty, lack of structure, despair)."}, {"Rank": 7, "Feature": "eva_fisica", "Mean Drop": 0.0033, "Category": "Clinical Health", "Interpretation": "Clinical evaluation of physical status confirms the impact of physical deterioration on death risk."}, {"Rank": 8, "Feature": "occupation_inactive", "Mean Drop": 0.0026, "Category": "Social Determinant", "Interpretation": "Economic inactivity (distinct from unemployment) often correlates with disability or chronic illness."}, {"Rank": 9, "Feature": "tr_outcome_adm_reasons", "Mean Drop": 0.0016, "Category": "System/Outcome", "Interpretation": "Administrative discharge might flag patients who disengage from the safety net."}, {"Rank": 10, "Feature": "cohabitation_family", "Mean Drop": 0.0016, "Category": "Social Support", "Interpretation": "Living situation impacts survival—likely related to isolation vs. support, though directionality needs checking."}, {"Rank": 11, "Feature": "tr_outcome_dropout", "Mean Drop": 0.0012, "Category": "System/Outcome", "Interpretation": "Dropping out of treatment removes the protective factor of care, increasing mortality risk."}]df_death_summary = pd.DataFrame(data_death)# Display with clean stylingprint("\n>>> TAKE-HOME MESSAGE: TOP 11 DRIVERS OF MORTALITY")pd.set_option('display.max_colwidth', None)display(df_death_summary.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap','background-color': '#fff0f0', # Light red for Death'border': '1px solid #dcdcdc'}).hide(axis='index'))
>>> TAKE-HOME MESSAGE: TOP 11 DRIVERS OF MORTALITY
Rank
Feature
Mean Drop
Category
Interpretation
1
adm_age_rec3
0.121900
Demographic/Bio
The dominant driver. Mortality risk is overwhelmingly a function of biological aging.
2
primary_sub_mod_alcohol
0.025300
Substance
Alcohol is the most lethal substance phenotype, likely due to long-term organ damage (cirrhosis, etc.) compared to other drugs.
3
any_phys_dx
0.006800
Clinical Health
Presence of physical comorbidities (e.g., HIV, Hep C, cardiovascular) significantly shortens survival.
4
prim_sub_freq_rec
0.005100
Substance
Frequency of use serves as a proxy for addiction severity and acute toxicity risk.
5
eva_ocupacion
0.004900
Social/Functional
Evaluation of occupational functioning; likely a marker for severe functional impairment.
6
occupation_unemployed
0.003700
Social Determinant
Unemployment acts as a major mortality risk factor (poverty, lack of structure, despair).
7
eva_fisica
0.003300
Clinical Health
Clinical evaluation of physical status confirms the impact of physical deterioration on death risk.
8
occupation_inactive
0.002600
Social Determinant
Economic inactivity (distinct from unemployment) often correlates with disability or chronic illness.
9
tr_outcome_adm_reasons
0.001600
System/Outcome
Administrative discharge might flag patients who disengage from the safety net.
10
cohabitation_family
0.001600
Social Support
Living situation impacts survival—likely related to isolation vs. support, though directionality needs checking.
11
tr_outcome_dropout
0.001200
System/Outcome
Dropping out of treatment removes the protective factor of care, increasing mortality risk.
Landmark
Code
# Evluation time pointstimes_eval_grid = np.array([3, # 3 months6, # 6 months12, # 1 year36, # 3 years48,60, # 5 years72,84,96,108# 10 years])# Filtrar solo tiempos dentro de tu rango de datosmax_time = np.max([y['time'].max() for y in y_surv_death_list])times_eval_grid = times_eval_grid[times_eval_grid <= max_time]
Updated function
Optimism bias fixed: Thresholds are now learned strictly on training folds via Youden’s J, then applied to unseen test folds.
Code
import numpy as npimport pandas as pdfrom sklearn.model_selection import StratifiedKFoldfrom sksurv.linear_model import CoxnetSurvivalAnalysisfrom sksurv.metrics import concordance_index_ipcw, cumulative_dynamic_auc, brier_scorefrom sklearn.metrics import confusion_matrix, f1_score, roc_curvefrom joblib import Parallel, delayeddef time_specific_performance_stratified_mi( X_list, y_surv_main_list, # The Target Outcome (e.g., Readmission) y_surv_comp_list, # The Competing Event (e.g., Death) - Needed for Stratification times_eval, # List of time points (e.g., [30, 90, 180]) alpha_idx=69, # WINNING ALPHA (Index) l1_ratio=0.1, # WINNING L1 RATIO alpha_min_ratio=0.001, # Must match tuning grid n_alphas=100, # Must match tuning grid n_splits=10, # 10-Fold CV n_repeats=3, # Permutation repeats random_state=2125, max_iter=100000, n_jobs=-2,):""" Time-Specific Performance & Importance with Stratified Competing Risk CV. Calculates AUC, Brier Score, PPV, NPV, Sensitivity, Specificity at specific horizons. """print(f"Starting Time-Specific Evaluation: {len(times_eval)} time points, {n_splits}-fold Stratified CV...")# --- Step 0: Pre-process Inputs --- feature_names = X_list[0].columns.tolist() n_features =len(feature_names) X_list_np = [X.values.astype(float) for X in X_list] # Convert to numpy for speed n_imputations =len(X_list)# Filter valid evaluation times max_t = np.max([y['time'].max() for y in y_surv_main_list]) times_eval = np.array([t for t in times_eval if t < max_t]) n_times =len(times_eval)print(f" > Valid Evaluation Times: {times_eval}")# --- Step 1: Safety Fix (Time <= 0) --- y_main_safe, y_comp_safe = [], []for i inrange(n_imputations): y_m = y_surv_main_list[i].copy() y_c = y_surv_comp_list[i].copy()# Fix 0.0 times to prevent crashesif np.any(y_m["time"] <=0): y_m["time"][y_m["time"] <=0] =1e-5if np.any(y_c["time"] <=0): y_c["time"][y_c["time"] <=0] =1e-5 y_main_safe.append(y_m) y_comp_safe.append(y_c)# --- Step 2: Recreate Alpha Grid --- dummy = CoxnetSurvivalAnalysis(l1_ratio=1.0, n_alphas=n_alphas, alpha_min_ratio=alpha_min_ratio, fit_baseline_model=False) dummy.fit(X_list_np[0], y_main_safe[0]) common_alphas = dummy.alphas_ target_alpha = common_alphas[alpha_idx]print(f" > Target Alpha: {target_alpha:.5f} (Index {alpha_idx})")# --- Step 3: Stratification Logic (Target + Competing Risk + Plan) --- y_samp = y_main_safe[0] y_comp = y_comp_safe[0] X_samp = X_list[0] # Need DataFrame for column access# Define Event Type (1=Comp Risk First, 2=Target First, 0=Censored) events_cr = np.zeros(len(y_samp), dtype=int) events_cr[y_comp["event"] & (~y_samp["event"] | (y_comp["time"] < y_samp["time"]))] =1 events_cr[y_samp["event"] & (~y_comp["event"] | (y_samp["time"] < y_comp["time"]))] =2# Define Plan Type plan_idx = np.zeros(len(X_samp), dtype=int) # Default 0 = pg-pabif"plan_type_corr_m-pr"in X_samp.columns: plan_idx[X_samp["plan_type_corr_m-pr"] ==1] =1if"plan_type_corr_pg-pai"in X_samp.columns: plan_idx[X_samp["plan_type_corr_pg-pai"] ==1] =2if"plan_type_corr_pg-pr"in X_samp.columns: plan_idx[X_samp["plan_type_corr_pg-pr"] ==1] =3if"plan_type_corr_m-pai"in X_samp.columns: plan_idx[X_samp["plan_type_corr_m-pai"] ==1] =4 strat_labels = (events_cr *10) + plan_idx# Merge rare groups counts = pd.Series(strat_labels).value_counts()for g in counts[counts < n_splits].index: strat_labels[strat_labels == g] =0 skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) cv_splits =list(skf.split(X_samp, strat_labels))print(f" > Stratification successful: {len(np.unique(strat_labels))} groups.")# --- Step 4: Worker Function ---def compute_fold(d, fold_idx, train_idx, test_idx): X_train, X_test = X_list_np[d][train_idx], X_list_np[d][test_idx] y_train, y_test = y_main_safe[d][train_idx], y_main_safe[d][test_idx]# Fit Model model = CoxnetSurvivalAnalysis( l1_ratio=l1_ratio, alphas=common_alphas, normalize=False, fit_baseline_model=True, # Required for survival function max_iter=max_iter, verbose=False ) model.fit(X_train, y_train)# 1. Linear Predictor (Risk Score)# We need TRAIN scores to find the threshold (Learn)# We need TEST scores to evaluate performance (Apply) risk_train = model.predict(X_train, alpha=float(target_alpha)) risk_test = model.predict(X_test, alpha=float(target_alpha))# 2. Survival Function -> For Brier Score only surv_funcs = model.predict_survival_function(X_test, alpha=float(target_alpha)) surv_probs = np.row_stack([fn(times_eval) for fn in surv_funcs]) # Shape: (n_test, n_times)# --- Metrics ---# A. Global C-Index (Stabilized with Tau)try: tau =min(y_train["time"].max(), y_test["time"].max()) -1e-7 c_idx = concordance_index_ipcw(y_train, y_test, risk_test, tau=tau)[0]except: c_idx = np.nan# B. Time-Specific Metrics auc_scores, bs_scores = [], [] class_metrics = []for t_idx, t inenumerate(times_eval):# 1. AUC & Brier (Threshold Independent)try: auc, _ = cumulative_dynamic_auc(y_train, y_test, risk_test, times=[t]) auc_scores.append(auc[0])except: auc_scores.append(np.nan)try: surv_prob_t = surv_probs[:, t_idx].reshape(-1, 1) _, bs = brier_score(y_train, y_test, surv_prob_t, times=[t]) bs_scores.append(bs[0])except: bs_scores.append(np.nan)# 2. Classification (PPV/NPV) - STRICT FIX# A. Find Optimal Threshold on TRAIN set is_case_train = (y_train['event']) & (y_train['time'] <= t) is_control_train = (y_train['time'] > t) mask_train = is_case_train | is_control_train # Exclude censored optimal_thresh =0.5# Fallbackif mask_train.sum() >10: y_bin_train = is_case_train[mask_train].astype(int) scores_train_valid = risk_train[mask_train]iflen(np.unique(y_bin_train)) >1: fpr, tpr, thresholds = roc_curve(y_bin_train, scores_train_valid)# Youden's J = TPR - FPR idx = np.argmax(tpr - fpr) optimal_thresh = thresholds[idx]# B. Apply Fixed Threshold to TEST set is_case_test = (y_test['event']) & (y_test['time'] <= t) is_control_test = (y_test['time'] > t) valid_mask_test = is_case_test | is_control_test metrics_t = {k: np.nan for k in ['ppv', 'npv', 'sens', 'spec', 'f1']}if valid_mask_test.sum() >5: y_binary = is_case_test[valid_mask_test].astype(int) scores_test_valid = risk_test[valid_mask_test]# PREDICT CLASS using TRAIN THRESHOLD y_pred = (scores_test_valid >= optimal_thresh).astype(int) tn, fp, fn, tp = confusion_matrix(y_binary, y_pred, labels=[0,1]).ravel() metrics_t['ppv'] = tp / (tp+fp) if (tp+fp) >0else0.0 metrics_t['npv'] = tn / (tn+fn) if (tn+fn) >0else0.0 metrics_t['sens'] = tp / (tp+fn) if (tp+fn) >0else0.0 metrics_t['spec'] = tn / (tn+fp) if (tn+fp) >0else0.0 metrics_t['f1'] = f1_score(y_binary, y_pred) class_metrics.append(metrics_t)# C. Permutation Drops (for Feature Importance) local_rng = np.random.RandomState(random_state + d * n_splits + fold_idx) fold_drops = [[] for _ inrange(n_features)]for col_idx inrange(n_features):for r inrange(n_repeats): X_perm = X_test.copy() X_perm[:, col_idx] = local_rng.permutation(X_perm[:, col_idx]) risk_perm = model.predict(X_perm, alpha=float(target_alpha))try: res_perm = concordance_index_ipcw(y_train, y_test, risk_perm, tau=tau)[0] fold_drops[col_idx].append(c_idx - res_perm)except: fold_drops[col_idx].append(0.0)return {'cindex': c_idx,'auc_scores': auc_scores,'bs_scores': bs_scores,'class_metrics': class_metrics,'fold_drops': fold_drops }# --- Step 5: Execute Parallel --- tasks = [ (d, fold_idx, train_idx, test_idx)for d inrange(n_imputations)for fold_idx, (train_idx, test_idx) inenumerate(cv_splits) ]print(f" > Processing {len(tasks)} folds...") results = Parallel(n_jobs=n_jobs)(delayed(compute_fold)(*t) for t in tasks)# --- Step 6: Aggregate Results ---# Global C-Index c_indices = [r['cindex'] for r in results ifnot np.isnan(r['cindex'])] c_mean = np.mean(c_indices) c_sd = np.std(c_indices, ddof=1)# Time-Specific Aggregation time_rows = [] keys = ['ppv', 'npv', 'sens', 'spec', 'f1']for t_i, t inenumerate(times_eval): aucs = [r['auc_scores'][t_i] for r in results ifnot np.isnan(r['auc_scores'][t_i])] bss = [r['bs_scores'][t_i] for r in results ifnot np.isnan(r['bs_scores'][t_i])] row = {'Time (Months)': t,'AUC Mean': np.mean(aucs), 'AUC SD': np.std(aucs),'Brier Mean': np.mean(bss), 'Brier SD': np.std(bss) }# Classification Metricsfor k in keys: vals = [r['class_metrics'][t_i][k] for r in results ifnot np.isnan(r['class_metrics'][t_i][k])] row[f'{k.upper()} Mean'] = np.mean(vals) if vals else np.nan row[f'{k.upper()} SD'] = np.std(vals) iflen(vals)>1else0.0 time_rows.append(row) df_time = pd.DataFrame(time_rows)# Feature Importance Aggregation global_drops = [[] for _ inrange(n_features)]for r in results:for c_i, drops inenumerate(r['fold_drops']): global_drops[c_i].extend(drops) imp_rows = []for c_i inrange(n_features): arr = np.array(global_drops[c_i]) imp_rows.append({'Feature': feature_names[c_i],'Mean Drop': arr.mean(),'SD Drop': arr.std() }) df_imp = pd.DataFrame(imp_rows).sort_values('Mean Drop', ascending=False)print(f"\n>>> GLOBAL C-INDEX: {c_mean:.4f} ± {c_sd:.4f}")return df_time, df_imp
Code
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport seaborn as sns# ==================== 1. FEATURE IMPORTANCE PLOT ====================def plot_feature_importance(df_imp, top_n=15, figsize=(10, 8), title="Feature Importance"):""" Visualizes Permutation Feature Importance (Mean Drop in C-Index). """# Prepare data df_plot = df_imp.head(top_n).copy() df_plot = df_plot.sort_values('Mean Drop', ascending=True) # Sort for barh fig, ax = plt.subplots(figsize=figsize)# Plot Bars bars = ax.barh( y=np.arange(len(df_plot)), width=df_plot['Mean Drop'], xerr=df_plot['SD Drop'], color='#4c72b0', # Steelblue alpha=0.8, capsize=5, height=0.6 )# Labels & Formatting ax.set_yticks(np.arange(len(df_plot))) ax.set_yticklabels(df_plot['Feature'], fontsize=11) ax.set_xlabel('Mean Drop in C-Index (Impact)', fontsize=12, fontweight='bold') ax.set_title(f'{title} (Top {top_n})', fontsize=14, fontweight='bold')# Add value labelsfor i, v inenumerate(df_plot['Mean Drop']): ax.text(v + (df_plot['Mean Drop'].max()*0.02), i, f"+{v:.4f}", va='center', fontsize=9, fontweight='bold', color='black') ax.grid(axis='x', alpha=0.3, linestyle='--') plt.tight_layout() plt.show()# ==================== 2. TEMPORAL PERFORMANCE (AUC vs BRIER) ====================def plot_temporal_performance(df_time, figsize=(12, 6)):""" Dual-axis plot: AUC (Discrimination) vs Brier Score (Calibration) over time. """ fig, ax1 = plt.subplots(figsize=figsize) times = df_time['Time (Months)']# --- LEFT AXIS: AUC (Higher is better) --- color_auc ='#2ca02c'# Green ln1 = ax1.plot(times, df_time['AUC Mean'], marker='o', color=color_auc, linewidth=2.5, label='AUC (Discrimination)')# Error bands for AUC ax1.fill_between(times, df_time['AUC Mean'] - df_time['AUC SD'], df_time['AUC Mean'] + df_time['AUC SD'], color=color_auc, alpha=0.15) ax1.set_xlabel('Time Horizon (Months)', fontsize=12, fontweight='bold') ax1.set_ylabel('AUC Score', color=color_auc, fontsize=12, fontweight='bold') ax1.tick_params(axis='y', labelcolor=color_auc) ax1.set_ylim(0.5, 1.0) # AUC range ax1.set_xticks(times)# --- RIGHT AXIS: Brier Score (Lower is better) --- ax2 = ax1.twinx() color_bs ='#d62728'# Red ln2 = ax2.plot(times, df_time['Brier Mean'], marker='s', color=color_bs, linewidth=2.5, linestyle='--', label='Brier Score (Calibration)')# Error bands for Brier ax2.fill_between(times, df_time['Brier Mean'] - df_time['Brier SD'], df_time['Brier Mean'] + df_time['Brier SD'], color=color_bs, alpha=0.15) ax2.set_ylabel('Brier Score (Error)', color=color_bs, fontsize=12, fontweight='bold') ax2.tick_params(axis='y', labelcolor=color_bs) ax2.set_ylim(0, 0.25) # Typical Brier range# Title & Legend plt.title('Model Performance Over Time: Discrimination vs. Calibration', fontsize=14, fontweight='bold')# Combined Legend lns = ln1 + ln2 labs = [l.get_label() for l in lns] ax1.legend(lns, labs, loc='center right') ax1.grid(True, axis='x', alpha=0.3) plt.tight_layout() plt.show()# ==================== 3. CLINICAL UTILITY METRICS (PPV/NPV) ====================def plot_clinical_utility(df_time, figsize=(12, 6)):""" Plots PPV, NPV, Sensitivity, and Specificity to show clinical trade-offs. """ fig, ax = plt.subplots(figsize=figsize) times = df_time['Time (Months)']# Metrics to plot (Key in DF, Label, Color, Marker) metrics = [ ('NPV Mean', 'NPV (Safety)', 'green', 'o'), ('PPV Mean', 'PPV (Precision)', 'blue', '^'), ('SENS Mean', 'Sensitivity', 'orange', 's'), ('SPEC Mean', 'Specificity', 'purple', 'D') ]for col, label, color, marker in metrics:if col in df_time.columns:# Main Line ax.plot(times, df_time[col], marker=marker, color=color, linewidth=2, label=label, alpha=0.85)# Error Bands (infer SD column name) sd_col = col.replace('Mean', 'SD')if sd_col in df_time.columns: ax.fill_between(times, np.maximum(0, df_time[col] - df_time[sd_col]), np.minimum(1, df_time[col] + df_time[sd_col]), color=color, alpha=0.1)# Formatting ax.set_ylim(0, 1.05) ax.set_xticks(times) ax.set_xlabel('Time Horizon (Months)', fontsize=12, fontweight='bold') ax.set_ylabel('Probability / Rate', fontsize=12, fontweight='bold') ax.set_title('Clinical Utility: Safety (NPV) vs. Precision (PPV)', fontsize=14, fontweight='bold') ax.axhline(0.5, color='gray', linestyle=':', alpha=0.5) ax.legend(loc='lower left', bbox_to_anchor=(0, 0), fontsize=10, ncol=2) ax.grid(True, alpha=0.3) plt.tight_layout() plt.show()
Readmission
Code
# Define evaluation times (e.g., 3 months, 6 months, 1 year, 2 years)# Note: Input is usually in the same unit as your 'time' column (assuming months here)df_performance, df_importance = time_specific_performance_stratified_mi( X_list = imputations_list_jan26, y_surv_main_list = y_surv_readm_list_corrected, # Target (Readmission) y_surv_comp_list = y_surv_death_list, # Competing Risk (Death) times_eval = times_eval_grid, alpha_idx =69, # WINNING READMISSION ALPHA l1_ratio =0.1, # WINNING READMISSION L1 alpha_min_ratio =0.001, n_splits =10, n_jobs =-2)# Display Time-Specific Performancepd.set_option('display.max_columns', None)display(df_performance.round(3))
High-Grade Safety Monitor. Misses some acute events but highly reliable for ruling out risk.
Medium Term (1 Year)
0.690
96.3%
Good. Stronger than global average.
0.793
99.6%
Excellent. Risk remains highly detectable.
Long Term (3-5 Years)
0.634
85.1%
Fair. Baseline data loses relevance.
0.781
97.4%
Robust. Biological risks (Age) persist.
Very Long Term (9 Years)
0.560
34.0%
Poor. Driven by new life events.
0.781
88.6%
Stable. Mortality risk does not degrade.
Code
import pandas as pdfrom IPython.display import display# --- 1. PERFORMANCE REALITY CHECK ---performance_final = pd.DataFrame([ {'Outcome': 'Readmission (6 Mo)','Old Sens (Optimistic)': '70.5%','New Sens (Strict)': '66.9%','Change': '-3.6%','Interpretation': 'Robust. The signal is stable. The model reliably identifies 2/3rds of readmissions without seeing the test data.' }, {'Outcome': 'Death (3 Mo)','Old Sens (Optimistic)': '92.0%','New Sens (Strict)': '70.2%','Change': '-21.8%','Interpretation': 'Corrected. The "92%" was inflated. The realistic 70% sensitivity is still clinically strong, functioning as a high-grade early warning system.' }])# --- 2. FINAL CLINICAL STRATEGY ---strategy_final = pd.DataFrame([ {'Outcome': 'Readmission','Role': 'The "2/3rds" Screener','Key Stat': 'Sensitivity 0.67 (6mo)','Actionable Strategy': 'Standard of Care. The model captures the majority (67%) of returning patients. Use it to auto-enroll high-risk patients in "Bridge Programs" (72hr follow-up), knowing you are allocating resources to the right people.' }, {'Outcome': 'Death','Role': 'The "Mortality Watchlist"','Key Stat': 'Sensitivity 0.70 (3mo)','Actionable Strategy': 'High-Impact Triage. Identifying 70% of imminent deaths at admission is life-saving. Any patient flagged High-Risk gets an immediate medical consult. The high Specificity (~80%) means false alarms are manageable.' }])# --- DISPLAY ---print("\n>>> DATA LEAKAGE CORRECTION: BEFORE vs. AFTER")pd.set_option('display.max_colwidth', None)display(performance_final.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid #dcdcdc'}))print("\n>>> FINAL CLINICAL UTILITY STRATEGY")display(strategy_final.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid #dcdcdc', 'background-color': '#f9f9f9'}))
>>> DATA LEAKAGE CORRECTION: BEFORE vs. AFTER
Outcome
Old Sens (Optimistic)
New Sens (Strict)
Change
Interpretation
0
Readmission (6 Mo)
70.5%
66.9%
-3.6%
Robust. The signal is stable. The model reliably identifies 2/3rds of readmissions without seeing the test data.
1
Death (3 Mo)
92.0%
70.2%
-21.8%
Corrected. The "92%" was inflated. The realistic 70% sensitivity is still clinically strong, functioning as a high-grade early warning system.
>>> FINAL CLINICAL UTILITY STRATEGY
Outcome
Role
Key Stat
Actionable Strategy
0
Readmission
The "2/3rds" Screener
Sensitivity 0.67 (6mo)
Standard of Care. The model captures the majority (67%) of returning patients. Use it to auto-enroll high-risk patients in "Bridge Programs" (72hr follow-up), knowing you are allocating resources to the right people.
1
Death
The "Mortality Watchlist"
Sensitivity 0.70 (3mo)
High-Impact Triage. Identifying 70% of imminent deaths at admission is life-saving. Any patient flagged High-Risk gets an immediate medical consult. The high Specificity (~80%) means false alarms are manageable.
Calibration plots
Code
import numpy as npimport pandas as pdfrom sklearn.model_selection import StratifiedKFoldfrom sksurv.linear_model import CoxnetSurvivalAnalysisfrom joblib import Parallel, delayeddef collect_calibration_data( X_list, y_surv_main_list, y_surv_comp_list, times_eval, alpha_idx, l1_ratio, n_splits=10, n_jobs=-2, random_state=2125):""" Runs Stratified CV to collect (y_true, y_pred) pairs for calibration plots. """print(f"Collecting Calibration Data: {len(times_eval)} time points...")# --- 1. Setup --- X_list_np = [X.values.astype(float) for X in X_list] n_imputations =len(X_list)# Safety Fix for Time=0 y_main_safe, y_comp_safe = [], []for i inrange(n_imputations): y_m, y_c = y_surv_main_list[i].copy(), y_surv_comp_list[i].copy()if np.any(y_m["time"] <=0): y_m["time"][y_m["time"] <=0] =1e-5if np.any(y_c["time"] <=0): y_c["time"][y_c["time"] <=0] =1e-5 y_main_safe.append(y_m) y_comp_safe.append(y_c)# Recreate Grid & Stratification (Same as before) dummy = CoxnetSurvivalAnalysis(l1_ratio=1.0, n_alphas=100, alpha_min_ratio=0.001, fit_baseline_model=False) dummy.fit(X_list_np[0], y_main_safe[0]) target_alpha = dummy.alphas_[alpha_idx]# Stratification Logic y_samp = y_main_safe[0] y_comp = y_comp_safe[0] events_cr = np.zeros(len(y_samp), dtype=int) events_cr[y_comp["event"] & (~y_samp["event"] | (y_comp["time"] < y_samp["time"]))] =1 events_cr[y_samp["event"] & (~y_comp["event"] | (y_samp["time"] < y_comp["time"]))] =2# Plan Type Stratification plan_idx = np.zeros(len(X_list[0]), dtype=int) cols = X_list[0].columnsif"plan_type_corr_m-pr"in cols: plan_idx[X_list[0]["plan_type_corr_m-pr"] ==1] =1if"plan_type_corr_pg-pai"in cols: plan_idx[X_list[0]["plan_type_corr_pg-pai"] ==1] =2if"plan_type_corr_pg-pr"in cols: plan_idx[X_list[0]["plan_type_corr_pg-pr"] ==1] =3if"plan_type_corr_m-pai"in cols: plan_idx[X_list[0]["plan_type_corr_m-pai"] ==1] =4 strat_labels = (events_cr *10) + plan_idx counts = pd.Series(strat_labels).value_counts()for g in counts[counts < n_splits].index: strat_labels[strat_labels == g] =0 skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=random_state) cv_splits =list(skf.split(X_list[0], strat_labels))# --- 2. Worker Function ---def process_fold(d, train_idx, test_idx): X_train, X_test = X_list_np[d][train_idx], X_list_np[d][test_idx] y_train, y_test = y_main_safe[d][train_idx], y_main_safe[d][test_idx] model = CoxnetSurvivalAnalysis(l1_ratio=l1_ratio, alphas=dummy.alphas_, normalize=False, fit_baseline_model=True, max_iter=100000) model.fit(X_train, y_train)# Get Survival Functions surv_funcs = model.predict_survival_function(X_test, alpha=float(target_alpha)) surv_probs_matrix = np.row_stack([fn(times_eval) for fn in surv_funcs]) fold_data = []for t_idx, t inenumerate(times_eval):# Define Binary Truth for Calibration# 1 = Event happened <= t# 0 = Survived > t# Censored <= t are excluded is_case = (y_test['event']) & (y_test['time'] <= t) is_control = (y_test['time'] > t) valid_mask = is_case | is_controlif valid_mask.sum() >0: y_true = is_case[valid_mask].astype(int)# Predicted Risk = 1 - Survival Probability y_prob =1.0- surv_probs_matrix[valid_mask, t_idx] fold_data.append({'time_idx': t_idx,'time_val': t,'y_true': y_true,'y_prob': y_prob })return fold_data# --- 3. Execute --- results = Parallel(n_jobs=n_jobs)(delayed(process_fold)(d, tr, te) for d inrange(n_imputations) for tr, te in cv_splits)# --- 4. Aggregate ---# Dictionary to store arrays by time point agg_data = {t: {'y_true': [], 'y_prob': []} for t in times_eval}for res in results:for item in res: t = item['time_val'] agg_data[t]['y_true'].extend(item['y_true']) agg_data[t]['y_prob'].extend(item['y_prob'])return agg_data
Code
import matplotlib.pyplot as pltfrom sklearn.calibration import calibration_curveimport seaborn as snsdef plot_faceted_calibration(agg_data, title_prefix="Outcome"):""" Plots a single figure with calibration curves faceted by time point. """ times =sorted(agg_data.keys()) n_plots =len(times)# Dynamic grid layout cols =3 rows = (n_plots // cols) + (1if n_plots % cols >0else0) fig, axes = plt.subplots(rows, cols, figsize=(5* cols, 4* rows)) axes = axes.flatten()for i, t inenumerate(times): ax = axes[i] data = agg_data[t]iflen(data['y_true']) <10: # Skip empty plots ax.axis('off')continue# Calculate Calibration Curve (10 bins) prob_true, prob_pred = calibration_curve(data['y_true'], data['y_prob'], n_bins=10, strategy='quantile')# Plot Curve ax.plot(prob_pred, prob_true, marker='o', linewidth=2, label='Model', color='#1f77b4')# Plot Perfect Calibration Line ax.plot([0, 1], [0, 1], linestyle='--', color='gray', label='Perfect')# Formatting ax.set_xlim(0, 1) ax.set_ylim(0, 1) ax.set_title(f"Time: {int(t)} Months", fontsize=12, fontweight='bold') ax.set_xlabel("Predicted Probability") ax.set_ylabel("Observed Fraction") ax.grid(alpha=0.3)# Add Histogram of predictions at the bottom (optional but helpful) ax_hist = ax.twinx() ax_hist.hist(data['y_prob'], range=(0,1), bins=20, color='#1f77b4', alpha=0.1) ax_hist.set_yticks([]) # Hide histogram scaleif i ==0: ax.legend()# Hide unused subplotsfor j inrange(i +1, len(axes)): axes[j].axis('off') plt.suptitle(f"Calibration Plots: {title_prefix}", fontsize=16, fontweight='bold', y=1.02) plt.tight_layout() plt.show()
# --- 2. Plot Faceted Calibration ---plot_faceted_calibration(calib_data_readm, title_prefix="Time to Readmission")
Code
# 1. Collect Data (Uses Winning Death Parameters: L1=0.1, Alpha=70)calib_data_death = collect_calibration_data( X_list = imputations_list_jan26, y_surv_main_list = y_surv_death_list, y_surv_comp_list = y_surv_readm_list_corrected, times_eval = times_eval_grid, # The times you want to check alpha_idx =70, l1_ratio =0.1, n_jobs =-2)
Code
# 2. Plotplot_faceted_calibration(calib_data_death, title_prefix="Time to Death")
2. Model Evaluation: Discrimination (C-index) vs. Null Baseline
After evaluating the Integrated Brier Score (IBS), which measures the accuracy of the predicted survival probabilities (calibration), we now evaluate the Discrimination of the model.
A model that knows nothing (Null) is always equivalent to a random coin toss, regardless of how many people die.
Code
import numpy as npimport pandas as pdfrom sksurv.nonparametric import kaplan_meier_estimatorfrom sksurv.functions import StepFunctionfrom sklearn.metrics import confusion_matrix, f1_score, roc_curvedef get_binary_metrics(y_true, y_pred):""" Helper function to calculate standard binary classification metrics. """ tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel()# Avoid division by zero sens = tp / (tp + fn) if (tp + fn) >0else0.0 spec = tn / (tn + fp) if (tn + fp) >0else0.0 ppv = tp / (tp + fp) if (tp + fp) >0else0.0 npv = tn / (tn + fn) if (tn + fn) >0else0.0 f1 = f1_score(y_true, y_pred, zero_division=0)return {'Sens': sens, 'Spec': spec, 'PPV': ppv, 'NPV': npv, 'F1': f1}def compare_model_vs_null(y_train, y_test, model_risk_scores, time_points):""" Compares your trained model against a Null (Kaplan-Meier Baseline) model. Logic: - Null Model: Predicts class based on population average. If >50% of people die by time t, predict Death for ALL. Otherwise, predict Survival for ALL. - Your Model: Uses Youden's J statistic (Sens + Spec - 1) to find the optimal cut-off for risk scores at each time point. """# 1. Train Kaplan-Meier (Null Baseline) on Training Data# We use training data to avoid data leakage. times_km, surv_km = kaplan_meier_estimator(y_train['event'], y_train['time']) km_predict_fn = StepFunction(times_km, surv_km) # Function to get S(t) results_list = []print(f"{'Time':<10} | {'Model F1':<10} | {'Null F1':<10} | {'F1 Gain':<10} | {'Model Sens':<10}")print("-"*65)for t in time_points:# --- A. Define Ground Truth at time t ---# We only evaluate patients who are either:# 1. Dead/Readmitted by time t (Case = 1)# 2. Known to survive past time t (Control = 0)# Censored patients before time t are excluded (unknown status) is_case = (y_test['event'] ==True) & (y_test['time'] <= t) is_control = (y_test['time'] > t) valid_mask = is_case | is_controlif valid_mask.sum() <10: # Skip if too few samplescontinue y_true_binary = is_case[valid_mask].astype(int) current_scores = model_risk_scores[valid_mask]# --- B. Evaluate YOUR MODEL (Smart) ---# 1. Find optimal threshold using ROC curve (Youden's Index)# This ensures we are fair to the model by picking its best operating point fpr, tpr, thresholds = roc_curve(y_true_binary, current_scores) optimal_idx = np.argmax(tpr - fpr) best_threshold = thresholds[optimal_idx]# 2. Make predictions y_pred_model = (current_scores >= best_threshold).astype(int) metrics_model = get_binary_metrics(y_true_binary, y_pred_model)# --- C. Evaluate NULL MODEL (Naive) ---# 1. Get population probability of event at time t prob_survival_population = km_predict_fn(t) prob_event_population =1.0- prob_survival_population# 2. Decision Rule: Majority Vote# If Prob(Event) > 0.5, predict 1 for everyone. Else predict 0 for everyone. null_prediction_class =1if prob_event_population >=0.5else0 y_pred_null = np.full(len(y_true_binary), null_prediction_class) metrics_null = get_binary_metrics(y_true_binary, y_pred_null)# --- D. Store Results --- row = {'Time_Month': t}# Save absolute metricsfor k in metrics_model: row[f'{k}_Model'] = metrics_model[k] row[f'{k}_Null'] = metrics_null[k] row[f'{k}_Gain'] = metrics_model[k] - metrics_null[k] # Positive = Model is better results_list.append(row)# Print quick progressprint(f"{t:<10.1f} | {metrics_model['F1']:<10.3f} | {metrics_null['F1']:<10.3f} | {metrics_model['F1'] - metrics_null['F1']:<10.3f} | {metrics_model['Sens']:<10.3f}")return pd.DataFrame(results_list)
Code
from sksurv.metrics import concordance_index_censoredimport numpy as npimport pandas as pd# 1. Input Your Actual Results (from previous steps)results = {"Readmission": {"Model C-Index": 0.6085, # From your 1-SE tuning"Data": y_surv_readm_list_corrected[0] # Use imputation 0 for null calculation },"Death": {"Model C-Index": 0.7888, # From your Death tuning"Data": y_surv_death_list[0] }}comparison_rows = []for outcome, info in results.items():# 2. Calculate Null C-Index (Mathematical Baseline)# We predict '0' risk for everyone (pure indifference) y_true = info["Data"] null_preds = np.zeros(len(y_true))# Calculate C-index for random guessing null_c_index = concordance_index_censored( y_true["event"], y_true["time"], null_preds )[0]# 3. Calculate Improvement model_c = info["Model C-Index"] diff = model_c - null_c_index# 4. Verdictif diff >0.25: verdict ="Excellent (Strong Clinical Utility)"elif diff >0.15: verdict ="Good (Clear Signal)"elif diff >0.05: verdict ="Fair (Better than Random)"else: verdict ="Poor (No Signal)" comparison_rows.append({"Outcome": outcome,"Null Baseline": null_c_index,"Your Model": model_c,"Absolute Gain": diff,"Verdict": verdict })# --- DISPLAY ---df_comparison = pd.DataFrame(comparison_rows)print("\n>>> MODEL PERFORMANCE VS. NULL BASELINE")pd.set_option('display.max_colwidth', None)display(df_comparison.style.set_properties(**{'text-align': 'left','white-space': 'pre-wrap','background-color': '#f9f9f9','border': '1px solid black'}).hide(axis='index'))
>>> MODEL PERFORMANCE VS. NULL BASELINE
Outcome
Null Baseline
Your Model
Absolute Gain
Verdict
Readmission
0.500000
0.608500
0.108500
Fair (Better than Random)
Death
0.500000
0.788800
0.288800
Excellent (Strong Clinical Utility)
Code
import pandas as pdimport numpy as npfrom sksurv.nonparametric import kaplan_meier_estimatorfrom sksurv.functions import StepFunctionfrom sklearn.metrics import confusion_matrix, f1_score, roc_curvedef get_binary_metrics(y_true, y_pred):"""Calculates Sens, Spec, PPV, NPV, F1 given binary truth and preds.""" tn, fp, fn, tp = confusion_matrix(y_true, y_pred, labels=[0, 1]).ravel() sens = tp / (tp + fn) if (tp + fn) >0else0.0 spec = tn / (tn + fp) if (tn + fp) >0else0.0 ppv = tp / (tp + fp) if (tp + fp) >0else0.0 npv = tn / (tn + fn) if (tn + fn) >0else0.0 f1 = f1_score(y_true, y_pred, zero_division=0)return {'Sens': sens, 'Spec': spec, 'PPV': ppv, 'NPV': npv, 'F1': f1}def compare_model_vs_null_strict(y_train, y_test, scores_train, scores_test, time_points):""" STRICT Comparison: Thresholds derived from TRAINING scores ONLY. Args: y_train, y_test: Structured arrays (event, time) scores_train: Risk scores for training set (used to pick threshold) scores_test: Risk scores for test set (used to evaluate) time_points: List of times to check """# 2. KM Null Model (Fit on Train) times_km, surv_km = kaplan_meier_estimator(y_train['event'], y_train['time']) km_predict_fn = StepFunction(times_km, surv_km) results_list = []for t in time_points:# --- A. Define Binary Targets ---# Train Targets (for threshold finding) case_train = (y_train['event']) & (y_train['time'] <= t) ctrl_train = (y_train['time'] > t) mask_train = case_train | ctrl_train# Test Targets (for evaluation) case_test = (y_test['event']) & (y_test['time'] <= t) ctrl_test = (y_test['time'] > t) mask_test = case_test | ctrl_testif mask_test.sum() <10or mask_train.sum() <10: continue# --- B. Find Threshold (TRAIN ONLY) --- y_true_train = case_train[mask_train].astype(int) scores_tr_valid = scores_train[mask_train] best_thresh =0.5iflen(np.unique(y_true_train)) >1: fpr, tpr, threshs = roc_curve(y_true_train, scores_tr_valid) best_thresh = threshs[np.argmax(tpr - fpr)] # Youden's J# --- C. Evaluate Model (TEST ONLY) ---# We apply the 'best_thresh' found in step B to the test scores y_true_test = case_test[mask_test].astype(int) scores_te_valid = scores_test[mask_test] y_pred_model = (scores_te_valid >= best_thresh).astype(int) metrics_model = get_binary_metrics(y_true_test, y_pred_model)# --- D. Evaluate Null (TEST ONLY) ---# Majority vote based on KM survival prob at time t prob_event =1.0- km_predict_fn(t) null_class =1if prob_event >=0.5else0 y_pred_null = np.full(len(y_true_test), null_class) metrics_null = get_binary_metrics(y_true_test, y_pred_null)# --- E. Store --- row = {'Time_Month': t}for k, v in metrics_model.items(): row[f'{k}_Model'] = vfor k, v in metrics_null.items(): row[f'{k}_Null'] = vfor k in ['F1', 'Sens', 'Spec']: row[f'{k}_Gain'] = metrics_model[k] - metrics_null[k] results_list.append(row)return pd.DataFrame(results_list)
Code
from sklearn.model_selection import train_test_splitfrom sksurv.linear_model import CoxnetSurvivalAnalysis# --- SETUP ---# Standard time grid if you don't have one definedif'times_eval_grid'notinlocals(): times_eval_grid = [3, 6, 12, 24, 36, 60]params = {"readm": {"l1": 0.1, "alpha_idx": 69, "data_y": y_surv_readm_list_corrected[0]},"death": {"l1": 0.1, "alpha_idx": 70, "data_y": y_surv_death_list[0]}}OUTCOME ="readm"print(f"\n"+"🔵"*30)print(f">>> {OUTCOME.upper()}: MODEL vs NULL COMPARISON (STRICT)")print("🔵"*30)# A. Prepare DataX = imputations_list_jan26[0]y = params[OUTCOME]["data_y"]# B. Split (Stratified)X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, stratify=y['event'], random_state=2125)# C. Fit Coxnetcox = CoxnetSurvivalAnalysis( l1_ratio=params[OUTCOME]["l1"], n_alphas=100, alpha_min_ratio=0.001, fit_baseline_model=True)cox.fit(X_train, y_train)# D. Predict Scores for BOTH Train and Test (Using Winning Alpha)target_alpha = cox.alphas_[params[OUTCOME]["alpha_idx"]] scores_train = cox.predict(X_train, alpha=target_alpha) # Needed to find thresholdscores_test = cox.predict(X_test, alpha=target_alpha) # Needed to evaluate# E. Run Comparison (Passing both sets of scores)df_comp = compare_model_vs_null_strict( y_train, y_test, scores_train, scores_test, time_points=times_eval_grid)# F. Display Resultsgain_cols = ['Time_Month', 'F1_Gain', 'Sens_Gain', 'Spec_Gain', 'PPV_Model', 'NPV_Model']print("\n>>> IMPROVEMENT OVER NULL")display(df_comp[gain_cols].round(3))avg_gain = df_comp['F1_Gain'].mean()if avg_gain >0.05:print(f"✅ VERDICT: Strong Signal. Average F1 Gain: +{avg_gain:.3f}")else:print(f"⚠️ VERDICT: Weak Signal. Average F1 Gain: +{avg_gain:.3f}")
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
>>> READM: MODEL vs NULL COMPARISON (STRICT)
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
>>> IMPROVEMENT OVER NULL
Time_Month
F1_Gain
Sens_Gain
Spec_Gain
PPV_Model
NPV_Model
0
3
0.035
0.612
-0.262
0.018
0.996
1
6
0.085
0.720
-0.376
0.045
0.989
2
12
0.185
0.657
-0.384
0.108
0.962
3
36
0.375
0.619
-0.440
0.269
0.849
4
48
0.439
0.680
-0.523
0.324
0.801
5
60
0.475
0.597
-0.453
0.395
0.733
6
72
0.521
0.591
-0.463
0.467
0.657
7
84
0.540
0.522
-0.401
0.560
0.561
8
96
0.504
0.406
-0.313
0.665
0.430
9
108
0.570
0.459
-0.369
0.752
0.324
✅ VERDICT: Strong Signal. Average F1 Gain: +0.373
Code
# Select Outcome to Run (Change to "death" to run mortality)OUTCOME ="death"print(f"\n"+"🔵"*30)print(f">>> {OUTCOME.upper()}: MODEL vs NULL COMPARISON (ALL METRICS)")print("🔵"*30)# A. Prepare DataX = imputations_list_jan26[0]y = params[OUTCOME]["data_y"]# B. Split (Stratified)X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, stratify=y['event'], random_state=2125)# C. Fit Coxnet (Using Winning L1)# Note: We fit a quick path to get the exact alpha indexcox = CoxnetSurvivalAnalysis( l1_ratio=params[OUTCOME]["l1"], n_alphas=100, alpha_min_ratio=0.001, fit_baseline_model=True)cox.fit(X_train, y_train)# Select specific alpha (approximate index mapping)target_alpha = cox.alphas_[params[OUTCOME]["alpha_idx"]] scores = cox.predict(X_test, alpha=target_alpha)# D. Run Comparisondf_comp = compare_model_vs_null(y_train, y_test, scores, TIMES_EVAL)# E. Display Results# 1. Summary of Gains (Does the model add value?)gain_cols = ['Time_Month', 'F1_Gain', 'Sens_Gain', 'Spec_Gain', 'PPV_Gain', 'NPV_Gain']print("\n>>> IMPROVEMENT OVER NULL (Positive = Model is Better)")display(df_comp[gain_cols].round(3))# 2. Detailed View (Model Performance vs. Null Performance)print("\n>>> DETAILED METRICS (Model vs. Null)")detail_cols = ['Time_Month', 'F1_Model', 'F1_Null', 'Sens_Model', 'Sens_Null', 'Spec_Model', 'Spec_Null']display(df_comp[detail_cols].round(3))# Quick Verdictavg_gain = df_comp['F1_Gain'].mean()if avg_gain >0.05:print(f"✅ VERDICT: Strong Signal. Average F1 Gain: +{avg_gain:.3f}")elif avg_gain >0:print(f"⚠️ VERDICT: Weak Signal. Average F1 Gain: +{avg_gain:.3f}")else:print(f"❌ VERDICT: No Signal. Model performs worse than majority vote.")
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
>>> DEATH: MODEL vs NULL COMPARISON (ALL METRICS)
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
Evaluating 10 time points...
>>> IMPROVEMENT OVER NULL (Positive = Model is Better)
Time_Month
F1_Gain
Sens_Gain
Spec_Gain
PPV_Gain
NPV_Gain
0
3
0.006
0.920
-0.447
0.003
0.001
1
6
0.012
0.691
-0.369
0.006
0.002
2
12
0.026
0.667
-0.338
0.013
0.003
3
36
0.111
0.635
-0.262
0.061
0.013
4
48
0.159
0.634
-0.256
0.091
0.019
5
60
0.220
0.630
-0.243
0.133
0.028
6
72
0.272
0.684
-0.288
0.170
0.043
7
84
0.342
0.690
-0.293
0.227
0.059
8
96
0.441
0.690
-0.279
0.324
0.086
9
108
0.533
0.682
-0.268
0.437
0.117
>>> DETAILED METRICS (Model vs. Null)
Time_Month
F1_Model
F1_Null
Sens_Model
Sens_Null
Spec_Model
Spec_Null
0
3
0.006
0.0
0.920
0.0
0.553
1.0
1
6
0.012
0.0
0.691
0.0
0.631
1.0
2
12
0.026
0.0
0.667
0.0
0.662
1.0
3
36
0.111
0.0
0.635
0.0
0.738
1.0
4
48
0.159
0.0
0.634
0.0
0.744
1.0
5
60
0.220
0.0
0.630
0.0
0.757
1.0
6
72
0.272
0.0
0.684
0.0
0.712
1.0
7
84
0.342
0.0
0.690
0.0
0.707
1.0
8
96
0.441
0.0
0.690
0.0
0.721
1.0
9
108
0.533
0.0
0.682
0.0
0.732
1.0
✅ VERDICT: Strong Signal. Average F1 Gain: +0.212
Code
# --- STRICT COMPARISON FOR DEATH ---from sklearn.model_selection import train_test_splitfrom sksurv.linear_model import CoxnetSurvivalAnalysis# 1. Setup Outcome and ParametersOUTCOME ="death"# Ensure params dict is defined (if not already in memory)if'params'notinlocals(): params = {"readm": {"l1": 0.1, "alpha_idx": 69, "data_y": y_surv_readm_list_corrected[0]},"death": {"l1": 0.1, "alpha_idx": 70, "data_y": y_surv_death_list[0]} }# Ensure time grid is definedif'times_eval_grid'notinlocals(): times_eval_grid = [3, 6, 12, 36, 60, 108] # Standard gridprint(f"\n"+"🔵"*30)print(f">>> {OUTCOME.upper()}: MODEL vs NULL COMPARISON (STRICT)")print("🔵"*30)# A. Prepare Data (Imputation 0)X = imputations_list_jan26[0]y = params[OUTCOME]["data_y"]# B. Split (Stratified)X_train, X_test, y_train, y_test = train_test_split( X, y, test_size=0.2, stratify=y['event'], random_state=2125)# C. Fit Coxnetprint("Fitting model...")cox = CoxnetSurvivalAnalysis( l1_ratio=params[OUTCOME]["l1"], n_alphas=100, alpha_min_ratio=0.001, fit_baseline_model=True)cox.fit(X_train, y_train)# D. Predict Scores for BOTH Train and Test# We need Train scores to learn the threshold, and Test scores to evaluatetarget_alpha = cox.alphas_[params[OUTCOME]["alpha_idx"]] scores_train = cox.predict(X_train, alpha=target_alpha) scores_test = cox.predict(X_test, alpha=target_alpha) # E. Run Comparison (Passing BOTH sets of scores)df_comp = compare_model_vs_null_strict( y_train, y_test, scores_train, scores_test, time_points=times_eval_grid)# F. Display Results# 1. Summary of Gainsgain_cols = ['Time_Month', 'F1_Gain', 'Sens_Gain', 'Spec_Gain', 'PPV_Model', 'NPV_Model']print("\n>>> IMPROVEMENT OVER NULL (Positive = Model is Better)")display(df_comp[gain_cols].round(3))# 2. Detailed Metricsprint("\n>>> DETAILED METRICS (Model vs. Null)")detail_cols = ['Time_Month', 'F1_Model', 'F1_Null', 'Sens_Model', 'Sens_Null', 'Spec_Model', 'Spec_Null']display(df_comp[detail_cols].round(3))# Verdictavg_gain = df_comp['F1_Gain'].mean()if avg_gain >0.05:print(f"✅ VERDICT: Strong Signal. Average F1 Gain: +{avg_gain:.3f}")elif avg_gain >0:print(f"⚠️ VERDICT: Weak Signal. Average F1 Gain: +{avg_gain:.3f}")else:print(f"❌ VERDICT: No Signal. Model performs worse than majority vote.")
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
>>> DEATH: MODEL vs NULL COMPARISON (STRICT)
🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵🔵
Fitting model...
>>> IMPROVEMENT OVER NULL (Positive = Model is Better)
Time_Month
F1_Gain
Sens_Gain
Spec_Gain
PPV_Model
NPV_Model
0
3
0.009
0.360
-0.116
0.004
0.999
1
6
0.015
0.364
-0.152
0.008
0.998
2
12
0.027
0.579
-0.280
0.014
0.996
3
36
0.108
0.638
-0.271
0.059
0.987
4
48
0.156
0.636
-0.264
0.089
0.980
5
60
0.211
0.639
-0.263
0.126
0.972
6
72
0.287
0.622
-0.234
0.187
0.959
7
84
0.358
0.614
-0.226
0.253
0.941
8
96
0.455
0.608
-0.207
0.363
0.913
9
108
0.532
0.597
-0.198
0.480
0.867
>>> DETAILED METRICS (Model vs. Null)
Time_Month
F1_Model
F1_Null
Sens_Model
Sens_Null
Spec_Model
Spec_Null
0
3
0.009
0.0
0.360
0.0
0.884
1.0
1
6
0.015
0.0
0.364
0.0
0.848
1.0
2
12
0.027
0.0
0.579
0.0
0.720
1.0
3
36
0.108
0.0
0.638
0.0
0.729
1.0
4
48
0.156
0.0
0.636
0.0
0.736
1.0
5
60
0.211
0.0
0.639
0.0
0.737
1.0
6
72
0.287
0.0
0.622
0.0
0.766
1.0
7
84
0.358
0.0
0.614
0.0
0.774
1.0
8
96
0.455
0.0
0.608
0.0
0.793
1.0
9
108
0.532
0.0
0.597
0.0
0.802
1.0
✅ VERDICT: Strong Signal. Average F1 Gain: +0.216
Code
import pandas as pdfrom IPython.display import display# --- PERFORMANCE GAINS DATAFRAME (Model vs. Null) ---performance_gains_msg = pd.DataFrame([ {'Metric': 'Sensitivity (Recall)','Readmission Gain': '+50% to +70%','Death Gain': '+36% to +64%','Clinical Meaning': 'Strict validation shows the model captures ~64% of long-term mortality cases. While it misses some acute short-term events (36% sensitivity at 3mo), it still massively outperforms the "Null" baseline (0%).' }, {'Metric': 'Specificity','Readmission Gain': '-30% to -40%','Death Gain': '-12% to -28%','Clinical Meaning': 'Improved Specificity. The strict Death model is far less "noisy" than the initial estimate (Specificity 0.88 at 3mo). It flags fewer patients, but the "High Risk" signal is more credible.' }, {'Metric': 'NPV (Safety)','Readmission Gain': '+2% to +5%','Death Gain': '+0.1% to +9%','Clinical Meaning': 'The "Rule-Out" Powerhouse. For mortality, the NPV is practically perfect (0.999 at 3mo). If the model labels a patient "Low Risk," they are almost guaranteed to survive the next quarter.' }])# --- CLINICAL STRATEGY DATAFRAME ---clinical_strategy_msg = pd.DataFrame([ {'Outcome': 'Readmission','Role': 'The "Dragnet" (Screening)','Key Stat': 'Sensitivity ~0.70 (6mo)','Actionable Strategy': 'Deploy as a broad screening tool. It casts a wide net to catch ~70% of returning patients. Accept the moderate false alarm rate as the cost of doing business to prevent relapse.' }, {'Outcome': 'Death','Role': 'The "Sentinel" (Safety Monitor)','Key Stat': 'NPV > 99.9% (3mo)','Actionable Strategy': 'Deploy as a "Rule-Out" tool. While it misses some acute/sudden deaths (Sens 0.36), it excels at identifying the safe majority. Use it to confidently de-escalate intense monitoring for "Low Risk" patients, focusing resources on the minority with sustained risk markers.' }])# --- DISPLAY ---print("\n>>> TAKE-HOME MESSAGE: FINAL MODEL UTILITY (STRICT VALIDATION)")pd.set_option('display.max_colwidth', None)# Styling for Performance Gainsdisplay(performance_gains_msg.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap','border': '1px solid #dcdcdc'}).set_table_styles([{'selector': 'th', 'props': [('background-color', '#f0f0f0')]}]))print("\n")# Styling for Clinical Strategydisplay(clinical_strategy_msg.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap','border': '1px solid #dcdcdc','background-color': '#f9f9f9'}).set_table_styles([{'selector': 'th', 'props': [('background-color', '#e6f2ff')]}]))
>>> TAKE-HOME MESSAGE: FINAL MODEL UTILITY (STRICT VALIDATION)
Metric
Readmission Gain
Death Gain
Clinical Meaning
0
Sensitivity (Recall)
+50% to +70%
+36% to +64%
Strict validation shows the model captures ~64% of long-term mortality cases. While it misses some acute short-term events (36% sensitivity at 3mo), it still massively outperforms the "Null" baseline (0%).
1
Specificity
-30% to -40%
-12% to -28%
Improved Specificity. The strict Death model is far less "noisy" than the initial estimate (Specificity 0.88 at 3mo). It flags fewer patients, but the "High Risk" signal is more credible.
2
NPV (Safety)
+2% to +5%
+0.1% to +9%
The "Rule-Out" Powerhouse. For mortality, the NPV is practically perfect (0.999 at 3mo). If the model labels a patient "Low Risk," they are almost guaranteed to survive the next quarter.
Outcome
Role
Key Stat
Actionable Strategy
0
Readmission
The "Dragnet" (Screening)
Sensitivity ~0.70 (6mo)
Deploy as a broad screening tool. It casts a wide net to catch ~70% of returning patients. Accept the moderate false alarm rate as the cost of doing business to prevent relapse.
1
Death
The "Sentinel" (Safety Monitor)
NPV > 99.9% (3mo)
Deploy as a "Rule-Out" tool. While it misses some acute/sudden deaths (Sens 0.36), it excels at identifying the safe majority. Use it to confidently de-escalate intense monitoring for "Low Risk" patients, focusing resources on the minority with sustained risk markers.
Code
import pandas as pdfrom IPython.display import display# --- 1. FUNCTIONAL FORM & DRIVERS (Why it happens) ---functional_readm = pd.DataFrame([ {'Driver Category': 'Systemic Structure','Top Feature': 'Plan Type (`plan_type`)','Insight': 'The "Where" matters more than the "Who." The strongest predictor of readmission is the structure of care (e.g., General vs. Women-only programs), suggesting that systemic support levels dictate stability more than clinical severity.' }, {'Driver Category': 'Demographics','Top Feature': 'Ethnicity & Sex','Insight': 'Social Determinants are key. Risk is stratified by demographic profiles, likely reflecting disparate access to post-discharge housing or community support networks.' }, {'Driver Category': 'Retention','Top Feature': 'Time in Treatment (`dit_m`)','Insight': 'The "Sweet Spot." The protective benefit of staying in treatment is non-linear. It peaks at ~6 months; keeping patients longer yields diminishing returns for readmission prevention.' }])# --- 2. STATISTICAL GAINS (Model vs. Null) ---# Based on Strict Validation logicperformance_readm = pd.DataFrame([ {'Metric': 'Sensitivity (Recall)','Gain over Null': '+60% to +70%','Clinical Meaning': 'The "Safety Net." A random guess finds 0% of returning patients. This model captures the majority (~2/3rds), allowing proactive intervention rather than reactive admissions.' }, {'Metric': 'F1 Score','Gain over Null': '+0.38 (Avg)','Clinical Meaning': 'Signal Strength. Despite the noise inherent in behavioral health, the model identifies a clear, actionable signal that persists even under strict cross-validation.' }, {'Metric': 'Temporal Stability','Trend': 'Decays over time','Clinical Meaning': 'Short-Term Validity. Unlike the Death model (stable for 9 years), Readmission risk is dynamic. The prediction is highly accurate for the first 6-12 months but loses relevance as patients\' life circumstances change.' }])# --- 3. CLINICAL STRATEGY (What to do) ---strategy_readm = pd.DataFrame([ {'Role': 'The "Revolving Door" Blocker','Target Population': 'High-Frequency Returners','Actionable Strategy': 'Deploy as a "Discharge Compass." Since risk is highest in the first 6 months (Sens > 0.70), use the score to mandate "Bridge Appointments" (case management contact within 72 hours) for the high-risk group.' }, {'Role': 'Resource Allocator','Target Population': 'Low-Risk "Stabilizers"','Actionable Strategy': 'De-escalation. Patients flagged as "Low Risk" (High NPV) can be safely stepped down to lower-intensity community monitoring, freeing up expensive slots for the high-risk cohort.' }])# --- DISPLAY ---print("\n>>> TAKE-HOME MESSAGE: READMISSION MODEL INTERPRETATION")pd.set_option('display.max_colwidth', None)print("\n--- A. THE DRIVERS (Systemic & Social) ---")display(functional_readm.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid #dcdcdc'}).set_table_styles([{'selector': 'th', 'props': [('background-color', '#e6f7ff')]}])) # Light Blue for Readmissionprint("\n--- B. THE PERFORMANCE (Strict Validation) ---")display(performance_readm.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid #dcdcdc'}).set_table_styles([{'selector': 'th', 'props': [('background-color', '#f0f0f0')]}]))print("\n--- C. THE STRATEGY (Clinical Implementation) ---")display(strategy_readm.style.set_properties(**{'text-align': 'left', 'white-space': 'pre-wrap', 'border': '1px solid #dcdcdc', 'background-color': '#f9f9f9'}).set_table_styles([{'selector': 'th', 'props': [('background-color', '#e6f7ff')]}]))
>>> TAKE-HOME MESSAGE: READMISSION MODEL INTERPRETATION
--- A. THE DRIVERS (Systemic & Social) ---
Driver Category
Top Feature
Insight
0
Systemic Structure
Plan Type (`plan_type`)
The "Where" matters more than the "Who." The strongest predictor of readmission is the structure of care (e.g., General vs. Women-only programs), suggesting that systemic support levels dictate stability more than clinical severity.
1
Demographics
Ethnicity & Sex
Social Determinants are key. Risk is stratified by demographic profiles, likely reflecting disparate access to post-discharge housing or community support networks.
2
Retention
Time in Treatment (`dit_m`)
The "Sweet Spot." The protective benefit of staying in treatment is non-linear. It peaks at ~6 months; keeping patients longer yields diminishing returns for readmission prevention.
--- B. THE PERFORMANCE (Strict Validation) ---
Metric
Gain over Null
Clinical Meaning
Trend
0
Sensitivity (Recall)
+60% to +70%
The "Safety Net." A random guess finds 0% of returning patients. This model captures the majority (~2/3rds), allowing proactive intervention rather than reactive admissions.
nan
1
F1 Score
+0.38 (Avg)
Signal Strength. Despite the noise inherent in behavioral health, the model identifies a clear, actionable signal that persists even under strict cross-validation.
nan
2
Temporal Stability
nan
Short-Term Validity. Unlike the Death model (stable for 9 years), Readmission risk is dynamic. The prediction is highly accurate for the first 6-12 months but loses relevance as patients' life circumstances change.
Decays over time
--- C. THE STRATEGY (Clinical Implementation) ---
Role
Target Population
Actionable Strategy
0
The "Revolving Door" Blocker
High-Frequency Returners
Deploy as a "Discharge Compass." Since risk is highest in the first 6 months (Sens > 0.70), use the score to mandate "Bridge Appointments" (case management contact within 72 hours) for the high-risk group.
1
Resource Allocator
Low-Risk "Stabilizers"
De-escalation. Patients flagged as "Low Risk" (High NPV) can be safely stepped down to lower-intensity community monitoring, freeing up expensive slots for the high-risk cohort.
Functional form
Unadjusted
Code
import numpy as npimport pandas as pdimport matplotlib.pyplot as pltimport seaborn as snsfrom sksurv.linear_model import CoxPHSurvivalAnalysisfrom statsmodels.nonparametric.smoothers_lowess import lowessdef plot_functional_form_diagnostics(X, y_surv, variables, outcome_name="Outcome"):""" Plots Martingale Residuals vs. Continuous Predictors to assess functional form. """# 1. Fit a Base Model (Using all variables EXCEPT the ones being tested is ideal, # but for a quick univariate check, an intercept-only or minimal model works).# Here we fit a model on 'other' variables to isolate the effect of the target.# For simplicity/robustness in high-dim data, we often use the 'Null' approach # (residuals from the average hazard).print(f"--- Diagnosing Functional Forms for {outcome_name} ---")# Extract Event and Time events = y_surv['event'] times = y_surv['time']# Estimate Nelson-Aalen Cumulative Hazard (Non-parametric baseline)# This serves as our "Expected" risk under the Null hypothesis df_na = pd.DataFrame({'time': times, 'event': events}).sort_values('time') unique_times = df_na['time'].unique()# Calculate cumulative hazard manually (simple Nelson-Aalen)# H(t) = sum(d_i / n_i) n_at_risk =len(df_na) base_haz = [] cumulative_hazard =0# Map time to cumulative hazard time_to_haz = {}# Group by time to handle ties grouped = df_na.groupby('time')['event'].agg(['sum', 'count'])# sum = deaths, count = total at that time (approx) - actually need risk set# Better to iterate:# sorting descending risk set df_sorted = df_na.sort_values('time') total_at_risk =len(df_sorted)# Robust Nelson-Aalen calculation distinct_times =sorted(df_sorted['time'].unique()) cum_haz =0 time_haz_map = {}for t in distinct_times:# events at this time n_events = df_sorted[df_sorted['time'] == t]['event'].sum() n_risk =len(df_sorted[df_sorted['time'] >= t])if n_risk >0: cum_haz += (n_events / n_risk) time_haz_map[t] = cum_haz# 2. Calculate Martingale Residuals# M_i = Event_i - Expected_i# Expected_i = CumulativeHazard(Time_i) residuals = []for i inrange(len(X)): t_i = times[i] e_i = events[i] expected = time_haz_map.get(t_i, 0)# Martingale residual m_i = e_i - expected residuals.append(m_i) residuals = np.array(residuals)# 3. Plotting fig, axes = plt.subplots(1, len(variables), figsize=(6*len(variables), 5))iflen(variables) ==1: axes = [axes]for i, var inenumerate(variables): ax = axes[i] x_val = X[var]# Scatter plot (High transparency because N is large) ax.scatter(x_val, residuals, alpha=0.1, color='gray', s=10)# LOWESS Smoother (The "Truth" Line)# frac=0.3 means we use 30% of data to smooth (adjust for smoothness) smooth = lowess(residuals, x_val, frac=0.3)# Plot the smooth line ax.plot(smooth[:, 0], smooth[:, 1], color='red', linewidth=3, label='Observed Trend')# Reference Line (Linear / Zero) ax.axhline(0, color='blue', linestyle='--', alpha=0.5, label='Null (Linear)') ax.set_title(f"Functional Form: {var}", fontweight='bold') ax.set_xlabel(f"{var} Value") ax.set_ylabel("Martingale Residual (Excess Risk)")# Add interpretation text ax.text(0.05, 0.95, "Curve = Non-linear\nFlat = No effect", transform=ax.transAxes, verticalalignment='top', fontsize=9, bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))if i ==0: ax.legend() plt.tight_layout() plt.show()# --- RUN DIAGNOSTICS ---# Select the continuous variables you want to check# 1. adm_age_rec3 (Age)# 2. dit_m (Retention / Time in treatment)vars_to_check = ['adm_age_rec3', 'dit_m', 'porc_pobr']