# -*- coding: utf-8 -*- """ DACTRL Temporal Sequence Model — Causal Transformer over EEG Windows ===================================================================== HYPOTHESIS: Every prior approach treats each 30s window independently. PGES has a clear, unique temporal signature: [pre-ictal baseline] → [ictal burst] → [PGES onset] → [PGES plateau] → [recovery] A causal transformer over consecutive window embeddings can: (a) Learn this transition pattern from other patients' seizure sequences via self-supervised next-embedding prediction (no labels needed) (b) At K=0: detect the transition anomaly to classify windows (c) At K>0: use sequence-level prototypes that capture PGES dynamics not just single-window statistics This is the ONLY approach in the study that exploits temporal structure. ARCHITECTURE: Window encoder: 16-dim features → 32-dim embedding (MLP, pretrained SupCon) Sequence encoder: N_CTX=8 consecutive embeddings → causal transformer (4 layers) Pre-training: predict embedding[t+1] from embeddings[t-N_CTX:t] Loss: cosine similarity + MSE on thalamic baseline sequences At test time: TSM_anomaly K=0: anomaly score = prediction error spikes at ictal→PGES TSM_seq_k0 K=0: sequence prototype from other patients' PGES sequences TSM_seq_kshot K>0: patient-specific sequence prototype from K labeled sequences WHY CAUSAL: We want online deployment — only look at past windows, not future. WHY BASELINE PRE-TRAIN: We have hundreds of baseline windows (no labels needed). The model learns normal thalamic dynamics, so PGES = anomaly. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, random, warnings, copy, math from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy import signal as sp_signal from scipy.stats import entropy as sp_entropy from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score, roc_auc_score import torch import torch.nn as nn import torch.nn.functional as F try: import mne; mne.set_log_level('ERROR') MNE_OK = True except ImportError: MNE_OK = False warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}] " f"{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) # ── Import v3 infrastructure ────────────────────────────────────────────────── _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as _f: _src = _f.read().replace("if __name__ == '__main__':", "if __name__ == '__v3_never__':") exec(compile(_src, str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] compute_consensus_thresholds = _v3g['compute_consensus_thresholds'] reapply_consensus = _v3g['reapply_consensus'] FullModel = _v3g['FullModel'] stage1_supcon = _v3g['stage1_supcon'] protonet_classify = _v3g['protonet_classify'] diversity_support = _v3g['diversity_support'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" import pandas as _pd OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_temporal_seq") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 16 EMB_DIM = 32 # window embedding dimension after sequence encoder projection N_CTX = 8 # consecutive windows as context (4 minutes at 30s/window) D_MODEL = 64 # transformer hidden dimension N_HEADS = 4 N_LAYERS = 4 SEQ_PRETRAIN_EP = 150 # epochs for next-embedding pre-training SEQ_LR = 3e-4 K_LIST = [0, 2, 5, 10, 20] N_TRIALS = 10 PRIMARY_EXCLUDE = {'P13'} def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) def _emb(model, X_norm): """Get embeddings from FullModel (calls forward with thalamic depth).""" model.eval() with torch.no_grad(): return model(torch.tensor(X_norm, dtype=torch.float32).to(DEVICE), 'thalamic').cpu().numpy() def _proto(Z, proto_pges, proto_base): """Nearest-prototype classification in embedding space.""" d_p = np.linalg.norm(Z - proto_pges, axis=1) d_b = np.linalg.norm(Z - proto_base, axis=1) return (d_p < d_b).astype(int) # ══════════════════════════════════════════════════════════════════════════════ # Feature extraction (same as all other scripts) # ══════════════════════════════════════════════════════════════════════════════ def _perm_entropy(sig, order=3): if len(sig) < order: return 0.0 try: n = len(sig) - order + 1 counts = {} for i in range(n): p = tuple(np.argsort(sig[i:i+order])) counts[p] = counts.get(p, 0) + 1 probs = np.array(list(counts.values())) / n return float(-np.sum(probs * np.log2(probs+1e-10)) / math.log2(math.factorial(order))) except: return 0.0 def _lzc(sig): try: med = np.median(sig) b = ''.join('1' if v > med else '0' for v in sig) i, k, l, c, n = 0, 1, 1, 1, len(b) while True: if b[i+k-1] not in b[:i+k-1]: if i+k < n: i += l; l = k = 1; c += 1 else: break else: k += 1 return float(c * np.log2(n+1) / (n+1e-8)) except: return 0.0 def _etc(sig, q=10): try: bins = np.linspace(sig.min(), sig.max()+1e-8, q+1) sym = np.digitize(sig, bins) - 1 return float(np.sum(sym[1:] != sym[:-1]) / max(len(sig)-1, 1)) except: return 0.0 def _approx_entropy(sig, m=2, r_mult=0.2): if len(sig) < m+2: return 0.0 r = r_mult*(np.std(sig)+1e-8) def phi(m_): t = np.array([sig[i:i+m_] for i in range(len(sig)-m_+1)]) C = np.mean([np.mean(np.max(np.abs(t-t[i]),axis=1)<=r) for i in range(len(t))])+1e-8 return np.log(C) try: return float(phi(m)-phi(m+1)) except: return 0.0 def extract_features(epoch, fs=256): feats = [] for ch_idx in range(epoch.shape[0]): sig = epoch[ch_idx].astype(float) iqr = np.percentile(sig, 75) - np.percentile(sig, 25) + 1e-8 sig_n = sig / iqr f, pxx = sp_signal.welch(sig, fs, nperseg=min(512, len(sig))) def bp(lo, hi): return float(np.mean(pxx[(f>=lo)&(f<=hi)])+1e-12) delta,theta,alpha,beta = bp(1,4),bp(4,8),bp(8,13),bp(13,30) total = delta+theta+alpha+beta+1e-12 mad = np.median(np.abs(sig-np.median(sig)))+1e-8 abs_sig = np.abs(sig) rel_sr = float(np.mean(abs_sig < 0.15*mad)) feats += [ rel_sr, delta/total, theta/total, alpha/total, float(np.mean(pxx[(f>=1)&(f<=4)]) / (np.mean(pxx[(f>=8)&(f<=13)])+1e-12)), _approx_entropy(sig_n[:256]), float(np.sum(np.diff(np.sign(sig_n))) != 0) / max(len(sig_n)-1,1), float(np.std(sig_n)), _lzc(sig_n[:512]), _etc(sig_n[:512]), _perm_entropy(sig_n[:512]), float(np.mean(sig_n**2)), float(np.percentile(np.abs(sig_n), 90)), float(beta/total), float(np.mean(np.abs(np.diff(sig_n)))), float(np.max(np.abs(sig_n))), ] arr = np.array(feats[:N_FEAT], dtype=np.float32) arr = np.where(np.isfinite(arr), arr, 0.0) return arr # ══════════════════════════════════════════════════════════════════════════════ # Temporal Sequence Transformer # ══════════════════════════════════════════════════════════════════════════════ class CausalTransformer(nn.Module): """ Causal (autoregressive) transformer over window embedding sequences. Input: (B, T, d_in) — sequence of window embeddings Output: (B, T, D_MODEL) — contextual representation at each timestep """ def __init__(self, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(N_CTX + 4, d_model) enc_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model*2, dropout=dropout, batch_first=True ) self.transformer = nn.TransformerEncoder(enc_layer, num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) # predict next embedding def forward(self, x, return_hidden=False): # x: (B, T, d_in) B, T, _ = x.shape pos = torch.arange(T, device=x.device).unsqueeze(0) h = self.proj_in(x) + self.pos_emb(pos) # causal mask: each position can only attend to previous positions mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) if return_hidden: return h # (B, T, D_MODEL) return self.proj_out(h) # (B, T, d_in) — next-step predictions class SequenceClassifier(nn.Module): """ Full pipeline: window features → SupCon encoder → Causal transformer → CLS representation → distance-based classification """ def __init__(self, window_encoder, seq_transformer): super().__init__() self.win_enc = window_encoder # FullModel (frozen during seq pre-training) self.seq_tf = seq_transformer def encode_sequence(self, x_seq): # x_seq: (B, T, N_FEAT) B, T, F = x_seq.shape x_flat = x_seq.view(B*T, F) with torch.no_grad(): z_flat = self.win_enc(x_flat) # (B*T, emb_dim) z_seq = z_flat.view(B, T, -1) h = self.seq_tf(z_seq, return_hidden=True) # (B, T, D_MODEL) return h[:, -1, :] # CLS = last position (causal — summarises full context) def build_sequences(patient, scaler, seq_len=N_CTX, stride=1): """ Extract overlapping sequences of length seq_len from a patient's windows. Returns: X_seqs (N_seq, seq_len, N_FEAT), y_seq (N_seq,), positions """ X_raw = patient['X'].astype(np.float32) y_all = patient['labels'].astype(np.int32) X_norm = scaler.transform(X_raw).astype(np.float32) seqs, labels, idxs = [], [], [] n = len(X_norm) for i in range(seq_len, n, stride): seq = X_norm[i-seq_len:i] seqs.append(seq) labels.append(y_all[i]) # label of the final window in context idxs.append(i) if not seqs: return None, None, None return np.array(seqs, dtype=np.float32), np.array(labels, dtype=np.int32), np.array(idxs) def pretrain_seq_transformer(seq_tf, train_patients, scaler, n_epochs=SEQ_PRETRAIN_EP, lr=SEQ_LR): """ Self-supervised pre-training: predict next 16-dim feature vector from context. Trained on BASELINE windows only — learns normal thalamic dynamics. No window encoder needed: CausalTransformer operates directly on 16-dim features. """ seq_tf.train() opt = torch.optim.Adam(seq_tf.parameters(), lr=lr) all_seqs = [] for p in train_patients: X_norm = scaler.transform(p['X'].astype(np.float32)).astype(np.float32) base_idx = np.where(p['labels'] == 0)[0] X_base = X_norm[base_idx] for i in range(N_CTX + 1, len(X_base), 1): all_seqs.append(X_base[i-N_CTX-1:i]) # length N_CTX+1 if not all_seqs: return seq_tf all_seqs = np.array(all_seqs, dtype=np.float32) X_ctx = all_seqs[:, :N_CTX, :] # (N, N_CTX, N_FEAT) X_target = all_seqs[:, -1, :] # (N, N_FEAT) dataset = torch.utils.data.TensorDataset( torch.tensor(X_ctx), torch.tensor(X_target) ) loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True) for ep in range(n_epochs): total_loss = 0.0 for x_ctx_b, x_tgt_b in loader: x_ctx_b, x_tgt_b = x_ctx_b.to(DEVICE), x_tgt_b.to(DEVICE) pred = seq_tf(x_ctx_b)[:, -1, :] # (B, N_FEAT) cos_loss = 1.0 - F.cosine_similarity(pred, x_tgt_b, dim=1).mean() mse_loss = F.mse_loss(pred, x_tgt_b) loss = cos_loss + 0.5 * mse_loss opt.zero_grad(); loss.backward(); opt.step() total_loss += loss.item() if (ep+1) % 50 == 0: log(f" Seq pretrain ep {ep+1}/{n_epochs}: loss={total_loss/len(loader):.4f}") return seq_tf def compute_anomaly_scores(seq_tf, patient, scaler): """ Compute per-window anomaly scores = next-feature prediction error. High score at ictal→PGES transition = anomaly detector. CausalTransformer predicts next 16-dim feature directly (no window encoder). Returns: scores (N_windows,), y_all (N_windows,) """ seq_tf.eval() X_norm = scaler.transform(patient['X'].astype(np.float32)).astype(np.float32) y_all = patient['labels'].astype(np.int32) scores = np.zeros(len(X_norm), dtype=np.float32) with torch.no_grad(): for i in range(N_CTX, len(X_norm)): ctx = torch.tensor(X_norm[i-N_CTX:i]).unsqueeze(0).to(DEVICE) pred = seq_tf(ctx)[0, -1] # predicted next feature (N_FEAT,) actual = torch.tensor(X_norm[i]).to(DEVICE) err = 1.0 - F.cosine_similarity(pred.unsqueeze(0), actual.unsqueeze(0)).item() scores[i] = err return scores, y_all def encode_sequences_cls(seq_tf, seqs): """ Encode sequences (N, T, N_FEAT) via CausalTransformer → CLS token (last position). Returns: Z (N, D_MODEL) """ seq_tf.eval() z_all = [] for i in range(0, len(seqs), 32): batch = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): h = seq_tf(batch, return_hidden=True) # (B, T, D_MODEL) z_all.append(h[:, -1, :].cpu().numpy()) return np.vstack(z_all) def seq_kshot_eval(seq_tf, scaler, train_patients, test_patient, k_list, n_trials): """ K-shot classification using CausalTransformer CLS sequence embeddings. Also includes anomaly-based K=0 classification. CausalTransformer operates directly on 16-dim normalized features. """ seq_tf.eval() results = {k: [] for k in k_list} results['anomaly_k0'] = [] # ── K=0: Anomaly detection ───────────────────────────────────────── scores, y_all = compute_anomaly_scores(seq_tf, test_patient, scaler) base_scores = scores[y_all == 0] thresh = np.mean(base_scores) + 0.5 * np.std(base_scores) if base_scores.shape[0] > 0 else np.median(scores) preds_anomaly = (scores > thresh).astype(int) results['anomaly_k0'].append(f1_score(y_all, preds_anomaly, zero_division=0)) # ── Build sequence embeddings (CLS tokens) ───────────────────────── Z_train_pges, Z_train_base = [], [] for p in train_patients: seqs_all, lbls_all, _ = build_sequences(p, scaler, seq_len=N_CTX) if seqs_all is None: continue z_seqs = encode_sequences_cls(seq_tf, seqs_all) if z_seqs[lbls_all==1].shape[0] > 0: Z_train_pges.append(z_seqs[lbls_all==1]) if z_seqs[lbls_all==0].shape[0] > 0: Z_train_base.append(z_seqs[lbls_all==0]) seqs_test, lbls_test, _ = build_sequences(test_patient, scaler, N_CTX) if seqs_test is None: for k in k_list: results[k].append(0.5) return results Z_test = encode_sequences_cls(seq_tf, seqs_test) pges_test_idx = np.where(lbls_test == 1)[0] base_test_idx = np.where(lbls_test == 0)[0] for trial in range(n_trials): rng = np.random.RandomState(trial * 11 + 7) for k in k_list: if k == 0: if Z_train_pges: proto_p = np.vstack(Z_train_pges).mean(0) proto_b = np.vstack(Z_train_base).mean(0) if Z_train_base else -proto_p else: results[k].append(0.5) continue else: if len(pges_test_idx) == 0: results[k].append(0.5) continue k_use = min(k, len(pges_test_idx)) k_base = min(k, len(base_test_idx)) sup_p = rng.choice(pges_test_idx, k_use, replace=False) sup_b = rng.choice(base_test_idx, k_base, replace=False) proto_p = Z_test[sup_p].mean(0) proto_b = Z_test[sup_b].mean(0) # classify via nearest prototype d_p = np.linalg.norm(Z_test - proto_p, axis=1) d_b = np.linalg.norm(Z_test - proto_b, axis=1) preds = (d_p < d_b).astype(int) results[k].append(f1_score(lbls_test, preds, zero_division=0)) return results # ══════════════════════════════════════════════════════════════════════════════ # Main # ══════════════════════════════════════════════════════════════════════════════ if __name__ == '__main__': log("=" * 70) log("DACTRL Temporal Sequence Model — Causal Transformer over EEG Windows") log("=" * 70) # ── Load data ────────────────────────────────────────────────────────────── log("[1] Loading thalamic patients...") meta_df = _pd.read_excel(METADATA) raw = load_all_seeg(meta_df) thresh = compute_consensus_thresholds(raw) reapply_consensus(raw, thresh) patients = [{'pid': pid, 'X': raw[pid]['X'], 'labels': raw[pid]['y_temporal']} for pid in sorted(raw.keys()) if pid not in PRIMARY_EXCLUDE] log(f" {len(patients)} patients loaded.") # ── Build scaler + train base window encoder ─────────────────────────────── log("[2] Training base window encoder (thal-only SupCon)...") X_all = np.vstack([p['X'] for p in patients]).astype(np.float32) y_all = np.concatenate([p['labels'] for p in patients]).astype(np.int32) thal_scaler = StandardScaler().fit(X_all) X_all_n = thal_scaler.transform(X_all).astype(np.float32) log(" Scaler fitted.") # ── LOSO evaluation ─────────────────────────────────────────────────────── log("[3] LOSO evaluation with Causal Transformer...") all_res = { 'win_only': {k: [] for k in K_LIST}, # FullModel SupCon (no sequence) 'tsm_anomaly': [], 'tsm_seq': {k: [] for k in K_LIST}, # CausalTransformer CLS prototypes } per_patient = [] for i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] log(f" [{i+1}/{len(patients)}] Patient {pid}...") # ── Window-only baseline: FullModel SupCon ───────────────────────── X_tr = thal_scaler.transform( np.vstack([p['X'] for p in train_ps]).astype(np.float32) ).astype(np.float32) y_tr = np.concatenate([p['labels'] for p in train_ps]).astype(np.int32) enc_loso = stage1_supcon(X_tr, y_tr) enc_loso.eval() X_te_n = thal_scaler.transform(test_p['X'].astype(np.float32)).astype(np.float32) y_te = test_p['labels'].astype(np.int32) Z_te = _emb(enc_loso, X_te_n) Z_pool_pges, Z_pool_base = [], [] for p in train_ps: X_p_n = thal_scaler.transform(p['X'].astype(np.float32)).astype(np.float32) Z_p = _emb(enc_loso, X_p_n) yp = p['labels'].astype(np.int32) if Z_p[yp==1].shape[0] > 0: Z_pool_pges.append(Z_p[yp==1]) if Z_p[yp==0].shape[0] > 0: Z_pool_base.append(Z_p[yp==0]) Z_pp = np.vstack(Z_pool_pges).mean(0) if Z_pool_pges else np.zeros(Z_te.shape[1]) Z_pb = np.vstack(Z_pool_base).mean(0) if Z_pool_base else np.zeros(Z_te.shape[1]) pges_idx = np.where(y_te==1)[0] base_idx = np.where(y_te==0)[0] for trial in range(N_TRIALS): rng = np.random.RandomState(trial*13+5) for k in K_LIST: if k == 0: pp, pb = Z_pp, Z_pb else: if len(pges_idx)==0: all_res['win_only'][k].append(0.5); continue ku=min(k,len(pges_idx)); kb=min(k,len(base_idx)) pp=Z_te[rng.choice(pges_idx,ku,replace=False)].mean(0) pb=Z_te[rng.choice(base_idx,kb,replace=False)].mean(0) preds = _proto(Z_te, pp, pb) all_res['win_only'][k].append(f1_score(y_te, preds, zero_division=0)) # ── Causal Transformer ───────────────────────────────────────────── log(f" Pre-training causal transformer on baseline sequences...") seq_tf = CausalTransformer(d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS).to(DEVICE) seq_tf = pretrain_seq_transformer(seq_tf, train_ps, thal_scaler, n_epochs=SEQ_PRETRAIN_EP) tsm_res = seq_kshot_eval(seq_tf, thal_scaler, train_ps, test_p, K_LIST, N_TRIALS) all_res['tsm_anomaly'].extend(tsm_res['anomaly_k0']) for k in K_LIST: all_res['tsm_seq'][k].extend(tsm_res[k]) log(f" Win-only K=0: {np.mean(all_res['win_only'][0]):.3f} " f"K=10: {np.mean(all_res['win_only'][10]):.3f}") log(f" TSM anomaly K=0: {np.mean(tsm_res['anomaly_k0']):.3f} " f"TSM seq K=0: {np.mean(tsm_res[0]):.3f} K=10: {np.mean(tsm_res[10]):.3f}") per_patient.append({ 'pid': pid, **{f'win_k{k}': np.mean(all_res['win_only'][k]) for k in K_LIST}, **{f'tsm_seq_k{k}': np.mean(tsm_res[k]) for k in K_LIST}, 'tsm_anomaly': np.mean(tsm_res['anomaly_k0']), }) # ── Summary ─────────────────────────────────────────────────────────────── log("") log("=" * 70) log("FINAL SUMMARY") log("=" * 70) log("") log("Window-only ProtoNet (thal SupCon, no sequence):") for k in K_LIST: v = all_res['win_only'][k] log(f" K={k:2d}: F1={np.mean(v):.3f} +/- {np.std(v):.3f}") log("") log("TSM Anomaly Detection (K=0, no labels):") v = all_res['tsm_anomaly'] log(f" F1={np.mean(v):.3f} +/- {np.std(v):.3f}") log("") log("TSM Sequence ProtoNet:") for k in K_LIST: v = all_res['tsm_seq'][k] log(f" K={k:2d}: F1={np.mean(v):.3f} +/- {np.std(v):.3f}") log("") log("Delta TSM_seq - Win_only:") for k in K_LIST: d = np.mean(all_res['tsm_seq'][k]) - np.mean(all_res['win_only'][k]) log(f" K={k:2d}: {d:+.3f}") # ── Save ────────────────────────────────────────────────────────────────── rows = [] for k in K_LIST: rows.append({'scenario':'win_only','K':k, 'F1':np.mean(all_res['win_only'][k]),'std':np.std(all_res['win_only'][k])}) rows.append({'scenario':'tsm_seq','K':k, 'F1':np.mean(all_res['tsm_seq'][k]),'std':np.std(all_res['tsm_seq'][k])}) rows.append({'scenario':'tsm_anomaly','K':0, 'F1':np.mean(all_res['tsm_anomaly']),'std':np.std(all_res['tsm_anomaly'])}) pd.DataFrame(rows).to_csv(TAB_DIR/"tsm_summary.csv", index=False) pd.DataFrame(per_patient).to_csv(TAB_DIR/"tsm_per_patient.csv", index=False) # ── Figure ──────────────────────────────────────────────────────────────── fig, axes = plt.subplots(1, 2, figsize=(13, 5)) # Left: K-curve comparison ax = axes[0] ax.plot(K_LIST, [np.mean(all_res['win_only'][k]) for k in K_LIST], 'o-', color='steelblue', lw=2, label='Window-only ProtoNet') ax.plot(K_LIST, [np.mean(all_res['tsm_seq'][k]) for k in K_LIST], 's-', color='darkorange', lw=2, label='TSM Sequence ProtoNet') ax.axhline(np.mean(all_res['tsm_anomaly']), color='green', ls='--', lw=1.5, label=f"TSM Anomaly K=0 ({np.mean(all_res['tsm_anomaly']):.3f})") ax.set_xlabel('K'); ax.set_ylabel('F1'); ax.set_title('Temporal Sequence Model vs Window-only') ax.legend(); ax.grid(alpha=0.3); ax.set_ylim(0.4, 1.0) # Right: anomaly score example for first test patient with PGES sample_p = next((p for p in patients if any(l==1 for l in p['labels'])), None) if sample_p: seq_tf_vis = CausalTransformer(d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS).to(DEVICE) seq_tf_vis = pretrain_seq_transformer(seq_tf_vis, patients, thal_scaler, n_epochs=50) scores_vis, y_vis = compute_anomaly_scores(seq_tf_vis, sample_p, thal_scaler) ax2 = axes[1] t_ax = np.arange(len(scores_vis)) ax2.plot(t_ax, scores_vis, color='gray', lw=1, alpha=0.7, label='Anomaly score') pges_t = t_ax[y_vis==1] ax2.axvspan(pges_t[0] if len(pges_t)>0 else 0, pges_t[-1] if len(pges_t)>0 else 1, alpha=0.2, color='red', label='PGES') ax2.set_xlabel('Window index'); ax2.set_ylabel('Prediction error (anomaly)') ax2.set_title(f'Anomaly scores — {sample_p["pid"]}') ax2.legend(); ax2.grid(alpha=0.3) plt.tight_layout() plt.savefig(FIG_DIR/"tsm_results.png", dpi=150, bbox_inches='tight') plt.close() log(f"[DONE] Results: {TAB_DIR}") log(f" Figures: {FIG_DIR}")