# -*- coding: utf-8 -*- """ DACTRL-TSM Nucleus Cross-Transfer Evaluation ============================================= Train TSM on one set of thalamic nuclei, test on a held-out nucleus. This tests whether temporal sequence representations generalise across anatomically distinct thalamic targets. Evaluation strategies: A1 — Single nucleus hold-out (4 folds: hold out ANT / CeM / CL / MD) A2 — Double nucleus hold-out (6 folds: all pairs) Best — Clinically motivated splits (ANT is largest group; CL is best performer) Directly mirrors dactrl_nucleus_comprehensive_cv.py but with TSM backbone. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, random, warnings from pathlib import Path from datetime import datetime from itertools import combinations import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}] " f"{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] _infer_seizure_ids = _v3g['_infer_seizure_ids'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_tsm_nucleus_transfer") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT', 'P14':'ANT','P15':'ANT', } NUCLEUS_GROUPS = { 'ANT': ['P10','P11','P12','P14','P15'], 'CeM': ['P1','P3','P5','P9'], 'CL': ['P2','P7','P8'], 'MD': ['P4','P6'], } PRIMARY_EXCLUDE = {'P13'} N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 K_LIST = [0, 2, 5, 10, 20] N_TRIALS = 10 def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self, n_ctx=N_CTX, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS, dropout=0.1): super().__init__() self.proj_in = nn.Linear(d_in, d_model) self.pos_emb = nn.Embedding(n_ctx + 4, d_model) enc = nn.TransformerEncoderLayer( d_model=d_model, nhead=n_heads, dim_feedforward=d_model * 2, dropout=dropout, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=n_layers) self.proj_out = nn.Linear(d_model, d_in) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb( torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler): model.train() opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for ep in range(SEQ_EP): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() if (ep + 1) % 50 == 0: log(f" ep{ep+1}/{SEQ_EP}") return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def make_sequences(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i - N_CTX: i]) lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def kshot_eval(model, scaler, train_patients, test_patient, k_list, n_trials): model.eval() Z_tr_p, Z_tr_b = [], [] for p in train_patients: s, l = make_sequences(p, scaler) if s is None: continue z = encode(model, s) if z[l == 1].shape[0] > 0: Z_tr_p.append(z[l == 1]) if z[l == 0].shape[0] > 0: Z_tr_b.append(z[l == 0]) s_te, l_te = make_sequences(test_patient, scaler) if s_te is None: return {k: [0.5] * n_trials for k in k_list} Z_te = encode(model, s_te) p_idx = np.where(l_te == 1)[0] b_idx = np.where(l_te == 0)[0] results = {k: [] for k in k_list} for trial in range(n_trials): rng = np.random.RandomState(trial * 13 + 3) for k in k_list: if k == 0: if not Z_tr_p or not Z_tr_b: results[0].append(0.5); continue pp = np.mean(np.vstack(Z_tr_p), axis=0) pb = np.mean(np.vstack(Z_tr_b), axis=0) dp = np.linalg.norm(Z_te - pp, axis=1) db = np.linalg.norm(Z_te - pb, axis=1) results[0].append(f1_score(l_te, (dp < db).astype(int), zero_division=0)) else: sz_ids = _infer_seizure_ids(l_te) sup_idx, qry_idx = diversity_support(l_te, k, sz_ids) if sup_idx is None or len(qry_idx) < 5: results[k].append(0.5); continue if len(np.unique(l_te[sup_idx])) < 2: results[k].append(0.5); continue pp = Z_te[sup_idx[l_te[sup_idx] == 1]].mean(axis=0) pb = Z_te[sup_idx[l_te[sup_idx] == 0]].mean(axis=0) dp = np.linalg.norm(Z_te[qry_idx] - pp, axis=1) db = np.linalg.norm(Z_te[qry_idx] - pb, axis=1) results[k].append(f1_score(l_te[qry_idx], (dp < db).astype(int), zero_division=0)) return results def run_split(split_name, train_pids, test_pids, all_patients): """Run one nucleus split: train on train_pids, test on each patient in test_pids.""" train_ps = [p for p in all_patients if p['pid'] in train_pids] test_ps = [p for p in all_patients if p['pid'] in test_pids] if not train_ps or not test_ps: return [] train_nucs = sorted(set(NUCLEUS_MAP.get(p['pid'], '?') for p in train_ps)) test_nucs = sorted(set(NUCLEUS_MAP.get(p['pid'], '?') for p in test_ps)) log(f"\n {split_name}: train={train_nucs} ({len(train_ps)}p) → test={test_nucs} ({len(test_ps)}p)") X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer( n_ctx=N_CTX, d_in=N_FEAT, d_model=D_MODEL, n_heads=N_HEADS, n_layers=N_LAYERS ).to(DEVICE) model = pretrain(model, train_ps, scaler) rows = [] for test_p in test_ps: pid = test_p['pid'] nucleus = NUCLEUS_MAP.get(pid, '?') res = kshot_eval(model, scaler, train_ps, test_p, K_LIST, N_TRIALS) for k in K_LIST: rows.append({ 'split': split_name, 'train_nuclei': '+'.join(train_nucs), 'test_nucleus': nucleus, 'pid': pid, 'K': k, 'F1_mean': np.mean(res[k]), 'F1_std': np.std(res[k]), }) k10 = np.mean(res[10]) log(f" {pid} ({nucleus}): K=2={np.mean(res[2]):.4f} K=10={k10:.4f}") del model; gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() return rows if __name__ == '__main__': log("=" * 65) log("DACTRL-TSM Nucleus Cross-Transfer") log("=" * 65) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_patients = [ {'pid': pid, 'X': raw[pid]['X'], 'labels': raw[pid]['y_temporal']} for pid in sorted(raw.keys()) if pid not in PRIMARY_EXCLUDE ] log(f"Patients: {[p['pid'] for p in all_patients]}") nuclei = ['ANT', 'CeM', 'CL', 'MD'] all_rows = [] # ── Strategy A1: Single nucleus hold-out (4 folds) ──────────────────────── log("\n=== Strategy A1: Single nucleus hold-out ===") for test_nuc in nuclei: test_pids = NUCLEUS_GROUPS[test_nuc] train_pids = [p['pid'] for p in all_patients if p['pid'] not in test_pids] rows = run_split(f"HoldOut_{test_nuc}", train_pids, test_pids, all_patients) all_rows.extend(rows) # ── Strategy A2: Double nucleus hold-out (6 folds) ──────────────────────── log("\n=== Strategy A2: Double nucleus hold-out ===") for nuc_pair in combinations(nuclei, 2): test_pids = [p for n in nuc_pair for p in NUCLEUS_GROUPS[n]] train_pids = [p['pid'] for p in all_patients if p['pid'] not in test_pids] if len(train_pids) < 3: log(f" SKIP {nuc_pair}: only {len(train_pids)} train patients") continue split_name = f"HoldOut_{'_'.join(nuc_pair)}" rows = run_split(split_name, train_pids, test_pids, all_patients) all_rows.extend(rows) df = pd.DataFrame(all_rows) df.to_csv(TAB_DIR / "tsm_nucleus_transfer.csv", index=False) # ── Summary ──────────────────────────────────────────────────────────────── log("\n=== A1 Summary (Single nucleus hold-out, K=10) ===") a1 = df[df.split.str.startswith('HoldOut_') & ~df.split.str.contains('_.*_')] a1_k10 = a1[a1.K == 10].groupby('test_nucleus')['F1_mean'].agg(['mean','std']) log(a1_k10.to_string(float_format='{:.4f}'.format)) # Compare to v3 nucleus CV baseline V3_NUC = {'ANT': 0.870, 'CeM': 0.840, 'CL': 0.903, 'MD': 0.942} log("\n TSM vs v3 nucleus CV (K=10):") for nuc, row in a1_k10.iterrows(): v3 = V3_NUC.get(nuc, float('nan')) log(f" {nuc}: TSM={row['mean']:.4f} v3={v3:.4f} Δ={row['mean']-v3:+.4f}") log("\n=== A2 Summary (Double nucleus hold-out, K=10) ===") a2 = df[df.split.str.startswith('HoldOut_') & df.split.str.contains('_.*_')] a2_k10 = a2[a2.K == 10].groupby('split')['F1_mean'].agg(['mean','std']) log(a2_k10.to_string(float_format='{:.4f}'.format)) # ── Figure ───────────────────────────────────────────────────────────────── fig, axes = plt.subplots(1, 2, figsize=(14, 5)) # Left: A1 K-curves per test nucleus ax = axes[0] colors = {'ANT': '#1b7837', 'CeM': '#762a83', 'CL': '#e66101', 'MD': '#2166ac'} for nuc in nuclei: sub = df[df.split == f'HoldOut_{nuc}'] by_k = sub.groupby('K')['F1_mean'].mean().reset_index() ax.plot(by_k.K, by_k.F1_mean, marker='o', linewidth=2, label=f'Test={nuc}', color=colors[nuc]) ax.set_xlabel('K'); ax.set_ylabel('F1') ax.set_title('A1: Single Nucleus Hold-Out\n(train on 3 nuclei, test on 1)') ax.set_ylim(0.5, 1.05); ax.legend(); ax.grid(alpha=0.3) # Right: A1 K=10 comparison TSM vs v3 ax = axes[1] nuc_order = ['ANT', 'CeM', 'CL', 'MD'] tsm_vals = [a1_k10.loc[n, 'mean'] if n in a1_k10.index else 0 for n in nuc_order] v3_vals = [V3_NUC[n] for n in nuc_order] x = np.arange(len(nuc_order)) w = 0.35 ax.bar(x - w/2, v3_vals, w, label='v3 K=10', color='#fc8d59', alpha=0.85) ax.bar(x + w/2, tsm_vals, w, label='TSM K=10', color='#2166ac', alpha=0.85) ax.set_xticks(x); ax.set_xticklabels(nuc_order) ax.set_ylabel('F1 (mean over held-out patients)') ax.set_title('TSM vs v3: Nucleus Hold-Out K=10'); ax.legend(); ax.grid(axis='y', alpha=0.3) ax.set_ylim(0.5, 1.05) plt.tight_layout() plt.savefig(FIG_DIR / "tsm_nucleus_transfer.png", dpi=150, bbox_inches='tight') log(f"\nSaved: {FIG_DIR / 'tsm_nucleus_transfer.png'}") log("Done.")