#@title ⚡ Final Comprehensive Evaluation: Pooled DeepSurv (AJ Competing Risks)
import torch
import numpy as np
import pandas as pd
import shap
import time
import gc
import warnings
import pickle
import os
from datetime import datetime
from sklearn.model_selection import StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, f1_score, precision_score, recall_score, confusion_matrix
from sksurv.metrics import concordance_index_ipcw
from pycox.models import CoxPH
import torchtuples as tt
from lifelines import AalenJohansenFitter
start_time = time.time()
TEST_MODE = False
# --- 1. CONFIGURATION (fixed hyperparameters) ---
BEST_LR = 0.0008
BEST_WD = 0.00025
BEST_BATCH = 1024
BEST_DROPOUT = 0.57
BEST_NODES = [256, 256, 128]
K_FOLDS = 10
EVAL_HORIZONS = [3, 6, 12, 24, 36, 48, 60, 72, 84, 96, 108]
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
N_IMPUTATIONS = len(imputations_list_jan26)
if TEST_MODE:
N_IMPUTATIONS_TEST = 1
K_FOLDS_TEST = 3
EVAL_HORIZONS_TEST = [12, 24]
MAX_EPOCHS_TEST = 30
SHAP_FOLDS_TEST = 2
else:
N_IMPUTATIONS_TEST = N_IMPUTATIONS
K_FOLDS_TEST = K_FOLDS
EVAL_HORIZONS_TEST = EVAL_HORIZONS
MAX_EPOCHS_TEST = 100
SHAP_FOLDS_TEST = 3
warnings.filterwarnings("ignore")
print(f"Starting pooled DeepSurv evaluation on {N_IMPUTATIONS_TEST} imputations...")
print(f"Device: {DEVICE} | Horizons: {EVAL_HORIZONS_TEST}")
# --- 2. CUSTOM AALEN-JOHANSEN CENSORING ---
class AalenJohansenCensoring:
"""
Estimates Censoring Distribution G(t) = P(C > t) using Aalen-Johansen.
Treats 'Censoring' as Event 1, and 'Death/Readm' as Competing Event 2.
"""
def __init__(self):
self.ajf = AalenJohansenFitter(calculate_variance=False)
self.max_time = 0
def fit(self, durations, events_composite):
# Input events: 0=Censored, 1=Death, 2=Readm
aj_events = np.zeros_like(events_composite)
# People who were originally censored (0) are now the Event of Interest (1)
aj_events[events_composite == 0] = 1
# People who died/readmitted (1, 2) are now Competing Risks (2)
aj_events[events_composite > 0] = 2
self.max_time = durations.max()
self.ajf.fit(durations, event_observed=aj_events, event_of_interest=1)
def predict(self, times):
# AJF predicts CIF_c(t) = P(C <= t, Event=Censored). We need G(t) = P(C > t) = 1 - CIF_c(t)
if np.isscalar(times):
cif_val = self.ajf.predict(times).item()
return 1.0 - cif_val
else:
cif_vals = self.ajf.predict(times).values.flatten()
return 1.0 - cif_vals
def compute_brier_competing(cif_values_at_time_horizon, censoring_dist,
Y_test, D_test, event_of_interest, time_horizon):
"""Brier Score using Aalen-Johansen IPCW weights."""
n = len(Y_test)
residuals = np.zeros(n)
w_horizon = censoring_dist.predict(time_horizon)
if w_horizon == 0: w_horizon = 1e-9
w_obs_all = censoring_dist.predict(Y_test)
w_obs_all[w_obs_all == 0] = 1e-9
for idx in range(n):
observed_time = Y_test[idx]
event_indicator = D_test[idx]
if observed_time > time_horizon:
residuals[idx] = (cif_values_at_time_horizon[idx])**2 / w_horizon
else:
w_obs = w_obs_all[idx]
if event_indicator == event_of_interest:
residuals[idx] = (1 - cif_values_at_time_horizon[idx])**2 / w_obs
elif event_indicator != event_of_interest and event_indicator != 0:
residuals[idx] = (cif_values_at_time_horizon[idx])**2 / w_obs
return residuals.mean()
# --- 3. HELPERS ---
def get_binary_target(events, times, risk_id, t_horizon):
is_case = (events == risk_id) & (times <= t_horizon)
mask_censored_early = (events == 0) & (times <= t_horizon)
valid_mask = ~mask_censored_early
y_binary = is_case[valid_mask].astype(int)
return y_binary, valid_mask
def find_optimal_threshold(y_true, y_prob):
thresholds = np.linspace(0.01, 0.99, 99)
best_f1 = -1.0
best_th = 0.5
for th in thresholds:
y_pred = (y_prob >= th).astype(int)
f1 = f1_score(y_true, y_pred, zero_division=0)
if f1 > best_f1:
best_f1 = f1
best_th = th
return best_th
def calculate_binary_metrics(y_true, y_prob, fixed_threshold):
y_pred = (y_prob >= fixed_threshold).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
return {
"F1": f1_score(y_true, y_pred, zero_division=0),
"Sens": recall_score(y_true, y_pred, zero_division=0),
"Spec": tn / (tn + fp) if (tn + fp) > 0 else 0.0,
"PPV": precision_score(y_true, y_pred, zero_division=0),
"NPV": tn / (tn + fn) if (tn + fn) > 0 else 0.0,
}
def bootstrap_ci_non_normal(data, alpha=0.05):
if len(data) == 0:
return np.nan, np.nan, np.nan
if len(data) == 1:
return data[0], data[0], data[0]
lower = np.percentile(data, 100 * (alpha / 2))
upper = np.percentile(data, 100 * (1 - alpha / 2))
return np.mean(data), lower, upper
def pick_first_existing(df, candidates):
for c in candidates:
if c in df.columns:
return c
return None
def build_plan_idx(X_curr):
col_pg_pab = pick_first_existing(X_curr, ["plan_type_corr_pg_pab", "plan_type_corr_pg-pab"])
col_pg_pr = pick_first_existing(X_curr, ["plan_type_corr_pg_pr", "plan_type_corr_pg-pr"])
col_pg_pai = pick_first_existing(X_curr, ["plan_type_corr_pg_pai", "plan_type_corr_pg-pai"])
col_m_pr = pick_first_existing(X_curr, ["plan_type_corr_m_pr", "plan_type_corr_m-pr"])
col_m_pai = pick_first_existing(X_curr, ["plan_type_corr_m_pai", "plan_type_corr_m-pai"])
plan_idx = np.zeros(len(X_curr), dtype=int)
if col_pg_pr is not None:
plan_idx[X_curr[col_pg_pr].astype(int) == 1] = 2
if col_pg_pai is not None:
plan_idx[X_curr[col_pg_pai].astype(int) == 1] = 3
if col_m_pr is not None:
plan_idx[X_curr[col_m_pr].astype(int) == 1] = 4
if col_m_pai is not None:
plan_idx[X_curr[col_m_pai].astype(int) == 1] = 5
if col_pg_pab is not None:
plan_idx[X_curr[col_pg_pab].astype(int) == 1] = 1
else:
non_ref_cols = [c for c in [col_pg_pr, col_pg_pai, col_m_pr, col_m_pai] if c is not None]
if non_ref_cols:
inferred_pg_pab = (X_curr[non_ref_cols].astype(int).sum(axis=1) == 0)
plan_idx[inferred_pg_pab] = 1
return plan_idx
def risk_at_horizon(surv_df, t_horizon):
grid = surv_df.index.values.astype(float)
idx = np.searchsorted(grid, t_horizon, side="right") - 1
idx = int(np.clip(idx, 0, len(grid) - 1))
return 1.0 - surv_df.iloc[idx].values.astype(float)
def integrated_risk_score(surv_df):
grid = surv_df.index.values.astype(float)
risk_curve = 1.0 - surv_df.values
return np.trapz(risk_curve, x=grid, axis=0)
def fit_deepsurv_model(X_train_s, t_train, e_train_bin, X_val_s, t_val, e_val_bin):
net = tt.practical.MLPVanilla(
in_features=X_train_s.shape[1],
num_nodes=BEST_NODES,
out_features=1,
batch_norm=True,
dropout=BEST_DROPOUT,
output_bias=False
)
model = CoxPH(net, tt.optim.Adam)
model.set_device(DEVICE)
model.optimizer.set_lr(BEST_LR)
model.optimizer.param_groups[0]["weight_decay"] = BEST_WD
y_train_cs = (t_train.astype("float32"), e_train_bin.astype("int64"))
y_val_cs = (t_val.astype("float32"), e_val_bin.astype("int64"))
model.fit(
X_train_s,
y_train_cs,
batch_size=BEST_BATCH,
epochs=MAX_EPOCHS_TEST,
callbacks=[tt.callbacks.EarlyStopping(patience=15)],
verbose=False,
val_data=(X_val_s, y_val_cs),
)
model.compute_baseline_hazards(X_train_s, y_train_cs)
return model
# --- 4. MAIN POOLED LOOP ---
pooled_results = []
threshold_records = []
baseline_hazards_log = [] # <--- ADD THIS HERE
for imp_idx in range(N_IMPUTATIONS_TEST):
print(f"\nImputation {imp_idx + 1}/{N_IMPUTATIONS_TEST}")
X_raw = imputations_list_jan26[imp_idx].copy()
y_d = y_surv_death_list[imp_idx]
y_r = y_surv_readm_list[imp_idx]
t_d = np.asarray(y_d["time"])
e_d = np.asarray(y_d["event"]).astype(bool)
t_r = np.asarray(y_r["time"])
e_r = np.asarray(y_r["event"]).astype(bool)
events = np.zeros(len(X_raw), dtype=int)
times = t_d.copy().astype("float32")
mask_r = e_r & (t_r <= t_d)
events[mask_r] = 2
times[mask_r] = t_r[mask_r]
mask_d = e_d & (~mask_r)
events[mask_d] = 1
print("Event counts (0=Censor, 1=Death, 2=Readm):", np.bincount(events))
X_curr = X_raw.copy()
plan_cols = [c for c in X_curr.columns if c.startswith("plan_type_corr")]
if plan_cols:
X_curr[plan_cols] = X_curr[plan_cols].astype("float32")
plan_sum = X_curr[plan_cols].astype(int).sum(axis=1)
if (plan_sum > 1).any():
raise ValueError("Invalid plan encoding: some rows have >1 plan types.")
plan_idx = build_plan_idx(X_curr)
strat_labels = (events * 10) + plan_idx
skf = StratifiedKFold(n_splits=K_FOLDS_TEST, shuffle=True, random_state=2125 + imp_idx)
for fold_idx, (train_idx, val_idx) in enumerate(skf.split(X_curr, strat_labels)):
print(".", end="")
X_train = X_curr.iloc[train_idx].values
X_val = X_curr.iloc[val_idx].values
t_train, e_train = times[train_idx], events[train_idx]
t_val, e_val = times[val_idx], events[val_idx]
scaler = StandardScaler().fit(X_train)
X_train_s = scaler.transform(X_train).astype("float32")
X_val_s = scaler.transform(X_val).astype("float32")
# Cause-Specific Event Masks
e_train_d = (e_train == 1).astype("int64")
e_val_d = (e_val == 1).astype("int64")
e_train_r = (e_train == 2).astype("int64")
e_val_r = (e_val == 2).astype("int64")
# Train 2 independent Cause-Specific DeepSurv models
model_d = fit_deepsurv_model(X_train_s, t_train, e_train_d, X_val_s, t_val, e_val_d)
model_r = fit_deepsurv_model(X_train_s, t_train, e_train_r, X_val_s, t_val, e_val_r)
surv_val_d = model_d.predict_surv_df(X_val_s)
surv_val_r = model_r.predict_surv_df(X_val_s)
surv_train_d = model_d.predict_surv_df(X_train_s)
surv_train_r = model_r.predict_surv_df(X_train_s)
# ---------------------------------------------------------
# --- ROBUST MULTI-HORIZON SHAP ---
# ---------------------------------------------------------
SHAP_HORIZONS = [3, 6, 12, 36, 60, 72, 84, 96]
#if imp_idx == 0 and fold_idx == 0:
# shap_multi_agg = {'Death': {h: {'vals': [], 'data': []} for h in SHAP_HORIZONS},
# 'Readmission': {h: {'vals': [], 'data': []} for h in SHAP_HORIZONS}}
if imp_idx == 0 and fold_idx == 0:
shap_multi_agg = {
'Death': {h: {'vals': [], 'data': [], 'base': []} for h in SHAP_HORIZONS},
'Readmission': {h: {'vals': [], 'data': [], 'base': []} for h in SHAP_HORIZONS}
}
if imp_idx == 0 and fold_idx < SHAP_FOLDS_TEST:
print(" [SHAP]", end="")
try:
bg_size = min(100, len(X_train_s))
bg_idx = np.random.choice(len(X_train_s), bg_size, replace=False)
bg_data = X_train_s[bg_idx]
test_size = min(100, len(X_val_s))
test_idx = np.random.choice(len(X_val_s), test_size, replace=False)
test_data = X_val_s[test_idx]
test_df = pd.DataFrame(test_data, columns=X_curr.columns)
for h in SHAP_HORIZONS:
def pred_death_h(x):
return risk_at_horizon(model_d.predict_surv_df(np.asarray(x, dtype='float32')), h)
ex_death = shap.KernelExplainer(pred_death_h, bg_data)
shap_vals_d = ex_death.shap_values(test_data, nsamples=50, silent=True)
shap_multi_agg['Death'][h]['vals'].append(shap_vals_d)
shap_multi_agg['Death'][h]['data'].append(test_df)
# ✅ FIX: Guardar el baseline repetido para cada paciente de este fold
shap_multi_agg['Death'][h]['base'].append(np.full(len(test_data), ex_death.expected_value))
def pred_readm_h(x):
return risk_at_horizon(model_r.predict_surv_df(np.asarray(x, dtype='float32')), h)
ex_readm = shap.KernelExplainer(pred_readm_h, bg_data)
shap_vals_r = ex_readm.shap_values(test_data, nsamples=50, silent=True)
shap_multi_agg['Readmission'][h]['vals'].append(shap_vals_r)
shap_multi_agg['Readmission'][h]['data'].append(test_df)
# ✅ FIX: Guardar el baseline
shap_multi_agg['Readmission'][h]['base'].append(np.full(len(test_data), ex_readm.expected_value))
except Exception as e:
pass
# ---------------------------------------------------------
outcomes_map = {
1: ("Death", surv_val_d, surv_train_d),
2: ("Readmission", surv_val_r, surv_train_r),
}
# 🟢 FIX: Fit Aalen-Johansen for proper Competing Risks IPCW weighting
aj_censor = AalenJohansenCensoring()
aj_censor.fit(t_train, e_train) # e_train contains 0, 1, 2
for risk_id, (outcome_name, surv_val_k, surv_train_k) in outcomes_map.items():
y_tr_cs = np.array([(bool(e == risk_id), t) for e, t in zip(e_train, t_train)], dtype=[("e", bool), ("t", float)])
y_va_cs = np.array([(bool(e == risk_id), t) for e, t in zip(e_val, t_val)], dtype=[("e", bool), ("t", float)])
risk_global = integrated_risk_score(surv_val_k)
try:
uno_g = concordance_index_ipcw(y_tr_cs, y_va_cs, risk_global)[0]
except Exception:
uno_g = np.nan
pooled_results.append({
"Imp": imp_idx, "Fold": fold_idx, "Outcome": outcome_name,
"Time": "Global", "Metric": "Uno C-Index", "Value": uno_g,
})
for t in EVAL_HORIZONS_TEST:
risk_t_val = risk_at_horizon(surv_val_k, t)
risk_t_train = risk_at_horizon(surv_train_k, t)
try:
auc_u = concordance_index_ipcw(y_tr_cs, y_va_cs, risk_t_val, tau=t)[0]
except Exception:
auc_u = np.nan
# 🟢 FIX: Call compute_brier_competing using AJ Weights
brier_cr = compute_brier_competing(
cif_values_at_time_horizon=risk_t_val,
censoring_dist=aj_censor, # Passed the AJ class
Y_test=t_val,
D_test=e_val, # Passed raw 0, 1, 2
event_of_interest=risk_id,
time_horizon=t,
)
y_bin_train, mask_train = get_binary_target(e_train, t_train, risk_id, t)
y_bin_val, mask_val = get_binary_target(e_val, t_val, risk_id, t)
best_th = np.nan
threshold_source = "Not estimated"
metrics_pack = {"Uno C-Index": auc_u}
if len(np.unique(y_bin_train)) > 1 and len(np.unique(y_bin_val)) > 1:
best_th = find_optimal_threshold(y_bin_train, risk_t_train[mask_train])
threshold_source = "Train max-F1"
bin_met = calculate_binary_metrics(y_bin_val, risk_t_val[mask_val], best_th)
auc_roc = roc_auc_score(y_bin_val, risk_t_val[mask_val])
metrics_pack.update({
"AUC-ROC": auc_roc, "F1": bin_met["F1"], "Sens": bin_met["Sens"],
"Spec": bin_met["Spec"], "PPV": bin_met["PPV"], "NPV": bin_met["NPV"],
})
threshold_records.append({
"Imp": imp_idx, "Fold": fold_idx, "Outcome": outcome_name, "Time": t,
"Threshold": best_th, "Threshold_Source": threshold_source,
"N_train_valid": int(mask_train.sum()), "N_val_valid": int(mask_val.sum()),
"N_train_pos": int(y_bin_train.sum()) if len(y_bin_train) else np.nan,
"N_val_pos": int(y_bin_val.sum()) if len(y_bin_val) else np.nan,
})
pooled_results.append({
"Imp": imp_idx, "Fold": fold_idx, "Outcome": outcome_name,
"Time": t, "Metric": "Brier Score (CR)", "Value": brier_cr,
})
for m_name, m_val in metrics_pack.items():
pooled_results.append({
"Imp": imp_idx, "Fold": fold_idx, "Outcome": outcome_name,
"Time": t, "Metric": m_name, "Value": m_val,
})
# --- STORE BASELINE HAZARDS ---
# PyCox stores these as pandas DataFrames internally
baseline_hazards_log.append({
"Imp": imp_idx,
"Fold": fold_idx,
"Death_BH": model_d.baseline_hazards_.copy(),
"Readm_BH": model_r.baseline_hazards_.copy()
})
# --- NEW: STORE RAW PREDICTIONS FOR CALIBRATION PLOTS ---
if 'raw_predictions_log' not in locals():
raw_predictions_log = []
# 2026-02-17= Compress the massive DataFrames into simple dictionaries mapping Horizon -> 1D array of probabilities
surv_d_compressed = {h: risk_at_horizon(surv_val_d, h) for h in EVAL_HORIZONS_TEST}
surv_r_compressed = {h: risk_at_horizon(surv_val_r, h) for h in EVAL_HORIZONS_TEST}
raw_predictions_log.append({
'imp': imp_idx,
'fold': fold_idx,
'surv_val_d': surv_d_compressed, # DataFrame of survival curves for death
'surv_val_r': surv_r_compressed, # DataFrame of survival curves for readm
'y_time_val': t_val,
'y_event_val': e_val
})
del model_d, model_r
gc.collect()
if torch.cuda.is_available(): torch.cuda.empty_cache()
# --- 5. AGGREGATION & EXPORT CSVs ---
df_res = pd.DataFrame(pooled_results)
df_thresholds = pd.DataFrame(threshold_records)
summary_stats = []
for (outcome, time_pt, metric), group in df_res.groupby(["Outcome", "Time", "Metric"]):
vals = group["Value"].dropna().values
mean_val, lower, upper = bootstrap_ci_non_normal(vals)
summary_stats.append({
"Outcome": outcome, "Time": time_pt, "Metric": metric,
"Mean": mean_val, "CI_Lower": lower, "CI_Upper": upper,
"Format": f"{mean_val:.3f} [{lower:.3f}-{upper:.3f}]",
})
df_summary = pd.DataFrame(summary_stats)
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
metrics_file = f"DS_AJ_Pooled_DeepSurv_{timestamp}.csv"
threshold_file = f"DS_AJ_Thresholds_DeepSurv_{timestamp}.csv"
df_summary.to_csv(metrics_file, sep=";", index=False)
df_thresholds.to_csv(threshold_file, sep=";", index=False)
# --- 6. EXPORT SHAP PERSISTENCE (.pkl) ---
shap_file = None
if 'shap_multi_agg' in locals() and len(shap_multi_agg['Death'][SHAP_HORIZONS[0]]['vals']) > 0:
print("\n💾 Consolidating Multi-Horizon SHAP data with Base Values...")
final_shap_export = {'Death': {}, 'Readmission': {}}
for outcome in ['Death', 'Readmission']:
for h in SHAP_HORIZONS:
vals_list = shap_multi_agg[outcome][h]['vals']
data_list = shap_multi_agg[outcome][h]['data']
base_list = shap_multi_agg[outcome][h]['base'] # ✅ FIX
if vals_list:
final_shap_export[outcome][h] = {
'shap_values': np.concatenate(vals_list, axis=0),
'data': pd.concat(data_list, axis=0),
'base_values': np.concatenate(base_list, axis=0) # ✅ FIX
}
shap_file = f"DS_AJ_MultiHorizon_SHAP_{timestamp}.pkl"
with open(shap_file, "wb") as f: pickle.dump(final_shap_export, f)
# --- 7. EXPORT BASELINE HAZARDS (.pkl) ---
bh_file = f"DS_AJ_BaselineHazards_{timestamp}.pkl"
with open(bh_file, "wb") as f:
pickle.dump(baseline_hazards_log, f)
print(f"✅ Saved Baseline Hazards pickle: '{bh_file}'")
with open(f"DS_AJ_RawPreds_{timestamp}.pkl", "wb") as f:
pickle.dump(raw_predictions_log, f)
print(f"\n✅ Finished! Metrics saved to: '{metrics_file}'")
total_duration_min = (time.time() - start_time) / 60
print(f"🏁 Total Execution Time: {total_duration_min:.2f} minutes")