# -*- coding: utf-8 -*- """ DACTRL — Auto-Calibration with Temperature Scaling ==================================================== Current ProtoNet: binary classification — closer to PGES or BASE prototype. Problem: raw distances give no calibrated probability → can't set clinical threshold. This script adds TEMPERATURE SCALING per patient: P(PGES) = sigmoid( (d_base - d_pges) / T ) T is fitted automatically on the K support examples per patient — same K examples used to build the prototype also calibrate the temperature. Also tests ADAPTIVE THRESHOLD: instead of the midpoint between prototypes, use the K support to learn the optimal decision boundary via precision-recall. Outputs: - ECE (Expected Calibration Error) before/after temperature scaling - F1 comparison: standard ProtoNet vs calibrated ProtoNet - Reliability diagram (predicted probability vs actual frequency) - Optimal operating points for clinical deployment (recall-focused) """ import os; os.environ.setdefault('PYTHONIOENCODING','utf-8') import gc, random, warnings, math from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import (f1_score, precision_recall_curve, roc_auc_score, brier_score_loss) from sklearn.isotonic import IsotonicRegression from scipy.special import expit as sigmoid from scipy.optimize import minimize_scalar import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}]", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3,'r',errors='replace') as f: exec(compile(f.read().replace("if __name__=='__main__':","if __name__=='__never__':"),str(_V3),'exec'),_v3g) load_all_seeg = _v3g['load_all_seeg'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_calibration") FIG_DIR = OUT_ROOT/"figures"; TAB_DIR = OUT_ROOT/"tables" for d in [OUT_ROOT,FIG_DIR,TAB_DIR]: d.mkdir(parents=True,exist_ok=True) N_FEAT = 16; N_CTX = 8; D_MODEL = 64; N_HEADS = 4; N_LAYERS = 4 SEQ_EP = 150; SEQ_LR = 3e-4 K_CALIB = 10 # K examples used for calibration N_TRIALS= 10 PRIMARY_EXCLUDE = {'P13'} def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(N_CTX+4, d_model) enc = nn.TransformerEncoderLayer(d_model=d_model,nhead=n_heads, dim_feedforward=d_model*2,dropout=dropout,batch_first=True) self.transformer = nn.TransformerEncoder(enc,num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) def forward(self, x, return_hidden=False): B,T,_ = x.shape h = self.proj_in(x)+self.pos_emb(torch.arange(T,device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T,device=x.device) h = self.transformer(h,mask=mask,is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(seq_tf, train_ps, scaler): seq_tf.train(); opt = torch.optim.Adam(seq_tf.parameters(),lr=SEQ_LR) seqs = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)).astype(np.float32) base = X_n[p['labels']==0] for i in range(N_CTX+1,len(base)): seqs.append(base[i-N_CTX-1:i]) if not seqs: return seq_tf seqs = np.array(seqs,dtype=np.float32) ds = torch.utils.data.TensorDataset(torch.tensor(seqs[:,:N_CTX]),torch.tensor(seqs[:,-1])) ld = torch.utils.data.DataLoader(ds,batch_size=64,shuffle=True) for ep in range(SEQ_EP): for xc,xt in ld: xc,xt = xc.to(DEVICE),xt.to(DEVICE) pred = seq_tf(xc)[:,-1,:] loss = (1.-F.cosine_similarity(pred,xt,dim=1).mean())+0.5*F.mse_loss(pred,xt) opt.zero_grad(); loss.backward(); opt.step() if (ep+1)%50==0: log(f" ep{ep+1}") return seq_tf def encode_cls(seq_tf, seqs): seq_tf.eval(); z=[] for i in range(0,len(seqs),32): b=torch.tensor(seqs[i:i+32],dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(seq_tf(b,return_hidden=True)[:,-1,:].cpu().numpy()) return np.vstack(z) def get_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)).astype(np.float32) y = patient['labels'].astype(np.int32) seqs,lbls = [],[] for i in range(N_CTX,len(X_n)): seqs.append(X_n[i-N_CTX:i]); lbls.append(y[i]) if not seqs: return None,None return np.array(seqs,dtype=np.float32), np.array(lbls,dtype=np.int32) def ece(probs, labels, n_bins=10): """Expected Calibration Error.""" bins = np.linspace(0,1,n_bins+1) ece_val = 0. for lo,hi in zip(bins[:-1],bins[1:]): mask = (probs>=lo) & (probs0.5).astype(int), zero_division=0) f1_cal = f1_score(lbls_te, (probs_cal>0.5).astype(int), zero_division=0) # Recall-optimised threshold (clinical: miss PGES is dangerous) pr, rc, thr = precision_recall_curve(lbls_te, probs_cal) # Find threshold giving recall >= 0.90 recall_target = 0.90 valid = np.where(rc[:-1] >= recall_target)[0] if len(valid): best_thr = thr[valid[-1]] f1_recall_opt = f1_score(lbls_te,(probs_cal>best_thr).astype(int),zero_division=0) else: best_thr = 0.5; f1_recall_opt = f1_cal ece_uncal = ece(probs_uncal, lbls_te) ece_cal = ece(probs_cal, lbls_te) all_rows.append({ 'pid': test_p['pid'], 'trial': trial, 'T_fitted': T, 'F1_uncal': f1_uncal, 'F1_cal': f1_cal, 'F1_recall90': f1_recall_opt, 'ECE_uncal': ece_uncal, 'ECE_cal': ece_cal, 'AUC': roc_auc_score(lbls_te, probs_cal) if lbls_te.sum()>0 else float('nan'), }) all_probs_uncal.extend(probs_uncal.tolist()) all_probs_cal.extend(probs_cal.tolist()) all_labels.extend(lbls_te.tolist()) del seq_tf; gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() df = pd.DataFrame(all_rows) df.to_csv(TAB_DIR/'calibration_per_patient.csv', index=False) summary = df.groupby('pid').agg( F1_uncal=('F1_uncal','mean'), F1_cal=('F1_cal','mean'), F1_recall90=('F1_recall90','mean'), ECE_uncal=('ECE_uncal','mean'), ECE_cal=('ECE_cal','mean'), T_mean=('T_fitted','mean'), AUC=('AUC','mean') ).reset_index() summary.to_csv(TAB_DIR/'calibration_summary.csv', index=False) log("\n=== CALIBRATION RESULTS ===") log(f" F1 uncalibrated: {df.F1_uncal.mean():.4f} ± {df.F1_uncal.std():.4f}") log(f" F1 calibrated : {df.F1_cal.mean():.4f} ± {df.F1_cal.std():.4f}") log(f" F1 recall≥90% : {df.F1_recall90.mean():.4f} ± {df.F1_recall90.std():.4f}") log(f" ECE uncalibrated: {df.ECE_uncal.mean():.4f}") log(f" ECE calibrated : {df.ECE_cal.mean():.4f}") log(f" Temperature T : {df.T_fitted.mean():.3f} ± {df.T_fitted.std():.3f}") log(f" AUC : {df.AUC.mean():.4f}") # Reliability diagram fig, axes = plt.subplots(1,2,figsize=(12,5)) for ax, probs, title in zip(axes, [np.array(all_probs_uncal), np.array(all_probs_cal)], ['Uncalibrated','Temperature-Scaled']): lbls_arr = np.array(all_labels) bins = np.linspace(0,1,11) bin_centers, bin_acc = [],[] for lo,hi in zip(bins[:-1],bins[1:]): mask = (probs>=lo)&(probs5: bin_centers.append((lo+hi)/2) bin_acc.append(lbls_arr[mask].mean()) ax.plot([0,1],[0,1],'k--',alpha=0.5,label='Perfect calibration') ax.plot(bin_centers,bin_acc,'o-',color='#e74c3c',label='Model') ax.set_xlabel('Mean Predicted Probability') ax.set_ylabel('Fraction of PGES Positives') ax.set_title(f'{title}\nECE={ece(probs,lbls_arr):.4f}') ax.legend(); ax.grid(alpha=0.3); ax.set_xlim(0,1); ax.set_ylim(0,1) plt.suptitle('Reliability Diagram — TSM ProtoNet Calibration', y=1.02) plt.tight_layout() plt.savefig(FIG_DIR/'reliability_diagram.png', dpi=150, bbox_inches='tight') log(f"Saved: {FIG_DIR/'reliability_diagram.png'}") log("[DONE]")