# -*- coding: utf-8 -*- """ DACTRL -- Cross-Nucleus Transfer Test ====================================== Trains DACTRL-TSM entirely on patients from one thalamic nucleus, then evaluates on patients from every other nucleus. 12 directed pairs: ANT->CL, ANT->CeM, ANT->MD, CL->ANT, CL->CeM, CL->MD, CeM->ANT, CeM->CL, CeM->MD, MD->ANT, MD->CL, MD->CeM Also includes same-nucleus LOSO baseline for reference. K evaluated: 0, 2, 5, 10 Clinical question: If a new patient arrives with nucleus X, but the training set only contains nucleus Y patients, how well does DACTRL generalise across anatomically distinct thalamic targets? Outputs: results/dactrl_cross_nucleus/tables/cross_nucleus_per_patient.csv results/dactrl_cross_nucleus/tables/cross_nucleus_summary.csv results/dactrl_cross_nucleus/figures/cross_nucleus_heatmap.png results/dactrl_cross_nucleus/figures/cross_nucleus_k_curves.png """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import random, warnings, copy from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.colors as mcolors from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}]", flush=True) torch.manual_seed(42); np.random.seed(42); random.seed(42) _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as f: exec(compile( f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] diversity_support = _v3g['diversity_support'] SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_cross_nucleus") FIG_DIR = OUT_ROOT / "figures" TAB_DIR = OUT_ROOT / "tables" for d in [OUT_ROOT, FIG_DIR, TAB_DIR]: d.mkdir(parents=True, exist_ok=True) N_FEAT = 17 N_CTX = 8 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 SEQ_EP = 150 SEQ_LR = 3e-4 N_TRIALS = 5 K_LIST = [0, 2, 5, 10] NUCLEUS_MAP = { 'P1':'CeM','P3':'CeM','P5':'CeM','P9':'CeM', 'P2':'CL', 'P7':'CL', 'P8':'CL', 'P4':'MD', 'P6':'MD', 'P10':'ANT','P11':'ANT','P12':'ANT', 'P14':'ANT','P15':'ANT', } NUC_COLORS = {'CeM':'#d73027','CL':'#4dac26','MD':'#2166ac','ANT':'#fc8d59'} def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj_in = nn.Linear(N_FEAT, D_MODEL) self.pos_emb = nn.Embedding(N_CTX + 4, D_MODEL) enc = nn.TransformerEncoderLayer( d_model=D_MODEL, nhead=N_HEADS, dim_feedforward=D_MODEL*2, dropout=0.1, batch_first=True) self.transformer = nn.TransformerEncoder(enc, num_layers=N_LAYERS) self.proj_out = nn.Linear(D_MODEL, N_FEAT) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.proj_in(x) + self.pos_emb(torch.arange(T, device=x.device).unsqueeze(0)) mask = nn.Transformer.generate_square_subsequent_mask(T, device=x.device) h = self.transformer(h, mask=mask, is_causal=True) return h if return_hidden else self.proj_out(h) def pretrain(model, train_patients, scaler): model.train() opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) seqs = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) base = X_n[p['labels'] == 0] for i in range(N_CTX + 1, len(base)): seqs.append(base[i - N_CTX - 1: i]) if not seqs: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) for _ in range(SEQ_EP): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() return model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 32): b = torch.tensor(seqs[i:i+32], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def eval_patient(model, scaler, test_p, k_list, n_trials): """Evaluate a single test patient with K-shot ProtoNet.""" X_te_n = scaler.transform(test_p['X'].astype(np.float32)) y_te = test_p['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_te_n)): seqs.append(X_te_n[i - N_CTX: i]) lbls.append(y_te[i]) if not seqs or np.sum(lbls) == 0: return {k: float('nan') for k in k_list} seqs = np.array(seqs, dtype=np.float32) lbls = np.array(lbls, dtype=np.int32) Z = encode(model, seqs) results = {} for k in k_list: if k == 0: # K=0: no test-patient support -- return nan, handled separately results[0] = float('nan') continue trial_f1 = [] for _ in range(n_trials): sup, qry = diversity_support(lbls, k) if sup is None or len(np.unique(lbls[sup])) < 2: continue pp = Z[sup[lbls[sup] == 1]].mean(axis=0) pb = Z[sup[lbls[sup] == 0]].mean(axis=0) preds = (np.linalg.norm(Z[qry] - pp, axis=1) < np.linalg.norm(Z[qry] - pb, axis=1)).astype(int) trial_f1.append(f1_score(lbls[qry], preds, zero_division=0)) results[k] = float(np.mean(trial_f1)) if trial_f1 else float('nan') return results def eval_patient_k0_crossproto(model, scaler, train_patients, test_p): """ K=0: prototypes from training patients (cross-patient), applied to test. No test-patient labels used. """ # Build cross-patient prototypes from training set all_pp, all_pb = [], [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) y = p['labels'].astype(np.int32) s, l = [], [] for i in range(N_CTX, len(X_n)): s.append(X_n[i - N_CTX: i]); l.append(y[i]) if not s: continue s = np.array(s, dtype=np.float32); l = np.array(l, dtype=np.int32) Z_p = encode(model, s) if l.sum() > 0: all_pp.append(Z_p[l == 1].mean(0)) if (l==0).sum() > 0: all_pb.append(Z_p[l == 0].mean(0)) if not all_pp or not all_pb: return float('nan') pp = np.mean(all_pp, axis=0) pb = np.mean(all_pb, axis=0) # Evaluate on test patient X_te_n = scaler.transform(test_p['X'].astype(np.float32)) y_te = test_p['labels'].astype(np.int32) s, l = [], [] for i in range(N_CTX, len(X_te_n)): s.append(X_te_n[i - N_CTX: i]); l.append(y_te[i]) if not s or np.sum(l) == 0: return float('nan') s = np.array(s, dtype=np.float32); l = np.array(l, dtype=np.int32) Z_te = encode(model, s) preds = (np.linalg.norm(Z_te - pp, axis=1) < np.linalg.norm(Z_te - pb, axis=1)).astype(int) return float(f1_score(l, preds, zero_division=0)) if __name__ == '__main__': log("=" * 60) log("DACTRL -- Cross-Nucleus Transfer Test") log("=" * 60) meta_df = pd.read_excel(METADATA) raw = load_all_seeg(meta_df) all_pids = sorted(p for p in raw.keys() if p != 'P13') patients = [{'pid': p, 'nucleus': NUCLEUS_MAP[p], 'X': raw[p]['X'], 'labels': raw[p]['y_temporal']} for p in all_pids] # Group by nucleus by_nuc = {} for p in patients: by_nuc.setdefault(p['nucleus'], []).append(p) nuclei = sorted(by_nuc.keys()) log(f"Nuclei: { {n: [p['pid'] for p in ps] for n,ps in by_nuc.items()} }") rows = [] # ── Same-nucleus LOSO (reference) ─────────────────────────────── log("\n=== Same-Nucleus LOSO (reference) ===") for nuc in nuclei: nuc_patients = by_nuc[nuc] if len(nuc_patients) < 2: log(f" {nuc}: only {len(nuc_patients)} patient(s), skip LOSO") continue for test_p in nuc_patients: train_ps = [p for p in nuc_patients if p['pid'] != test_p['pid']] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) model.eval() res = eval_patient(model, scaler, test_p, K_LIST, N_TRIALS) res[0] = eval_patient_k0_crossproto(model, scaler, train_ps, test_p) log(f" SAME {nuc} -> {test_p['pid']}: " + " ".join(f"K={k}:{res[k]:.4f}" for k in K_LIST)) for k in K_LIST: rows.append({'train_nucleus': nuc, 'test_nucleus': nuc, 'transfer': 'same', 'test_pid': test_p['pid'], 'K': k, 'F1': res[k]}) del model if torch.cuda.is_available(): torch.cuda.empty_cache() # ── Cross-nucleus transfer ──────────────────────────────────────── log("\n=== Cross-Nucleus Transfer ===") for src_nuc in nuclei: train_ps = by_nuc[src_nuc] # Train one model on ALL patients in the source nucleus X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain(model, train_ps, scaler) model.eval() log(f"\n Source: {src_nuc} ({[p['pid'] for p in train_ps]}) " f"-> testing on all other nuclei") for tgt_nuc in nuclei: if tgt_nuc == src_nuc: continue for test_p in by_nuc[tgt_nuc]: res = eval_patient(model, scaler, test_p, K_LIST, N_TRIALS) res[0] = eval_patient_k0_crossproto(model, scaler, train_ps, test_p) log(f" {src_nuc}->{tgt_nuc} ({test_p['pid']}): " + " ".join(f"K={k}:{res[k]:.4f}" for k in K_LIST)) for k in K_LIST: rows.append({'train_nucleus': src_nuc, 'test_nucleus': tgt_nuc, 'transfer': 'cross', 'test_pid': test_p['pid'], 'K': k, 'F1': res[k]}) del model if torch.cuda.is_available(): torch.cuda.empty_cache() # ── Save per-patient results ───────────────────────────────────── df = pd.DataFrame(rows) df.to_csv(TAB_DIR / "cross_nucleus_per_patient.csv", index=False) # ── Summary: mean F1 per (train_nuc, test_nuc, K) ──────────────── summary = df.groupby(['train_nucleus','test_nucleus','K']).F1.agg(['mean','std','count']).reset_index() summary.columns = ['train_nucleus','test_nucleus','K','F1_mean','F1_std','n'] summary.to_csv(TAB_DIR / "cross_nucleus_summary.csv", index=False) log("\n=== Cross-Nucleus Transfer Summary (K=10) ===") k10 = summary[summary.K == 10].pivot(index='train_nucleus', columns='test_nucleus', values='F1_mean') log("\nF1 heatmap (rows=train nucleus, cols=test nucleus):") log(k10.to_string(float_format='{:.3f}'.format)) # ── Heatmap: K=10 F1 for all (train, test) pairs ───────────────── fig, axes = plt.subplots(1, len(K_LIST), figsize=(6*len(K_LIST), 5)) for ax, K in zip(axes, K_LIST): pivot = summary[summary.K == K].pivot( index='train_nucleus', columns='test_nucleus', values='F1_mean') im = ax.imshow(pivot.values, vmin=0.4, vmax=1.0, cmap='RdYlGn', aspect='auto') ax.set_xticks(range(len(pivot.columns))) ax.set_yticks(range(len(pivot.index))) ax.set_xticklabels(pivot.columns, rotation=45, fontsize=10) ax.set_yticklabels(pivot.index, fontsize=10) for i in range(len(pivot.index)): for j in range(len(pivot.columns)): val = pivot.values[i, j] if not np.isnan(val): ax.text(j, i, f'{val:.3f}', ha='center', va='center', fontsize=9, fontweight='bold', color='white' if val < 0.6 else 'black') ax.set_title(f'K={K}', fontsize=12) ax.set_xlabel('Test Nucleus', fontsize=10) ax.set_ylabel('Train Nucleus', fontsize=10) plt.colorbar(im, ax=ax, fraction=0.046) plt.suptitle('DACTRL-TSM Cross-Nucleus Transfer F1\n(diagonal=same-nucleus LOSO, off-diagonal=cross-nucleus)', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "cross_nucleus_heatmap.png", dpi=150, bbox_inches='tight') # ── K-curves: same vs cross transfer per nucleus pair ──────────── cross_pairs = [('CeM','ANT'),('CeM','CL'),('CL','ANT'),('ANT','CeM')] fig2, axes2 = plt.subplots(1, len(cross_pairs), figsize=(5*len(cross_pairs), 5)) for ax, (src, tgt) in zip(axes2, cross_pairs): same_row = summary[(summary.train_nucleus==tgt) & (summary.test_nucleus==tgt)] cross_row = summary[(summary.train_nucleus==src) & (summary.test_nucleus==tgt)] ax.plot(same_row.K, same_row.F1_mean, 'o-', color=NUC_COLORS[tgt], linewidth=2, label=f'{tgt}->{tgt} (same)', markersize=7) ax.fill_between(same_row.K, same_row.F1_mean - same_row.F1_std, same_row.F1_mean + same_row.F1_std, alpha=0.15, color=NUC_COLORS[tgt]) ax.plot(cross_row.K, cross_row.F1_mean, 's--', color=NUC_COLORS[src], linewidth=2, label=f'{src}->{tgt} (cross)', markersize=7) ax.fill_between(cross_row.K, cross_row.F1_mean - cross_row.F1_std, cross_row.F1_mean + cross_row.F1_std, alpha=0.15, color=NUC_COLORS[src]) ax.set_xlabel('K (support shots)', fontsize=10) ax.set_ylabel('F1 Score', fontsize=10) ax.set_title(f'Test nucleus: {tgt}', fontsize=11) ax.legend(fontsize=8); ax.grid(True, alpha=0.3) ax.set_ylim(0.3, 1.05) plt.suptitle('DACTRL-TSM: Same-Nucleus vs Cross-Nucleus Transfer', fontsize=13) plt.tight_layout() plt.savefig(FIG_DIR / "cross_nucleus_k_curves.png", dpi=150, bbox_inches='tight') log(f"\nSaved -> {FIG_DIR}") log("Done.")