# -*- coding: utf-8 -*- """ DACTRL — Clinical Evaluation ============================== Runs LOSO and computes: 1. False Alarm Rate for DACTRL-TSM at K=0, 2, 5, 10, 20 (FA/hr and FA per 180s post-ictal window) 2. Detection Latency — windows until first correct PGES detection post-seizure (at K=10) 3. Conformal Prediction — adaptive prediction sets with guaranteed 90% marginal coverage using RAPS (Regularised Adaptive Prediction Sets) Score: s = dp / (dp + db + 1e-8) (lower = more PGES-like) Outputs: results/dactrl_clinical_eval/tables/false_alarm_rate.csv results/dactrl_clinical_eval/tables/detection_latency.csv results/dactrl_clinical_eval/tables/conformal_coverage.csv results/dactrl_clinical_eval/figures/fa_rate_by_k.png results/dactrl_clinical_eval/figures/detection_latency.png results/dactrl_clinical_eval/figures/conformal_sets.png """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import random, warnings, copy from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}] " f"{torch.cuda.get_device_name(0) if torch.cuda.is_available() else ''}", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_clinical_eval") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 K_LIST = [0, 2, 5, 10, 20] N_TRIALS = 10 EPOCH_SEC = 5 # seconds per window CONFORMAL_ALPHA = 0.10 # 90% coverage NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT', 'P14':'ANT','P15':'ANT', } def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self, n_ctx=N_CTX, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(n_ctx + 4, d_model) enc = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 2, dropout=dropout, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb( torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler): model.train() opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for _ in range(SEQ_EP): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def make_seqs(X_n, y): seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i - N_CTX: i]) lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def get_prototypes_k0(model, scaler, train_patients): Z_tr_p, Z_tr_b = [], [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) s, l = make_seqs(X_n, p['labels']) if s is None: continue Z = encode(model, s); l = np.array(l) if Z[l == 1].shape[0] > 0: Z_tr_p.append(Z[l == 1]) if Z[l == 0].shape[0] > 0: Z_tr_b.append(Z[l == 0]) if not Z_tr_p or not Z_tr_b: return None, None return np.mean(np.vstack(Z_tr_p), axis=0), np.mean(np.vstack(Z_tr_b), axis=0) def pges_score(Z, pp, pb): """Lower score = more PGES-like (distance to PGES prototype).""" dp = np.linalg.norm(Z - pp, axis=1) db = np.linalg.norm(Z - pb, axis=1) return dp / (dp + db + 1e-8) # 0=PGES, 1=baseline if __name__ == '__main__': log("=" * 65) log("DACTRL — Clinical Evaluation (FA Rate, Latency, Conformal)") log("=" * 65) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_pids = sorted(p for p in raw.keys() if p != 'P13') patients = [{'pid': p, 'X': raw[p]['X'], 'labels': raw[p]['y_temporal']} for p in all_pids] log(f"Patients: {all_pids} (N={len(patients)})") fa_rows, lat_rows, conf_rows = [], [], [] # Conformal calibration scores (collected across folds, used at end) cal_scores_pges = [] # nonconformity scores for PGES windows cal_scores_base = [] # nonconformity scores for baseline windows for fold_i, test_p in enumerate(patients): pid = test_p['pid'] nucleus = NUCLEUS_MAP.get(pid, '?') train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) model.eval() X_te_n = scaler.transform(test_p['X'].astype(np.float32)) y_te = test_p['labels'].astype(np.int32) seqs_te, lbls = make_seqs(X_te_n, y_te) if seqs_te is None: log(f" [{fold_i+1:02d}] {pid}: skip (no sequences)") continue Z_te = encode(model, seqs_te) n_windows = len(lbls) total_dur_hr = n_windows * EPOCH_SEC / 3600.0 log(f" [{fold_i+1:02d}/{len(patients)}] {pid} ({nucleus}) " f"n={n_windows} PGES={lbls.sum()} dur={total_dur_hr:.2f}hr") pp0, pb0 = get_prototypes_k0(model, scaler, train_ps) # ── False Alarm Rate at each K ───────────────────────────── for k in K_LIST: f1s, fas = [], [] for trial in range(N_TRIALS): if k == 0: if pp0 is None: continue pp, pb = pp0, pb0 preds = (np.linalg.norm(Z_te - pp, axis=1) < np.linalg.norm(Z_te - pb, axis=1)).astype(int) qry_mask = np.ones(len(lbls), dtype=bool) else: sup_idx, qry_idx = diversity_support(lbls, k) if sup_idx is None or len(qry_idx) < 5: continue if len(np.unique(lbls[sup_idx])) < 2: continue pp = Z_te[sup_idx[lbls[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[lbls[sup_idx] == 0]].mean(axis=0) preds_all = (np.linalg.norm(Z_te - pp, axis=1) < np.linalg.norm(Z_te - pb, axis=1)).astype(int) preds = preds_all[qry_idx] qry_mask_idx = qry_idx lbls_q = lbls[qry_idx] # FA: false positives among true baseline windows fp = np.sum((preds == 1) & (lbls_q == 0)) base_windows = np.sum(lbls_q == 0) base_dur_hr = base_windows * EPOCH_SEC / 3600.0 fa_hr = fp / base_dur_hr if base_dur_hr > 0 else 0.0 fa_180 = fp / max(1, np.sum(lbls_q == 1) // 36) f1s.append(f1_score(lbls_q, preds, zero_division=0)) fas.append(fa_hr) continue # k==0 path fp = np.sum((preds == 1) & (lbls == 0)) base_dur_hr = np.sum(lbls == 0) * EPOCH_SEC / 3600.0 fa_hr = fp / base_dur_hr if base_dur_hr > 0 else 0.0 f1s.append(f1_score(lbls, preds, zero_division=0)) fas.append(fa_hr) fa_rows.append({'pid': pid, 'nucleus': nucleus, 'K': k, 'F1': np.mean(f1s) if f1s else float('nan'), 'FA_per_hour': np.mean(fas) if fas else float('nan')}) # ── Detection Latency at K=10 ───────────────────────────── sup_idx, qry_idx = diversity_support(lbls, 10) if sup_idx is not None and len(np.unique(lbls[sup_idx])) == 2: pp = Z_te[sup_idx[lbls[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[lbls[sup_idx] == 0]].mean(axis=0) preds_all = (np.linalg.norm(Z_te - pp, axis=1) < np.linalg.norm(Z_te - pb, axis=1)).astype(int) # Find contiguous PGES runs and measure windows until first TP in_pges, run_start, latencies = False, 0, [] for i, lbl in enumerate(lbls): if lbl == 1 and not in_pges: in_pges = True; run_start = i elif lbl == 0 and in_pges: # End of PGES run — find first correct detection run = np.arange(run_start, i) detected = np.where(preds_all[run] == 1)[0] if len(detected) > 0: latencies.append(int(detected[0])) else: latencies.append(int(len(run))) # missed in_pges = False if latencies: lat_rows.append({ 'pid': pid, 'nucleus': nucleus, 'mean_latency_windows': np.mean(latencies), 'mean_latency_sec': np.mean(latencies) * EPOCH_SEC, 'median_latency_sec': np.median(latencies) * EPOCH_SEC, 'n_episodes': len(latencies), }) # ── Conformal: collect nonconformity scores (split conformal) ── # Use K=0 prototypes as calibration signal if pp0 is not None: scores = pges_score(Z_te, pp0, pb0) cal_scores_pges.extend(scores[lbls == 1].tolist()) cal_scores_base.extend(scores[lbls == 0].tolist()) del model if torch.cuda.is_available(): torch.cuda.empty_cache() # ── Save FA rate ─────────────────────────────────────────────── fa_df = pd.DataFrame(fa_rows) fa_df.to_csv(TAB_DIR / "false_alarm_rate.csv", index=False) fa_summary = fa_df.groupby('K')[['F1', 'FA_per_hour']].mean().reset_index() log("\n=== False Alarm Rate (DACTRL-TSM, mean across patients) ===") log(fa_summary.to_string(index=False, float_format='{:.4f}'.format)) # ── Save detection latency ───────────────────────────────────── lat_df = pd.DataFrame(lat_rows) lat_df.to_csv(TAB_DIR / "detection_latency.csv", index=False) if not lat_df.empty: log(f"\n=== Detection Latency (K=10) ===") log(f" Mean: {lat_df['mean_latency_sec'].mean():.1f}s") log(f" Median: {lat_df['median_latency_sec'].median():.1f}s") log(f" Total PGES episodes: {lat_df['n_episodes'].sum()}") # ── Conformal prediction ─────────────────────────────────────── cal_p = np.array(cal_scores_pges) cal_b = np.array(cal_scores_base) # RAPS threshold: qhat such that P(score <= qhat) >= 1-alpha on calibration # For PGES detection: score low = PGES. Threshold on PGES class scores. all_scores = np.concatenate([cal_p, cal_b]) all_labels = np.concatenate([np.ones(len(cal_p)), np.zeros(len(cal_b))]) # Split: use 50% as calibration, measure coverage on remaining rng = np.random.RandomState(42) idx = rng.permutation(len(all_scores)) n_cal = len(idx) // 2 cal_idx, val_idx = idx[:n_cal], idx[n_cal:] # Nonconformity score for PGES: how surprising is it that this window is PGES? # For PGES windows, score should be low (close to prototype) cal_pges_scores = all_scores[cal_idx][all_labels[cal_idx] == 1] qhat = float(np.quantile(cal_pges_scores, 1 - CONFORMAL_ALPHA + 1/len(cal_pges_scores))) val_preds = (all_scores[val_idx] <= qhat).astype(int) val_true = all_labels[val_idx].astype(int) coverage = np.mean(val_preds[val_true == 1]) # coverage of PGES windows fpr = np.mean(val_preds[val_true == 0]) # false positive rate set_size = np.mean(val_preds) # average prediction set size conf_rows.append({ 'alpha': CONFORMAL_ALPHA, 'target_coverage': 1 - CONFORMAL_ALPHA, 'empirical_coverage': float(coverage), 'false_positive_rate': float(fpr), 'qhat': float(qhat), 'n_cal': len(cal_pges_scores), }) conf_df = pd.DataFrame(conf_rows) conf_df.to_csv(TAB_DIR / "conformal_coverage.csv", index=False) log(f"\n=== Conformal Prediction (alpha={CONFORMAL_ALPHA}) ===") log(f" Target coverage: {1-CONFORMAL_ALPHA:.0%}") log(f" Empirical coverage:{coverage:.3f}") log(f" False positive rate:{fpr:.3f}") log(f" q_hat: {qhat:.4f}") # ── Figures ─────────────────────────────────────────────────── fig, axes = plt.subplots(1, 3, figsize=(16, 5)) # FA rate vs K ax = axes[0] ax.plot(fa_summary['K'], fa_summary['FA_per_hour'], 'o-', color='#d73027', linewidth=2, markersize=8) ax.axhline(257, linestyle='--', color='#fc8d59', alpha=0.7, label='XGBoost LOSO') ax.axhline(720, linestyle=':', color='#999999', alpha=0.7, label='Threshold rule') ax.set_xlabel('K (support examples)', fontsize=11) ax.set_ylabel('False Alarms / Hour', fontsize=11) ax.set_title('DACTRL-TSM False Alarm Rate\nvs K (LOSO)', fontsize=11) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) # Detection latency distribution ax = axes[1] if not lat_df.empty: ax.hist(lat_df['mean_latency_sec'], bins=10, color='#2166ac', alpha=0.8, edgecolor='white') ax.axvline(lat_df['mean_latency_sec'].mean(), linestyle='--', color='#d73027', label=f"Mean={lat_df['mean_latency_sec'].mean():.1f}s") ax.set_xlabel('Detection Latency (seconds)', fontsize=11) ax.set_ylabel('Count (patients)', fontsize=11) ax.set_title('PGES Detection Latency\n(K=10, first correct window)', fontsize=11) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) # Conformal prediction score distributions ax = axes[2] ax.hist(cal_p, bins=40, alpha=0.6, color='#d73027', label='PGES windows', density=True) ax.hist(cal_b, bins=40, alpha=0.6, color='#2166ac', label='Baseline windows', density=True) ax.axvline(qhat, linestyle='--', color='black', linewidth=2, label=f'q_hat={qhat:.3f} (90% cov.)') ax.set_xlabel('Nonconformity Score (lower = more PGES)', fontsize=11) ax.set_ylabel('Density', fontsize=11) ax.set_title(f'Conformal Score Distributions\nCoverage={coverage:.3f} (target=0.90)', fontsize=11) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) plt.suptitle('DACTRL-TSM Clinical Evaluation', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "clinical_eval_summary.png", dpi=150, bbox_inches='tight') # Detection latency per nucleus if not lat_df.empty: fig2, ax2 = plt.subplots(figsize=(8, 5)) nuc_lat = lat_df.groupby('nucleus')['mean_latency_sec'].agg(['mean','std']) ax2.bar(nuc_lat.index, nuc_lat['mean'], yerr=nuc_lat['std'], color='#2166ac', alpha=0.8, capsize=5, edgecolor='white') ax2.axhline(lat_df['mean_latency_sec'].mean(), linestyle='--', color='#d73027', label=f"Overall mean={lat_df['mean_latency_sec'].mean():.1f}s") ax2.set_ylabel('Detection Latency (seconds)', fontsize=11) ax2.set_title('Detection Latency by Nucleus (K=10)', fontsize=12) ax2.legend(fontsize=10); ax2.grid(True, alpha=0.3) plt.tight_layout() plt.savefig(FIG_DIR / "detection_latency_nucleus.png", dpi=150, bbox_inches='tight') log(f"\nSaved -> {FIG_DIR}") log("Done.")