# -*- coding: utf-8 -*- """ DACTRL — AUC-ROC alongside F1 for main LOSO results ===================================================== Runs the same clean LOSO protocol as dactrl_seeg_clean_eval.py but also computes AUC from the ProtoNet distance score: score = db / (dp + db + 1e-8) (higher = more likely PGES) Reports F1 and AUC at K=0,2,5,10,20 per patient and summarised by nucleus. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import random, warnings from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score, roc_auc_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}]", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] _infer_seizure_ids = _v3g['_infer_seizure_ids'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_auc_results") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 K_LIST = [0, 2, 5, 10, 20] N_TRIALS = 10 NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT','P13':'ANT', 'P14':'ANT','P15':'ANT', } def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self, n_ctx=N_CTX, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(n_ctx + 4, d_model) enc = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 2, dropout=dropout, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb( torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler, n_epochs=SEQ_EP, lr=SEQ_LR): model.train() opt = torch.optim.Adam(model.parameters(), lr=lr) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for ep in range(n_epochs): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def eval_patient(model, scaler, train_patients, test_patient, k_list, n_trials): model.eval() X_te_n = scaler.transform(test_patient['X'].astype(np.float32)) y_te = test_patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_te_n)): seqs.append(X_te_n[i - N_CTX: i]) lbls.append(y_te[i]) if not seqs: return {k: {'F1': 0.5, 'AUC': 0.5} for k in k_list} seqs = np.array(seqs, dtype=np.float32) lbls = np.array(lbls, dtype=np.int32) Z_te = encode(model, seqs) # Zero-shot: use training prototypes Z_tr_p, Z_tr_b = [], [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) tr_seqs, tr_lbls = [], [] for i in range(N_CTX, len(X_n)): tr_seqs.append(X_n[i - N_CTX: i]) tr_lbls.append(p['labels'][i]) if not tr_seqs: continue Z_tr = encode(model, np.array(tr_seqs, dtype=np.float32)) tr_lbls = np.array(tr_lbls) if Z_tr[tr_lbls == 1].shape[0] > 0: Z_tr_p.append(Z_tr[tr_lbls == 1]) if Z_tr[tr_lbls == 0].shape[0] > 0: Z_tr_b.append(Z_tr[tr_lbls == 0]) results = {} for k in k_list: f1s, aucs = [], [] for trial in range(n_trials): rng = np.random.RandomState(trial * 13 + 3) if k == 0: if not Z_tr_p or not Z_tr_b: f1s.append(0.5); aucs.append(0.5); continue pp = np.mean(np.vstack(Z_tr_p), axis=0) pb = np.mean(np.vstack(Z_tr_b), axis=0) else: sup_idx, qry_idx = diversity_support(lbls, k) if sup_idx is None or len(qry_idx) < 5: f1s.append(0.5); aucs.append(0.5); continue if len(np.unique(lbls[sup_idx])) < 2: f1s.append(0.5); aucs.append(0.5); continue pp = Z_te[sup_idx[lbls[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[lbls[sup_idx] == 0]].mean(axis=0) Z_te = Z_te # use all for AUC but query subset for F1 lbls_q = lbls[qry_idx] Z_q = Z_te[qry_idx] dp_q = np.linalg.norm(Z_q - pp, axis=1) db_q = np.linalg.norm(Z_q - pb, axis=1) score_q = db_q / (dp_q + db_q + 1e-8) f1s.append(f1_score(lbls_q, (dp_q < db_q).astype(int), zero_division=0)) if len(np.unique(lbls_q)) == 2: aucs.append(roc_auc_score(lbls_q, score_q)) else: aucs.append(float('nan')) continue # k==0 path dp = np.linalg.norm(Z_te - pp, axis=1) db = np.linalg.norm(Z_te - pb, axis=1) score = db / (dp + db + 1e-8) f1s.append(f1_score(lbls, (dp < db).astype(int), zero_division=0)) if len(np.unique(lbls)) == 2: aucs.append(roc_auc_score(lbls, score)) else: aucs.append(float('nan')) results[k] = {'F1': np.nanmean(f1s), 'AUC': np.nanmean(aucs)} return results if __name__ == '__main__': log("=" * 60) log("DACTRL — AUC-ROC + F1 (LOSO, K=0..20)") log("=" * 60) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_pids = sorted(p for p in raw.keys() if p != 'P13') patients = [{'pid': p, 'X': raw[p]['X'], 'labels': raw[p]['y_temporal']} for p in all_pids] log(f"Patients: {all_pids}") rows = [] for fold_i, test_p in enumerate(patients): pid = test_p['pid'] nucleus = NUCLEUS_MAP.get(pid, '?') train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) res = eval_patient(model, scaler, train_ps, test_p, K_LIST, N_TRIALS) for k in K_LIST: f1 = res[k]['F1'] auc = res[k]['AUC'] log(f" {pid} ({nucleus}) K={k:2d}: F1={f1:.4f} AUC={auc:.4f}") rows.append({'pid': pid, 'nucleus': nucleus, 'K': k, 'F1': f1, 'AUC': auc}) del model torch.cuda.empty_cache() if torch.cuda.is_available() else None df = pd.DataFrame(rows) df.to_csv(TAB_DIR / "auc_f1_per_patient.csv", index=False) summary = df.groupby('K')[['F1', 'AUC']].agg(['mean', 'std']).reset_index() summary.columns = ['K', 'F1_mean', 'F1_std', 'AUC_mean', 'AUC_std'] summary.to_csv(TAB_DIR / "auc_f1_summary.csv", index=False) log("\n=== Summary (N=14 patients, excl. P13) ===") log(summary.to_string(index=False, float_format='{:.4f}'.format)) nucleus_summary = df.groupby(['nucleus', 'K'])[['F1', 'AUC']].mean().reset_index() nucleus_summary.to_csv(TAB_DIR / "auc_f1_by_nucleus.csv", index=False) # Plot F1 and AUC together at K=10 df10 = df[df['K'] == 10].copy() fig, axes = plt.subplots(1, 2, figsize=(12, 5)) for ax, metric, color in zip(axes, ['F1', 'AUC'], ['#2166ac', '#d6604d']): vals = df10.groupby('nucleus')[metric].agg(['mean', 'std']) ax.bar(vals.index, vals['mean'], yerr=vals['std'], color=color, alpha=0.8, capsize=5, edgecolor='white') ax.axhline(df10[metric].mean(), linestyle='--', color='black', alpha=0.6, label=f'Overall mean={df10[metric].mean():.3f}') ax.set_ylim(0.5, 1.05) ax.set_ylabel(metric, fontsize=12) ax.set_title(f'{metric} by Nucleus (K=10)', fontsize=12) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) plt.suptitle('DACTRL-TSM: F1 and AUC-ROC at K=10 (LOSO)', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "auc_f1_nucleus.png", dpi=150, bbox_inches='tight') # K-curve fig2, ax2 = plt.subplots(figsize=(8, 5)) ax2.errorbar(summary.K, summary.F1_mean, yerr=summary.F1_std, fmt='o-', capsize=4, color='#2166ac', label='F1', linewidth=2) ax2.errorbar(summary.K, summary.AUC_mean, yerr=summary.AUC_std, fmt='s--', capsize=4, color='#d6604d', label='AUC-ROC', linewidth=2) ax2.set_xlabel('K (support examples per class)', fontsize=12) ax2.set_ylabel('Score', fontsize=12) ax2.set_title('DACTRL-TSM: F1 and AUC-ROC vs K\n(LOSO, N=14)', fontsize=13) ax2.set_ylim(0.5, 1.05); ax2.legend(fontsize=11); ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(FIG_DIR / "auc_f1_k_curve.png", dpi=150, bbox_inches='tight') log(f"\nSaved figures to {FIG_DIR}") log("Done.")