# -*- coding: utf-8 -*- """ DACTRL — Detection Latency =========================== Computes how many windows (and seconds) elapse after the start of each PGES episode before DACTRL-TSM makes its first correct detection. Uses K=10 LOSO protocol. Outputs: results/dactrl_detection_latency/tables/latency_per_episode.csv results/dactrl_detection_latency/tables/latency_summary.csv results/dactrl_detection_latency/figures/latency_boxplot.png """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import random, warnings from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}]", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_detection_latency") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 EPOCH_SEC = 5 N_TRIALS = 5 NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT', 'P14':'ANT','P15':'ANT', } def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj_in = nn.Linear(N_FEAT, D_MODEL) self.pos_emb = nn.Embedding(N_CTX + 4, D_MODEL) enc = nn.TransformerEncoderLayer( d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=D_MODEL*2, dropout=0.1, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=N_LAYERS) self.proj_out = nn.Linear(D_MODEL, N_FEAT) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler): model.train() opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for _ in range(SEQ_EP): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def get_pges_runs(lbls): """Return list of (start, end) index pairs for contiguous PGES runs.""" runs = [] in_run, start = False, 0 for i, l in enumerate(lbls): if l == 1 and not in_run: in_run, start = True, i elif l == 0 and in_run: runs.append((start, i)) in_run = False if in_run: runs.append((start, len(lbls))) return runs if __name__ == '__main__': log("=" * 60) log("DACTRL — Detection Latency (K=10, LOSO)") log("=" * 60) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_pids = sorted(p for p in raw.keys() if p != 'P13') patients = [{'pid': p, 'X': raw[p]['X'], 'labels': raw[p]['y_temporal']} for p in all_pids] log(f"Patients: {all_pids} (N={len(patients)})") rows = [] for fold_i, test_p in enumerate(patients): pid = test_p['pid'] nucleus = NUCLEUS_MAP.get(pid, '?') train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) model.eval() X_te_n = scaler.transform(test_p['X'].astype(np.float32)) y_te = test_p['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_te_n)): seqs.append(X_te_n[i - N_CTX: i]) lbls.append(y_te[i]) if not seqs: continue seqs = np.array(seqs, dtype=np.float32) lbls = np.array(lbls, dtype=np.int32) pges_runs = get_pges_runs(lbls) if not pges_runs or lbls.sum() == 0: log(f" [{fold_i+1:02d}] {pid} ({nucleus}): no PGES runs, skip") continue Z_te = encode(model, seqs) # Average over N_TRIALS support samples trial_latencies = {i: [] for i in range(len(pges_runs))} for trial in range(N_TRIALS): sup_idx, qry_idx = diversity_support(lbls, 10) if sup_idx is None or len(np.unique(lbls[sup_idx])) < 2: continue pp = Z_te[sup_idx[lbls[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[lbls[sup_idx] == 0]].mean(axis=0) preds = (np.linalg.norm(Z_te - pp, axis=1) < np.linalg.norm(Z_te - pb, axis=1)).astype(int) for ep_i, (s, e) in enumerate(pges_runs): run_preds = preds[s:e] detected = np.where(run_preds == 1)[0] lat = int(detected[0]) if len(detected) > 0 else int(e - s) trial_latencies[ep_i].append(lat) for ep_i, (s, e) in enumerate(pges_runs): lats = trial_latencies[ep_i] if not lats: continue mean_lat_w = float(np.mean(lats)) detected_pct = float(np.mean([l < (e - s) for l in lats])) rows.append({ 'pid': pid, 'nucleus': nucleus, 'episode': ep_i + 1, 'episode_len_windows': int(e - s), 'episode_len_sec': int(e - s) * EPOCH_SEC, 'mean_latency_windows': mean_lat_w, 'mean_latency_sec': mean_lat_w * EPOCH_SEC, 'detection_rate': detected_pct, }) n_ep = len(pges_runs) mean_lat = np.mean([r['mean_latency_sec'] for r in rows if r['pid'] == pid]) log(f" [{fold_i+1:02d}] {pid} ({nucleus}): {n_ep} episodes " f"mean_latency={mean_lat:.1f}s") del model if torch.cuda.is_available(): torch.cuda.empty_cache() df = pd.DataFrame(rows) df.to_csv(TAB_DIR / "latency_per_episode.csv", index=False) summary = df.groupby('nucleus').agg( mean_latency_sec=('mean_latency_sec', 'mean'), median_latency_sec=('mean_latency_sec', 'median'), std_latency_sec=('mean_latency_sec', 'std'), detection_rate=('detection_rate', 'mean'), n_episodes=('episode', 'count'), ).reset_index() summary.to_csv(TAB_DIR / "latency_summary.csv", index=False) overall_mean = df['mean_latency_sec'].mean() overall_median = df['mean_latency_sec'].median() log(f"\n=== Detection Latency Summary (K=10) ===") log(f" Overall mean: {overall_mean:.1f}s") log(f" Overall median: {overall_median:.1f}s") log(f" Total episodes: {len(df)}") log(summary.to_string(index=False, float_format='{:.2f}'.format)) # Boxplot by nucleus fig, axes = plt.subplots(1, 2, figsize=(13, 5)) ax = axes[0] nuclei = sorted(df['nucleus'].unique()) data_by_nuc = [df[df['nucleus'] == n]['mean_latency_sec'].values for n in nuclei] bp = ax.boxplot(data_by_nuc, labels=nuclei, patch_artist=True, medianprops=dict(color='black', linewidth=2)) colors = ['#4dac26', '#d73027', '#2166ac', '#fc8d59'] for patch, color in zip(bp['boxes'], colors): patch.set_facecolor(color); patch.set_alpha(0.7) ax.axhline(overall_mean, linestyle='--', color='black', alpha=0.5, label=f'Overall mean={overall_mean:.1f}s') ax.set_ylabel('Detection Latency (seconds)', fontsize=11) ax.set_title('PGES Detection Latency by Nucleus\n(K=10, LOSO)', fontsize=11) ax.legend(fontsize=9); ax.grid(True, axis='y', alpha=0.3) ax = axes[1] ax.hist(df['mean_latency_sec'], bins=15, color='#2166ac', alpha=0.8, edgecolor='white') ax.axvline(overall_mean, linestyle='--', color='#d73027', linewidth=2, label=f'Mean={overall_mean:.1f}s') ax.axvline(overall_median, linestyle=':', color='#4dac26', linewidth=2, label=f'Median={overall_median:.1f}s') ax.set_xlabel('Detection Latency (seconds)', fontsize=11) ax.set_ylabel('Episode count', fontsize=11) ax.set_title('Detection Latency Distribution\n(all episodes, all patients)', fontsize=11) ax.legend(fontsize=9); ax.grid(True, alpha=0.3) plt.suptitle('DACTRL-TSM: PGES Detection Latency (17 features)', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "latency_boxplot.png", dpi=150, bbox_inches='tight') log(f"\nSaved -> {FIG_DIR}") log("Done.")