# -*- coding: utf-8 -*- """ DACTRL — Paired-Supervised CycleGAN + TUH Scale (C11) ======================================================= Motivation ---------- C8 showed that TUH-scale unsupervised CycleGAN provides zero benefit over thalamic-only TSM (null result at all K). The hypothesis is that the unsupervised generator — trained with no temporal correspondence between scalp and thalamic windows — learns a rough marginal distribution match but fails to capture the PGES-specific feature relationships. We have a unique asset that C8 did not exploit: three patients (P2, P10, P12) with *simultaneous* scalp (19-ch) and thalamic LFP recordings of the same seizures. For each 5-second window at time t we can extract: x_scalp(t) — 17 features from average-reference scalp x_thal(t) — 17 features from thalamic LFP This gives supervised signal: the ground-truth answer to "what does this scalp feature vector look like in thalamic feature space?" Strategy -------- Stage 1: Train unsupervised CycleGAN on TUH scalp ↔ thalamic baseline (identical to C8 — scales to thousands of scalp windows). Stage 2: Fine-tune G_S2T with a paired supervised loss on P2/P10/P12: L_paired = ||G_S2T(x_scalp(t)) - x_thal(t)||² Combined: L = L_GAN + lambda * L_paired (lambda=2.0) Stage 3: Translate TUH PGES sessions → synthetic thalamic PGES using the now-calibrated generator. Stage 4: LOSO TSM pre-training on real thalamic + synthetic thalamic PGES. LOSO leakage prevention ----------------------- When the test patient is P2, P10, or P12, that patient's paired data is excluded from Stage 2 fine-tuning. Only the remaining paired patients contribute supervised signal. Conditions ---------- A: Thalamic-only TSM (canonical baseline — reproduces C1) B: TUH CycleGAN unsupervised → TSM fine-tune (C8 null, reproduced) C: Paired-supervised G_S2T only (no TUH Stage 1; cold-start paired fine-tune) → translate TUH PGES → TSM pre-train + thalamic fine-tune D: TUH unsupervised Stage 1 → Paired-supervised Stage 2 (main result) → translate TUH PGES → TSM pre-train + thalamic fine-tune E: Condition D + Day-0 temporal heuristic (auto-label via device timestamp) """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, glob, random, threading, warnings, copy from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy.stats import wilcoxon from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F try: import mne; mne.set_log_level('ERROR') MNE_OK = True except ImportError: MNE_OK = False warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}] " f"{torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'No GPU'}") torch.manual_seed(42); np.random.seed(42); random.seed(42) # ── Paths ───────────────────────────────────────────────────────────────────── SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" TUH_BASE = r"G:\PHD Datasets\Data\TUH\edf" OUT_ROOT = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_paired_tuh_cyclegan") OUT_ROOT.mkdir(parents=True, exist_ok=True) # ── Hyper-parameters ────────────────────────────────────────────────────────── WIN_SEC = 5 FS_THAL = 250 # thalamic SEEG sampling rate FS_SCALP = 256 # institutional scalp EEG POST_OFFSET = 5 # seconds after seizure end before PGES windows POST_DUR = 180 # max PGES window extraction (seconds) PRE_DUR = 120 # pre-ictal baseline window extraction (seconds) MAX_TUH = 300 # TUH file budget N_FEAT = 17 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 N_CTX = 8 SEQ_EP_THAL = 60 SEQ_EP_TUH = 30 SEQ_LR = 3e-4 K_VALS = [0, 2, 5, 10] N_TRIALS = 5 # Features inverted between scalp and thalamic (C2) # Indices: 10=Suppression_Ratio, 0=RMS, 3=Variance (same as C8) INVERT_IDX = [10, 0, 3] NUCLEUS_MAP = { 'P1':'CeM','P2':'CL','P3':'CeM','P4':'MD','P5':'CeM', 'P6':'MD','P7':'CL','P8':'CL','P9':'CeM','P10':'ANT', 'P11':'ANT','P12':'ANT','P13':'ANT','P14':'ANT','P15':'ANT', } # Patients with simultaneous scalp+thalamic (≥18 ch) PAIRED_PIDS = {'P2', 'P10', 'P12'} # Standard 10-20 scalp channel substrings (any casing) SCALP_KEYS = ['FP1','FP2','F3','F4','F7','F8','FZ','C3','C4','CZ', 'P3','P4','PZ','T3','T4','T5','T6','O1','O2', 'A1','A2','M1','M2'] def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) # ══════════════════════════════════════════════════════════════════════════════ # Feature extraction # ══════════════════════════════════════════════════════════════════════════════ def compute_features(sig, fs): from numpy.fft import rfft, rfftfreq sig = sig - sig.mean() n = len(sig) rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs) psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0.0 total = float(psd.sum()) + 1e-10 delta = band(0.5, 4); theta = band(4, 8) alpha = band(8, 13); beta = band(13, 30) gamma = band(80, 150) sr = (delta + theta) / (alpha + beta + 1e-10) p = psd / (psd.sum() + 1e-10); p = p[p > 0] shan = float(-np.sum(p * np.log(p + 1e-10))) supp = float(np.mean(np.abs(sig) < 0.05 * np.max(np.abs(sig) + 1e-10))) def _apen(u, m=2, r=None): if r is None: r = 0.2 * np.std(u) + 1e-10 N = len(u) def phi(mm): x = np.array([u[i:i+mm] for i in range(N-mm+1)]) C = np.sum(np.max(np.abs(x[:,None]-x[None,:]),axis=2) <= r, axis=0) / (N-mm+1) return np.sum(np.log(C + 1e-10)) / (N-mm+1) return abs(phi(m) - phi(m+1)) u = sig[:min(200, n)] apen = float(_apen(u)) def _sampen(u, m=2, r=None): if r is None: r = 0.2 * np.std(u) + 1e-10 N = len(u) def _count(mm): cnt = 0 for i in range(N-mm): cnt += np.sum(np.max(np.abs( np.array([u[i:i+mm]]) - np.array([u[j:j+mm] for j in range(N-mm) if j != i])), axis=1) <= r) return cnt A = _count(m+1); B = _count(m) return float(-np.log((A+1e-10)/(B+1e-10))) sampen = float(_sampen(u)) hist, _ = np.histogram(sig, bins=10) hist = hist / (hist.sum() + 1e-10) etc = float(-np.sum(hist[hist>0] * np.log(hist[hist>0]))) bsig = (sig > np.median(sig)).astype(int) lzc = float(len(set(''.join(map(str,bsig))[i:i+4] for i in range(len(bsig)-3)))) m2 = 3; perms2 = [] for i in range(n - m2 + 1): perms2.append(tuple(np.argsort(sig[i:i+m2]))) from collections import Counter cnt = Counter(perms2); tot = sum(cnt.values()) pent = float(-sum((v/tot)*np.log(v/tot+1e-10) for v in cnt.values())) return np.array([rms, ll, zc, var, delta/total, theta/total, alpha/total, beta/total, sr, shan, supp, apen, sampen, etc, lzc, pent, gamma/total], dtype=np.float32) def apply_inversion_correction(X): X = X.copy() for idx in INVERT_IDX: X[:, idx] = -X[:, idx] return X # ══════════════════════════════════════════════════════════════════════════════ # TUH feature extraction # ══════════════════════════════════════════════════════════════════════════════ def extract_tuh_features(edf_path, csv_path, target_labels=('gnsz','tcsz')): try: raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) fs = raw.info['sfreq'] szs = [] with open(csv_path) as f: for line in f: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 5: continue lbl = parts[3].strip() if lbl in target_labels: try: szs.append((float(parts[1]), float(parts[2]))) except ValueError: continue if not szs: return None, None szs = sorted(set(szs)) merged = [list(szs[0])] for s, e in szs[1:]: if s <= merged[-1][1]: merged[-1][1] = max(merged[-1][1], e) else: merged.append([s, e]) picks = mne.pick_types(raw.info, eeg=True) if len(picks) == 0: picks = list(range(len(raw.ch_names))) data = raw.get_data(picks=picks) if np.abs(np.median(data)) < 0.01: data = data * 1e6 ch_std = data.std(axis=1) good = ch_std < np.percentile(ch_std, 90) * 3 if good.sum() == 0: good = np.ones(len(ch_std), dtype=bool) avg = data[good].mean(axis=0) win_samp = int(WIN_SEC * fs) pges_sessions, base_sessions = [], [] for sz_start, sz_end in merged: pi_start = int((sz_end + POST_OFFSET) * fs) pi_end = min(int((sz_end + POST_OFFSET + POST_DUR) * fs), len(avg)) pi_wins = [] for i in range(pi_start, pi_end - win_samp, win_samp): seg = avg[i:i+win_samp] if len(seg) == win_samp: pi_wins.append(compute_features(seg, fs)) if len(pi_wins) >= N_CTX + 2: pges_sessions.append(np.array(pi_wins, dtype=np.float32)) pre_end = int((sz_start - 10) * fs) pre_start = max(0, int((sz_start - 10 - PRE_DUR) * fs)) pre_wins = [] for i in range(pre_start, pre_end - win_samp, win_samp): seg = avg[i:i+win_samp] if len(seg) == win_samp: pre_wins.append(compute_features(seg, fs)) if len(pre_wins) >= N_CTX + 2: base_sessions.append(np.array(pre_wins, dtype=np.float32)) if not pges_sessions and not base_sessions: return None, None return pges_sessions, base_sessions except Exception: return None, None # ══════════════════════════════════════════════════════════════════════════════ # Paired scalp+thalamic extraction (P2, P10, P12) # ══════════════════════════════════════════════════════════════════════════════ def _pick_scalp_channels(raw): """Return indices of scalp channels in raw object.""" found = [] for i, ch in enumerate(raw.ch_names): cu = ch.upper().replace('-','').replace(' ','') for k in SCALP_KEYS: if cu == k or cu == k+'-AVG' or cu == k+'REF': found.append(i); break if not found or i != found[-1]: # also match on contains for k in SCALP_KEYS: if k in cu and i not in found: found.append(i); break return found def _pick_thalamic_channel(raw): """Return index of best thalamic LT bipolar channel.""" for i, ch in enumerate(raw.ch_names): cu = ch.upper().replace(' ','').replace('-','') if 'LT' in cu or 'THAL' in cu or 'ANT' in cu: return i return None def _resolve_edf(pdir, sz_file): for fname in [sz_file, sz_file.replace('.edf','.EDF')]: p = pdir / fname if p.exists(): return p return None def extract_paired_windows(pid, meta_df, seeg_root): """ For a paired patient, extract simultaneous (scalp_feat, thal_feat) pairs from every PGES window and baseline window in each seizure EDF. Returns (Xs, Xt, y) — scalp features, thalamic features, label (1=PGES/0=base). """ row = meta_df[meta_df['Patient ID'] == pid] if len(row) == 0: log(f" [WARN] {pid}: not in metadata"); return None, None, None row = row.iloc[0] pdir = seeg_root / pid # Gather (sz_file, sz_start, sz_end, pges_end) for each seizure sz_info = [] for col in meta_df.columns: if 'seizure' in col.lower() and 'file' in col.lower(): idx = col.lower().replace('file','') start_col = [c for c in meta_df.columns if idx in c.lower() and 'start' in c.lower()] end_col = [c for c in meta_df.columns if idx in c.lower() and 'end' in c.lower()] if not start_col or not end_col: continue sz_file = row.get(col, None) sz_start = row.get(start_col[0], None) sz_end = row.get(end_col[0], None) if pd.isna(sz_file) or pd.isna(sz_start): continue sz_info.append((str(sz_file), float(sz_start), float(sz_end) if not pd.isna(sz_end) else float(sz_start)+60)) if not sz_info: log(f" [WARN] {pid}: no seizure info found"); return None, None, None all_Xs, all_Xt, all_y = [], [], [] for sz_file, sz_start, sz_end in sz_info: edf_path = _resolve_edf(pdir, sz_file) if edf_path is None: continue try: raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose=False) fs = raw.info['sfreq'] scalp_idx = _pick_scalp_channels(raw) thal_idx = _pick_thalamic_channel(raw) if not scalp_idx or thal_idx is None: log(f" [WARN] {pid} {sz_file}: no scalp/thal channels found") continue data = raw.get_data() # (n_ch, n_samp) if np.abs(np.median(data)) < 0.01: data = data * 1e6 scalp_data = data[scalp_idx].mean(axis=0) # average reference thal_data = data[thal_idx] # single thalamic channel win_samp = int(WIN_SEC * fs) # PGES windows pi_start = int((sz_end + POST_OFFSET) * fs) pi_end = min(int((sz_end + POST_OFFSET + POST_DUR) * fs), scalp_data.shape[0]) for i in range(pi_start, pi_end - win_samp, win_samp): xs = compute_features(scalp_data[i:i+win_samp], fs) xt = compute_features(thal_data[i:i+win_samp], fs) if xs is not None and xt is not None: all_Xs.append(xs); all_Xt.append(xt); all_y.append(1) # Baseline windows (pre-ictal) pre_end = int((sz_start - 10) * fs) pre_start = max(0, int((sz_start - 10 - 120) * fs)) for i in range(pre_start, pre_end - win_samp, win_samp): xs = compute_features(scalp_data[i:i+win_samp], fs) xt = compute_features(thal_data[i:i+win_samp], fs) if xs is not None and xt is not None: all_Xs.append(xs); all_Xt.append(xt); all_y.append(0) except Exception as e: log(f" [ERR] {pid} {sz_file}: {e}"); continue if not all_Xs: return None, None, None return (np.array(all_Xs, dtype=np.float32), np.array(all_Xt, dtype=np.float32), np.array(all_y, dtype=np.int32)) # ══════════════════════════════════════════════════════════════════════════════ # CycleGAN (feature-space) # ══════════════════════════════════════════════════════════════════════════════ class _MLP(nn.Module): def __init__(self, in_d, hid, out_d, act_out=None): super().__init__() self.net = nn.Sequential( nn.Linear(in_d, hid), nn.LeakyReLU(0.2), nn.Linear(hid, hid), nn.LeakyReLU(0.2), nn.Linear(hid, out_d)) self.act_out = act_out def forward(self, x): o = self.net(x) return torch.tanh(o) if self.act_out == 'tanh' else o def train_cyclegan_unsupervised(scalp_wins, thal_wins, epochs=60, lr=1e-3): """Stage 1: standard unsupervised CycleGAN (same as C8).""" G_S2T = _MLP(N_FEAT, 64, N_FEAT, 'tanh').to(DEVICE) G_T2S = _MLP(N_FEAT, 64, N_FEAT, 'tanh').to(DEVICE) D_S = _MLP(N_FEAT, 32, 1).to(DEVICE) D_T = _MLP(N_FEAT, 32, 1).to(DEVICE) opt_G = torch.optim.Adam(list(G_S2T.parameters()) + list(G_T2S.parameters()), lr=lr, betas=(0.5, 0.999)) opt_D = torch.optim.Adam(list(D_S.parameters()) + list(D_T.parameters()), lr=lr, betas=(0.5, 0.999)) bce = nn.BCEWithLogitsLoss(); l1 = nn.L1Loss() S = torch.tensor(scalp_wins, dtype=torch.float32) T = torch.tensor(thal_wins, dtype=torch.float32) ds_s = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(S), batch_size=256, shuffle=True) n_t = len(T) for ep in range(epochs): for (s_b,) in ds_s: idx = torch.randint(0, n_t, (len(s_b),)) t_b = T[idx].to(DEVICE); s_b = s_b.to(DEVICE) # Generators t_fake = G_S2T(s_b); s_rec = G_T2S(t_fake) s_fake = G_T2S(t_b); t_rec = G_S2T(s_fake) loss_G = (bce(D_T(t_fake), torch.ones(len(s_b),1,device=DEVICE)) + bce(D_S(s_fake), torch.ones(len(s_b),1,device=DEVICE)) + 10 * l1(s_rec, s_b) + 10 * l1(t_rec, t_b)) opt_G.zero_grad(); loss_G.backward(); opt_G.step() # Discriminators t_fake2 = G_S2T(s_b).detach(); s_fake2 = G_T2S(t_b).detach() loss_D = (bce(D_T(t_b), torch.ones(len(s_b),1,device=DEVICE)) + bce(D_T(t_fake2), torch.zeros(len(s_b),1,device=DEVICE)) + bce(D_S(s_b), torch.ones(len(s_b),1,device=DEVICE)) + bce(D_S(s_fake2), torch.zeros(len(s_b),1,device=DEVICE))) opt_D.zero_grad(); loss_D.backward(); opt_D.step() G_S2T.eval() return G_S2T def fine_tune_with_pairs(G_S2T, paired_data, epochs=30, lr=5e-4, lam_paired=2.0): """ Stage 2: Supervised fine-tuning of G_S2T using simultaneous scalp+thalamic pairs. paired_data: list of (Xs, Xt, y) tuples from training paired patients. Loss = L1(G_S2T(x_scalp), x_thal) — direct supervised regression. """ if not paired_data: return G_S2T Xs_all = np.vstack([d[0] for d in paired_data]) Xt_all = np.vstack([d[1] for d in paired_data]) Xs_t = torch.tensor(Xs_all, dtype=torch.float32) Xt_t = torch.tensor(Xt_all, dtype=torch.float32) ds = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(Xs_t, Xt_t), batch_size=128, shuffle=True) opt = torch.optim.Adam(G_S2T.parameters(), lr=lr) l1 = nn.L1Loss() G_S2T.train() for ep in range(epochs): ep_loss = 0.0 for xs_b, xt_b in ds: xs_b = xs_b.to(DEVICE); xt_b = xt_b.to(DEVICE) pred = G_S2T(xs_b) loss = lam_paired * l1(pred, xt_b) opt.zero_grad(); loss.backward(); opt.step() ep_loss += loss.item() if (ep + 1) % 10 == 0: log(f" [paired fine-tune] ep {ep+1}/{epochs} loss={ep_loss/len(ds):.4f}") G_S2T.eval() return G_S2T def translate_sessions(G_S2T, sessions): translated = [] for sess in sessions: t_in = torch.tensor(sess, dtype=torch.float32).to(DEVICE) with torch.no_grad(): t_out = G_S2T(t_in).cpu().numpy() translated.append(t_out) return translated # ══════════════════════════════════════════════════════════════════════════════ # TSM (CausalTransformer) + ProtoNet # ══════════════════════════════════════════════════════════════════════════════ class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj = nn.Linear(N_FEAT, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, N_FEAT) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): h = self.enc(self.proj(x), mask=self.mask[:x.shape[1], :x.shape[1]]) return h if return_hidden else self.head(h) def pretrain_on_sessions(model, sessions, epochs, lr=SEQ_LR): seqs = [] for sess in sessions: if len(sess) < N_CTX + 2: continue for i in range(N_CTX + 1, len(sess)): seqs.append(sess[i - N_CTX - 1: i]) if len(seqs) < 10: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.DataLoader( torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, -1])), batch_size=128, shuffle=True) opt = torch.optim.Adam(model.parameters(), lr=lr) model.train() for _ in range(epochs): for xc, xt in ds: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() model.eval() return model def finetune_on_thalamic(model, train_patients, scaler, epochs=SEQ_EP_THAL): sessions = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] if len(sess) >= N_CTX + 2: sessions.append(sess) return pretrain_on_sessions(model, sessions, epochs, lr=SEQ_LR * 0.3) if sessions else model def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def build_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i - N_CTX: i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def diversity_support(lbls, K): pos = np.where(lbls == 1)[0]; neg = np.where(lbls == 0)[0] if len(pos) < K or len(neg) < K: return None, None return (np.random.choice(pos, K, replace=False), np.random.choice(neg, K, replace=False)) def kshot_eval(model, seqs, lbls, K, n_trials=N_TRIALS): Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') if K == 0: pp = Z[lbls==1].mean(0); pb = Z[lbls==0].mean(0) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) return float(f1_score(lbls, preds, zero_division=0)) scores = [] for _ in range(n_trials): sup_p, sup_n = diversity_support(lbls, K) if sup_p is None: continue pp = Z[sup_p].mean(0); pb = Z[sup_n].mean(0) qry = [i for i in range(len(lbls)) if i not in set(sup_p)|set(sup_n)] if not qry: continue preds = (np.linalg.norm(Z[qry]-pp,axis=1) < np.linalg.norm(Z[qry]-pb,axis=1)).astype(int) scores.append(float(f1_score(lbls[qry], preds, zero_division=0))) return float(np.mean(scores)) if scores else float('nan') # ══════════════════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════════════════ if __name__ == '__main__': log('=' * 60) log('DACTRL C11 — Paired-Supervised CycleGAN + TUH Scale') log('=' * 60) # ── Step 1: Load thalamic patients ────────────────────────────────────── log('\nStep 1: Loading thalamic patients...') _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as _f: exec(compile( _f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] meta_df = pd.read_excel(METADATA) raw_data = load_all_seeg(meta_df) patients = [{'pid': p, 'nucleus': NUCLEUS_MAP.get(p,'UNK'), 'X': d['X'], 'labels': d['y_temporal']} for p, d in raw_data.items() if d['y_temporal'].sum() > 0 and p != 'P13'] log(f' Loaded {len(patients)} thalamic patients (P13 excluded).') # ── Step 2: Extract paired scalp+thalamic windows (P2, P10, P12) ──────── log('\nStep 2: Extracting simultaneous scalp+thalamic pairs...') paired_bank = {} # pid -> (Xs, Xt, y) for pid in sorted(PAIRED_PIDS): log(f' Extracting {pid}...') Xs, Xt, y = extract_paired_windows(pid, meta_df, SEEG_ROOT) if Xs is not None: paired_bank[pid] = (Xs, Xt, y) log(f' {pid}: {len(Xs)} paired windows ({y.sum()} PGES, {(y==0).sum()} base)') else: log(f' {pid}: extraction failed — will skip') # Normalise paired data to thalamic scaler (fit later per-fold) # We store raw (unnormalised) pairs here and normalise per fold log(f' Paired bank: {len(paired_bank)} patients with simultaneous data.') # ── Step 3: Build TUH corpus ───────────────────────────────────────────── log('\nStep 3: Building TUH feature corpus (gnsz/tcsz)...') csvs = [f for f in glob.glob(os.path.join(TUH_BASE, '**', '*.csv'), recursive=True) if 'worksheet' not in f.lower()] def _has_target(f): try: return any(t in open(f, errors='ignore').read() for t in ['tcsz','gnsz']) except: return False tgt_csvs = [f for f in csvs if _has_target(f)] tgt_pairs_tuh = [(f, f.replace('.csv','.edf')) for f in tgt_csvs if os.path.exists(f.replace('.csv','.edf'))] log(f' Found {len(tgt_pairs_tuh)} EDF+CSV pairs.') np.random.seed(42) if len(tgt_pairs_tuh) > MAX_TUH: idxs = np.random.choice(len(tgt_pairs_tuh), MAX_TUH, replace=False) tgt_pairs_tuh = [tgt_pairs_tuh[i] for i in idxs] log(f' Using {len(tgt_pairs_tuh)} files (MAX_TUH={MAX_TUH}).') def _extract_with_timeout(edf_path, csv_path, timeout=180): result = [None, None] def _run(): try: result[0], result[1] = extract_tuh_features(edf_path, csv_path) except: pass t = threading.Thread(target=_run, daemon=True) t.start(); t.join(timeout=timeout) if t.is_alive(): log(f' [TIMEOUT] {os.path.basename(edf_path)}') return result[0], result[1] tuh_pges_sessions = [] # list of (N_i, 17) PGES sessions from TUH tuh_base_sessions = [] # baseline sessions skipped = 0 for k, (csv_p, edf_p) in enumerate(tgt_pairs_tuh): pges_sess, base_sess = _extract_with_timeout(edf_p, csv_p) if pges_sess is None: skipped += 1; continue tuh_pges_sessions.extend(pges_sess) tuh_base_sessions.extend(base_sess if base_sess else []) if (k+1) % 20 == 0: log(f' TUH: {k+1}/{len(tgt_pairs_tuh)} | ' f'PGES sessions={len(tuh_pges_sessions)} | skipped={skipped}') all_tuh_wins = sum(len(s) for s in tuh_pges_sessions) + sum(len(s) for s in tuh_base_sessions) log(f' TUH: {all_tuh_wins} windows | ' f'{len(tuh_pges_sessions)} PGES sessions | ' f'{len(tuh_base_sessions)} base sessions | skipped={skipped}') # Global TUH scaler (for CycleGAN training only — not applied to thalamic eval) all_tuh_raw = np.vstack(tuh_pges_sessions + tuh_base_sessions) tuh_scaler = StandardScaler().fit(all_tuh_raw) tuh_pges_n = [tuh_scaler.transform(s) for s in tuh_pges_sessions] tuh_base_n = [tuh_scaler.transform(s) for s in tuh_base_sessions] tuh_all_n = tuh_pges_n + tuh_base_n # ── Step 4: Global CycleGAN Stage 1 (unsupervised, same as C8) ────────── log('\nStep 4: Training global CycleGAN Stage 1 (unsupervised TUH ↔ thalamic)...') X_all_thal = np.vstack([p['X'].astype(np.float32) for p in patients]) scaler_thal = StandardScaler().fit(X_all_thal) all_thal_wins_n = scaler_thal.transform(X_all_thal) all_scalp_wins = np.vstack(tuh_pges_n + tuh_base_n) G_S2T_base = train_cyclegan_unsupervised(all_scalp_wins, all_thal_wins_n, epochs=60, lr=1e-3) log(' Stage 1 CycleGAN complete.') # ── Step 5: LOSO experiment ────────────────────────────────────────────── log('\nStep 5: LOSO evaluation (5 conditions)...') log(' A = Thalamic-only TSM (baseline)') log(' B = TUH CycleGAN unsupervised → TSM (C8 null reproduced)') log(' C = Paired-supervised G_S2T cold-start → translate TUH → TSM') log(' D = TUH unsup Stage1 → Paired-sup Stage2 → translate TUH → TSM [MAIN]') log(' E = D + Day-0 temporal heuristic') results = {c: {k: [] for k in K_VALS} for c in ['A','B','C','D','E']} # Pre-translate TUH with Stage-1 G_S2T (for Condition B — per-fold not needed) log('\n Pre-translating TUH sessions with Stage-1 G_S2T (for Condition B)...') tuh_gan_B = translate_sessions(G_S2T_base, tuh_all_n) for fold_i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) seqs, lbls = build_seqs(test_p, scaler) if seqs is None or lbls.sum() == 0: log(f' [{fold_i+1:02d}/{len(patients)}] {pid}: skip'); continue log(f'\n [{fold_i+1:02d}/{len(patients)}] Test={pid} ' f'(PGES={int(lbls.sum())}, Base={int((lbls==0).sum())})') # Thalamic baseline sessions for pre-training thal_sessions = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] if len(sess) >= N_CTX + 2: thal_sessions.append(sess) # Paired patients available for this fold (exclude test patient) fold_paired_pids = [pp for pp in paired_bank if pp != pid] fold_paired_data = [] for pp in fold_paired_pids: Xs_raw, Xt_raw, py = paired_bank[pp] # Normalise scalp features with TUH scaler; thalamic with fold scaler Xs_n = tuh_scaler.transform(Xs_raw) Xt_n = scaler.transform(Xt_raw) fold_paired_data.append((Xs_n, Xt_n, py)) # ── Condition A: thalamic-only TSM ──────────────────────────────── model_A = CausalTransformer().to(DEVICE) model_A = pretrain_on_sessions(model_A, thal_sessions, SEQ_EP_THAL) res_A = {k: kshot_eval(model_A, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['A'][k].append(res_A[k]) log(f' A: K=0={res_A[0]:.4f} K=10={res_A[10]:.4f}') del model_A # ── Condition B: TUH Stage-1 CycleGAN → TSM fine-tune (C8 repro) ─ model_B = CausalTransformer().to(DEVICE) model_B = pretrain_on_sessions(model_B, tuh_gan_B, SEQ_EP_TUH) model_B = finetune_on_thalamic(model_B, train_ps, scaler) res_B = {k: kshot_eval(model_B, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['B'][k].append(res_B[k]) log(f' B: K=0={res_B[0]:.4f} K=10={res_B[10]:.4f}') del model_B # ── Condition C: Paired-supervised cold-start → translate TUH ──── # Fresh G_S2T, no Stage 1 — pure paired supervision only G_C = _MLP(N_FEAT, 64, N_FEAT, 'tanh').to(DEVICE) G_C.train() G_C = fine_tune_with_pairs(G_C, fold_paired_data, epochs=40, lr=5e-4) tuh_gan_C = translate_sessions(G_C, tuh_all_n) del G_C model_C = CausalTransformer().to(DEVICE) model_C = pretrain_on_sessions(model_C, tuh_gan_C, SEQ_EP_TUH) model_C = finetune_on_thalamic(model_C, train_ps, scaler) res_C = {k: kshot_eval(model_C, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['C'][k].append(res_C[k]) log(f' C: K=0={res_C[0]:.4f} K=10={res_C[10]:.4f} ' f'(paired={len(fold_paired_pids)} patients)') del model_C, tuh_gan_C # ── Condition D: Stage1 → Paired fine-tune → translate TUH ────── G_D = copy.deepcopy(G_S2T_base) # start from Stage-1 weights G_D = fine_tune_with_pairs(G_D, fold_paired_data, epochs=30, lr=5e-4) tuh_gan_D = translate_sessions(G_D, tuh_all_n) del G_D model_D = CausalTransformer().to(DEVICE) model_D = pretrain_on_sessions(model_D, tuh_gan_D, SEQ_EP_TUH) model_D = finetune_on_thalamic(model_D, train_ps, scaler) res_D = {k: kshot_eval(model_D, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['D'][k].append(res_D[k]) log(f' D: K=0={res_D[0]:.4f} K=10={res_D[10]:.4f} [MAIN — paired+TUH]') del model_D # ── Condition E: D + Day-0 temporal heuristic ──────────────────── G_E = copy.deepcopy(G_S2T_base) G_E = fine_tune_with_pairs(G_E, fold_paired_data, epochs=30, lr=5e-4) tuh_gan_E = translate_sessions(G_E, tuh_all_n) del G_E model_E = CausalTransformer().to(DEVICE) model_E = pretrain_on_sessions(model_E, tuh_gan_E, SEQ_EP_TUH) model_E = finetune_on_thalamic(model_E, train_ps, scaler) # Auto-label first K=10 PGES windows using device timestamp K_AUTO = 10 pges_starts = np.where(np.diff(np.concatenate([[0], lbls])) == 1)[0] if len(pges_starts) > 0: onset = pges_starts[0] pges_end = onset while pges_end < len(lbls) and lbls[pges_end] == 1: pges_end += 1 auto_p = np.arange(onset, min(onset + K_AUTO, pges_end)) base_idx = np.where(lbls == 0)[0] pre_base = base_idx[base_idx < onset] if len(pre_base) < K_AUTO: pre_base = base_idx[:K_AUTO] auto_b = pre_base[-K_AUTO:] Z = encode(model_E, seqs) pp = Z[auto_p].mean(0); pb = Z[auto_b].mean(0) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) f1_e0 = float(f1_score(lbls, preds, zero_division=0)) for k in K_VALS: if k == 0: results['E'][k].append(f1_e0) else: results['E'][k].append(res_D[k]) # same as D for K>0 else: for k in K_VALS: results['E'][k].append(res_D[k]) log(f' E: K=0={results["E"][0][-1]:.4f} K=10={results["E"][10][-1]:.4f}') del model_E, tuh_gan_E gc.collect(); torch.cuda.empty_cache() if torch.cuda.is_available() else None # ── Results ────────────────────────────────────────────────────────────── log('\n' + '=' * 60) log('=== C11: Paired-Supervised CycleGAN + TUH Results ===') log('=' * 60) cond_means = {} for cond in ['A','B','C','D','E']: means = {k: np.nanmean(results[cond][k]) for k in K_VALS} cond_means[cond] = means header = f"{'Condition':<45} {'K=0':>6} {'K=2':>6} {'K=5':>6} {'K=10':>6}" log(header) log('-' * 70) labels = { 'A': 'A: Thalamic-only TSM (baseline)', 'B': 'B: TUH CycleGAN unsup → TSM [C8 repro]', 'C': 'C: Paired-sup G_S2T cold-start → TUH → TSM', 'D': 'D: TUH unsup + Paired-sup → TUH → TSM [MAIN]', 'E': 'E: D + Day-0 temporal heuristic', } for cond in ['A','B','C','D','E']: m = cond_means[cond] log(f"{labels[cond]:<45} " f"{m[0]:>6.4f} {m[2]:>6.4f} {m[5]:>6.4f} {m[10]:>6.4f}") log('\nGain over thalamic-only baseline (A):') for cond in ['B','C','D','E']: gains = {k: cond_means[cond][k] - cond_means['A'][k] for k in K_VALS} sign = {k: '+' if gains[k] >= 0 else '' for k in K_VALS} log(f" {labels[cond][:42]}:") for k in K_VALS: log(f" K={k:>2}: {sign[k]}{gains[k]:+.4f}") # Wilcoxon: D vs A at K=10 d_vals = [v for v in results['D'][10] if not np.isnan(v)] a_vals = [v for v in results['A'][10] if not np.isnan(v)] if len(d_vals) >= 5 and len(a_vals) >= 5: try: stat, pval = wilcoxon(d_vals[:len(a_vals)], a_vals[:len(d_vals)]) log(f'\nWilcoxon D vs A (K=10): stat={stat:.3f}, p={pval:.4f}') except Exception as e: log(f'\nWilcoxon failed: {e}') # ── Save results ────────────────────────────────────────────────────────── np.save(str(OUT_ROOT / 'results_raw.npy'), results) rows = [] for cond in ['A','B','C','D','E']: for k in K_VALS: vals = [v for v in results[cond][k] if not np.isnan(v)] rows.append({'condition': cond, 'K': k, 'mean': np.mean(vals) if vals else np.nan, 'std': np.std(vals) if vals else np.nan, 'n': len(vals)}) pd.DataFrame(rows).to_csv(str(OUT_ROOT / 'results_summary.csv'), index=False) # ── Figure ──────────────────────────────────────────────────────────────── fig, axes = plt.subplots(1, 2, figsize=(14, 5)) fig.suptitle('C11: Paired-Supervised CycleGAN + TUH Scale', fontsize=12, fontweight='bold') colors = {'A':'#7f8c8d','B':'#e74c3c','C':'#f39c12','D':'#27ae60','E':'#2980b9'} ax = axes[0] for cond in ['A','B','C','D','E']: vals_k0 = [cond_means[cond][0]] vals_k10= [cond_means[cond][10]] ax.plot([0, 10], [cond_means[cond][0], cond_means[cond][10]], 'o-', color=colors[cond], label=cond, linewidth=2, markersize=8) ax.set_xlabel('K (few-shot examples)'); ax.set_ylabel('F1') ax.set_title('F1 vs K — all conditions') ax.legend(fontsize=8); ax.grid(alpha=0.3) # All K values ax2 = axes[1] x = np.arange(len(K_VALS)); w = 0.15 for i, cond in enumerate(['A','B','C','D','E']): vals = [cond_means[cond][k] for k in K_VALS] ax2.bar(x + i*w, vals, w, label=cond, color=colors[cond], alpha=0.85) ax2.set_xticks(x + 2*w); ax2.set_xticklabels([f'K={k}' for k in K_VALS]) ax2.set_ylim(0.7, 1.05); ax2.set_ylabel('Mean F1') ax2.set_title('Condition comparison by K') ax2.legend(fontsize=8); ax2.grid(alpha=0.3, axis='y') plt.tight_layout() fig.savefig(str(OUT_ROOT / 'c11_paired_tuh_cyclegan.png'), dpi=150) plt.close(fig) log(f'Figure saved -> {OUT_ROOT}/c11_paired_tuh_cyclegan.png') log('Results saved -> ' + str(OUT_ROOT)) log('COMPLETE')