# -*- coding: utf-8 -*- """ DACTRL — Probability Calibration (17 features) ================================================ Evaluates how well the ProtoNet distance score is calibrated as a probability estimate. Reports: - Expected Calibration Error (ECE) at K=10 - Maximum Calibration Error (MCE) - Reliability diagram (confidence vs accuracy) - Brier score - Temperature scaling: finds optimal T to minimise NLL on held-out windows - Pre- vs post-calibration ECE comparison Score -> probability mapping: raw: p = 1 - score where score = dp/(dp+db+1e-8) temp: p = sigmoid(logit(raw) / T) LOSO protocol — same as clean eval. Outputs: results/dactrl_calibration_17feat/tables/calibration_metrics.csv results/dactrl_calibration_17feat/figures/reliability_diagram.png results/dactrl_calibration_17feat/figures/ece_by_nucleus.png """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import random, warnings from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import brier_score_loss from sklearn.calibration import calibration_curve import torch import torch.nn as nn import torch.nn.functional as F from scipy.special import expit as sigmoid from scipy.optimize import minimize_scalar warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}] " f"{torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''}", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_calibration_17feat") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 N_BINS = 10 N_TRIALS = 5 NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT', 'P14':'ANT','P15':'ANT', } def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) def ece(probs, labels, n_bins=N_BINS): bins = np.linspace(0, 1, n_bins + 1) ece_val, mce_val = 0.0, 0.0 for lo, hi in zip(bins[:-1], bins[1:]): mask = (probs >= lo) & (probs < hi) if mask.sum() == 0: continue acc = labels[mask].mean() conf = probs[mask].mean() ece_val += mask.sum() / len(probs) * abs(acc - conf) mce_val = max(mce_val, abs(acc - conf)) return float(ece_val), float(mce_val) def temp_scale_nll(T, logits, labels): probs = sigmoid(logits / T) probs = np.clip(probs, 1e-7, 1 - 1e-7) return -np.mean(labels * np.log(probs) + (1 - labels) * np.log(1 - probs)) class CausalTransformer(nn.Module): def __init__(self, n_ctx=N_CTX, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(n_ctx + 4, d_model) enc = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 2, dropout=dropout, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb( torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler): model.train() opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for _ in range(SEQ_EP): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) if __name__ == '__main__': log("=" * 60) log("DACTRL — Calibration (17 features, LOSO)") log("=" * 60) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_pids = sorted(p for p in raw.keys() if p != 'P13') patients = [{'pid': p, 'X': raw[p]['X'], 'labels': raw[p]['y_temporal']} for p in all_pids] log(f"Patients: {all_pids} (N={len(patients)})") all_probs_raw, all_probs_cal, all_labels = [], [], [] rows = [] for fold_i, test_p in enumerate(patients): pid = test_p['pid'] nucleus = NUCLEUS_MAP.get(pid, '?') train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) model.eval() X_te_n = scaler.transform(test_p['X'].astype(np.float32)) y_te = test_p['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_te_n)): seqs.append(X_te_n[i - N_CTX: i]) lbls.append(y_te[i]) if not seqs: log(f" [{fold_i+1}] {pid}: skip") continue seqs = np.array(seqs, dtype=np.float32) lbls = np.array(lbls, dtype=np.int32) Z_te = encode(model, seqs) # K=10 support/query split for calibration sup_idx, qry_idx = diversity_support(lbls, 10) if sup_idx is None or len(np.unique(lbls[sup_idx])) < 2: log(f" [{fold_i+1}] {pid}: skip (no support)") continue pp = Z_te[sup_idx[lbls[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[lbls[sup_idx] == 0]].mean(axis=0) dp = np.linalg.norm(Z_te[qry_idx] - pp, axis=1) db = np.linalg.norm(Z_te[qry_idx] - pb, axis=1) score = dp / (dp + db + 1e-8) probs_raw = 1.0 - score # prob of PGES y_q = lbls[qry_idx] # Temperature scaling on this fold's query set logits = np.log(np.clip(probs_raw, 1e-7, 1-1e-7) / np.clip(1 - probs_raw, 1e-7, 1-1e-7)) res = minimize_scalar(temp_scale_nll, bounds=(0.1, 10.0), method='bounded', args=(logits, y_q)) T_opt = float(res.x) probs_cal = sigmoid(logits / T_opt) ece_raw, mce_raw = ece(probs_raw, y_q) ece_cal, mce_cal = ece(probs_cal, y_q) brier_raw = brier_score_loss(y_q, probs_raw) brier_cal = brier_score_loss(y_q, probs_cal) log(f" [{fold_i+1:02d}] {pid} ({nucleus}) " f"ECE_raw={ece_raw:.4f} ECE_cal={ece_cal:.4f} T={T_opt:.3f}") rows.append({'pid': pid, 'nucleus': nucleus, 'ECE_raw': ece_raw, 'MCE_raw': mce_raw, 'ECE_cal': ece_cal, 'MCE_cal': mce_cal, 'Brier_raw': brier_raw, 'Brier_cal': brier_cal, 'T_opt': T_opt, 'n_query': len(y_q)}) all_probs_raw.extend(probs_raw.tolist()) all_probs_cal.extend(probs_cal.tolist()) all_labels.extend(y_q.tolist()) del model if torch.cuda.is_available(): torch.cuda.empty_cache() df = pd.DataFrame(rows) df.to_csv(TAB_DIR / "calibration_metrics.csv", index=False) log("\n=== Calibration Summary ===") log(f" ECE (raw): {df['ECE_raw'].mean():.4f} +/- {df['ECE_raw'].std():.4f}") log(f" ECE (T-scaled): {df['ECE_cal'].mean():.4f} +/- {df['ECE_cal'].std():.4f}") log(f" Brier (raw): {df['Brier_raw'].mean():.4f}") log(f" Brier (T-cal): {df['Brier_cal'].mean():.4f}") log(f" Mean T_opt: {df['T_opt'].mean():.3f}") all_probs_raw = np.array(all_probs_raw) all_probs_cal = np.array(all_probs_cal) all_labels = np.array(all_labels) # ── Reliability diagram ──────────────────────────────────────── fig, axes = plt.subplots(1, 2, figsize=(12, 5)) for ax, probs, title, color in zip( axes, [all_probs_raw, all_probs_cal], ['Before Temperature Scaling', 'After Temperature Scaling'], ['#d6604d', '#2166ac']): frac_pos, mean_pred = calibration_curve(all_labels, probs, n_bins=N_BINS) ax.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Perfect calibration') ax.plot(mean_pred, frac_pos, 'o-', color=color, linewidth=2, markersize=8, label=f'DACTRL-TSM') ax.fill_between(mean_pred, frac_pos, mean_pred, alpha=0.15, color=color) ece_v, _ = ece(probs, all_labels) ax.set_title(f'Reliability Diagram\n{title} (ECE={ece_v:.4f})', fontsize=11) ax.set_xlabel('Mean Predicted Probability', fontsize=11) ax.set_ylabel('Fraction of Positives', fontsize=11) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) ax.set_xlim(0, 1); ax.set_ylim(0, 1) plt.suptitle('DACTRL-TSM Calibration (17 features, LOSO K=10)', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "reliability_diagram.png", dpi=150, bbox_inches='tight') # ── ECE by nucleus ───────────────────────────────────────────── fig2, ax2 = plt.subplots(figsize=(8, 5)) nuc = df.groupby('nucleus')[['ECE_raw', 'ECE_cal']].mean() x = np.arange(len(nuc)) ax2.bar(x - 0.2, nuc['ECE_raw'], 0.35, label='Raw', color='#d6604d', alpha=0.8) ax2.bar(x + 0.2, nuc['ECE_cal'], 0.35, label='T-scaled', color='#2166ac', alpha=0.8) ax2.set_xticks(x); ax2.set_xticklabels(nuc.index) ax2.set_ylabel('ECE', fontsize=11) ax2.set_title('Expected Calibration Error by Nucleus', fontsize=12) ax2.legend(fontsize=10); ax2.grid(True, axis='y', alpha=0.3) plt.tight_layout() plt.savefig(FIG_DIR / "ece_by_nucleus.png", dpi=150, bbox_inches='tight') log(f"\nSaved -> {FIG_DIR}") log("Done.")