""" DACTRL — TUH Scalp Pre-training Experiment ============================================ Platform vision: leverage vast public scalp EEG (TUH corpus) to pre-train the DACTRL CausalTransformer in feature space, then fine-tune on thalamic LOSO. Strategy: 1. Filter TUH seizure corpus for gnsz/tcsz (generalized/tonic-clonic) — these reliably produce PGES-like post-ictal suppression. 2. Extract 17 features from average-reference scalp signal (global brain state). 3. Apply inversion correction to 3 features inverted between scalp and thalamic (Suppression_Ratio, Spectral_Ratio, Zero_Crossings) — C2 finding. 4. Pre-train CausalTransformer on corrected scalp feature corpus. 5. Fine-tune ProtoNet head on thalamic LOSO — compare vs thalamic-only baseline. Conditions: A: Thalamic-only pre-training (baseline, reproduced from main pipeline) B: TUH scalp pre-training (corrected features) + thalamic fine-tune C: TUH scalp pre-training (NO correction) + thalamic fine-tune [ablation] """ import os, sys, glob, copy, warnings, threading from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score from datetime import datetime warnings.filterwarnings('ignore') # ── Paths ───────────────────────────────────────────────────────────────────── TUH_BASE = 'G:/PHD Datasets/Data/Scalp/tueeg_data/tuh_eeg_seizure/v2.0.3/edf' THAL_BASE = 'G:/PHD Datasets/Data' # thalamic SEEG patients OUT_DIR = 'D:/Projects/phd/PSEG/pges_toolkit/results/dactrl_tuh_pretrain' os.makedirs(OUT_DIR, exist_ok=True) # ── Hyperparameters ─────────────────────────────────────────────────────────── WIN_SEC = 5 # 5-second windows POST_OFFSET = 30 # seconds after seizure offset to start post-ictal extraction POST_DUR = 240 # seconds of post-ictal to use as pseudo-PGES PRE_DUR = 120 # seconds of pre-ictal baseline N_CTX = 8 # context length (same as thalamic model) SEQ_EP_TUH = 30 # pre-training epochs on TUH SEQ_EP_THAL = 100 # fine-tuning epochs on thalamic SEQ_LR = 3e-4 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 MAX_TUH = 300 # max TUH files to use (memory budget) K_VALS = [0, 2, 5, 10] DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Feature indices for inversion correction (C2 finding): # Suppression_Ratio (idx 10), Spectral_Ratio (idx 8), Zero_Crossings (idx 2) INVERT_IDX = [2, 8, 10] def log(msg): print(f'[{datetime.now().strftime("%H:%M:%S")}] {msg}', flush=True) # ── Model ───────────────────────────────────────────────────────────────────── class CausalTransformer(nn.Module): def __init__(self, n_feat=17): super().__init__() self.proj = nn.Linear(n_feat, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, n_feat) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.enc(self.proj(x), mask=self.mask[:T, :T]) if return_hidden: return h return self.head(h) # ── Feature extraction (17 features) ───────────────────────────────────────── def compute_features(sig, fs): """Extract 17 features from a 1D signal segment.""" import numpy as np from numpy.fft import rfft, rfftfreq sig = sig - sig.mean() n = len(sig) # Time domain rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) # Power spectrum freqs = rfftfreq(n, 1/fs) psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0.0 total = float(psd.sum()) + 1e-10 delta = band(0.5, 4); theta = band(4, 8) alpha = band(8, 13); beta = band(13, 30) gamma = band(80, 150) sr = (delta + theta) / (alpha + beta + 1e-10) # Spectral Ratio # Shannon entropy of normalised PSD p = psd / (psd.sum() + 1e-10) p = p[p > 0] shan = float(-np.sum(p * np.log(p + 1e-10))) # Suppression Ratio supp = float(np.mean(np.abs(sig) < 0.05 * np.max(np.abs(sig) + 1e-10))) # Approximate entropy (simplified m=2, r=0.2*std) def _apen(u, m=2, r=None): if r is None: r = 0.2 * np.std(u) + 1e-10 N = len(u) def phi(mm): x = np.array([u[i:i+mm] for i in range(N-mm+1)]) C = np.sum(np.max(np.abs(x[:,None]-x[None,:]), axis=2) <= r, axis=0) / (N-mm+1) return np.sum(np.log(C + 1e-10)) / (N-mm+1) return abs(phi(m) - phi(m+1)) # Use short segment for speed u = sig[:min(200, n)] apen = float(_apen(u)) # Sample entropy (m=2, r=0.2*std) def _sampen(u, m=2, r=None): if r is None: r = 0.2 * np.std(u) + 1e-10 N = len(u) def _count(mm): cnt = 0 for i in range(N-mm): cnt += np.sum(np.max(np.abs( np.array([u[i:i+mm]]) - np.array([u[j:j+mm] for j in range(N-mm) if j != i])), axis=1) <= r) return cnt A = _count(m+1); B = _count(m) return float(-np.log((A+1e-10)/(B+1e-10))) sampen = float(_sampen(u)) # Entropy of time-course (ETC proxy via histogram) hist, _ = np.histogram(sig, bins=10) hist = hist / (hist.sum() + 1e-10) etc = float(-np.sum(hist[hist>0] * np.log(hist[hist>0]))) # LZC (Lempel-Ziv complexity proxy) bsig = (sig > np.median(sig)).astype(int) lzc = float(len(set(''.join(map(str,bsig))[i:i+4] for i in range(len(bsig)-3)))) # Permutation entropy (m=3) m = 3 perms = [] for i in range(len(sig)-m+1): perms.append(tuple(np.argsort(sig[i:i+m]))) from collections import Counter cnt = Counter(perms) tot = sum(cnt.values()) pent = float(-sum((v/tot)*np.log(v/tot+1e-10) for v in cnt.values())) return np.array([rms, ll, zc, var, delta/total, theta/total, alpha/total, beta/total, sr, shan, supp, apen, sampen, etc, lzc, pent, gamma/total], dtype=np.float32) def extract_tuh_features(edf_path, csv_path, target_labels=('gnsz','tcsz')): """ Load TUH EDF+CSV, extract temporally ordered feature sequences per seizure. Returns list of session arrays — each array is (N_windows, 17) in time order. Post-ictal (pseudo-PGES) and baseline windows are returned as separate ordered sequences so TSM can build valid within-session temporal windows. Returns (pges_sessions, base_sessions) — lists of ordered (N,17) arrays. """ try: import mne raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) fs = raw.info['sfreq'] # Parse CSV for seizure intervals of target type szs = [] with open(csv_path) as f: for line in f: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 5: continue lbl = parts[3].strip() if lbl in target_labels: try: szs.append((float(parts[1]), float(parts[2]))) except ValueError: continue if not szs: return None, None # Merge overlapping intervals across channels szs = sorted(set(szs)) merged = [list(szs[0])] for s, e in szs[1:]: if s <= merged[-1][1]: merged[-1][1] = max(merged[-1][1], e) else: merged.append([s, e]) # Average reference — clip bad channels first picks = mne.pick_types(raw.info, eeg=True) if len(picks) == 0: picks = list(range(len(raw.ch_names))) data = raw.get_data(picks=picks) if np.abs(np.median(data)) < 0.01: # V → µV data = data * 1e6 ch_std = data.std(axis=1) good = ch_std < np.percentile(ch_std, 90) * 3 if good.sum() == 0: good = np.ones(len(ch_std), dtype=bool) avg = data[good].mean(axis=0) # (n_samples,) ordered in time win_samp = int(WIN_SEC * fs) pges_sessions, base_sessions = [], [] for sz_start, sz_end in merged: # Post-ictal: ordered sequence of consecutive windows pi_start = int((sz_end + POST_OFFSET) * fs) pi_end = min(int((sz_end + POST_OFFSET + POST_DUR) * fs), len(avg)) pi_wins = [] for i in range(pi_start, pi_end - win_samp, win_samp): seg = avg[i:i+win_samp] if len(seg) == win_samp: pi_wins.append(compute_features(seg, fs)) if len(pi_wins) >= N_CTX + 2: # enough for at least one TSM sequence pges_sessions.append(np.array(pi_wins, dtype=np.float32)) # Pre-ictal baseline: ordered sequence pre_end = int((sz_start - 10) * fs) pre_start = max(0, int((sz_start - 10 - PRE_DUR) * fs)) pre_wins = [] for i in range(pre_start, pre_end - win_samp, win_samp): seg = avg[i:i+win_samp] if len(seg) == win_samp: pre_wins.append(compute_features(seg, fs)) if len(pre_wins) >= N_CTX + 2: base_sessions.append(np.array(pre_wins, dtype=np.float32)) if not pges_sessions and not base_sessions: return None, None return pges_sessions, base_sessions except Exception as e: return None, None def apply_inversion_correction(X): """Invert the 3 features that are directionally flipped scalp→thalamic (C2).""" X = X.copy() for idx in INVERT_IDX: X[:, idx] = -X[:, idx] return X # ── Feature-space CycleGAN ──────────────────────────────────────────────────── N_FEAT = 17 class _MLP(nn.Module): def __init__(self, in_d, hid, out_d, act_out=None): super().__init__() self.net = nn.Sequential( nn.Linear(in_d, hid), nn.LeakyReLU(0.2), nn.Linear(hid, hid), nn.LeakyReLU(0.2), nn.Linear(hid, out_d)) self.act_out = act_out def forward(self, x): o = self.net(x) return torch.tanh(o) if self.act_out == 'tanh' else o def train_cyclegan(scalp_wins, thal_wins, epochs=60, lr=1e-3): """ Feature-space CycleGAN: learns scalp(17-d) ↔ thalamic(17-d) mapping. scalp_wins, thal_wins: (N, 17) numpy arrays (already normalised). Returns G_S2T (scalp→thalamic generator) as eval-mode module. """ G_S2T = _MLP(N_FEAT, 64, N_FEAT, 'tanh').to(DEVICE) # scalp → thalamic G_T2S = _MLP(N_FEAT, 64, N_FEAT, 'tanh').to(DEVICE) # thalamic → scalp D_S = _MLP(N_FEAT, 32, 1).to(DEVICE) # discriminator scalp D_T = _MLP(N_FEAT, 32, 1).to(DEVICE) # discriminator thalamic opt_G = torch.optim.Adam(list(G_S2T.parameters()) + list(G_T2S.parameters()), lr=lr, betas=(0.5, 0.999)) opt_D = torch.optim.Adam(list(D_S.parameters()) + list(D_T.parameters()), lr=lr, betas=(0.5, 0.999)) bce = nn.BCEWithLogitsLoss() l1 = nn.L1Loss() S = torch.tensor(scalp_wins, dtype=torch.float32) T = torch.tensor(thal_wins, dtype=torch.float32) ds_s = torch.utils.data.DataLoader( torch.utils.data.TensorDataset(S), batch_size=256, shuffle=True) n_t = len(T) for ep in range(epochs): for (s_b,) in ds_s: # random thalamic batch of same size idx = torch.randint(0, n_t, (len(s_b),)) t_b = T[idx].to(DEVICE) s_b = s_b.to(DEVICE) # ── Generators ── fake_t = G_S2T(s_b); rec_s = G_T2S(fake_t) fake_s = G_T2S(t_b); rec_t = G_S2T(fake_s) idt_s = G_T2S(s_b); idt_t = G_S2T(t_b) loss_G = (bce(D_T(fake_t), torch.ones_like(D_T(fake_t))) # adv S→T + bce(D_S(fake_s), torch.ones_like(D_S(fake_s))) # adv T→S + 10. * l1(rec_s, s_b) # cycle S + 10. * l1(rec_t, t_b) # cycle T + 5. * l1(idt_s, s_b) # identity S + 5. * l1(idt_t, t_b)) # identity T opt_G.zero_grad(); loss_G.backward(); opt_G.step() # ── Discriminators ── fake_t2 = G_S2T(s_b).detach() fake_s2 = G_T2S(t_b).detach() loss_D = (bce(D_T(t_b), torch.ones_like(D_T(t_b))) + bce(D_T(fake_t2), torch.zeros_like(D_T(fake_t2))) + bce(D_S(s_b), torch.ones_like(D_S(s_b))) + bce(D_S(fake_s2), torch.zeros_like(D_S(fake_s2)))) * 0.5 opt_D.zero_grad(); loss_D.backward(); opt_D.step() G_S2T.eval() return G_S2T def translate_sessions(G_S2T, sessions): """Translate a list of (N_i, 17) scalp-feature sessions to thalamic domain.""" translated = [] for sess in sessions: t_in = torch.tensor(sess, dtype=torch.float32).to(DEVICE) with torch.no_grad(): t_out = G_S2T(t_in).cpu().numpy() translated.append(t_out) return translated # ── Pre-training ─────────────────────────────────────────────────────────────── def pretrain_on_sessions(model, sessions, epochs, lr=SEQ_LR): """ Self-supervised next-window TSM prediction. sessions: list of (N_i, 17) arrays — each array is one temporally ordered session. Sliding windows are built WITHIN each session only — no cross-session contamination. """ seqs = [] for sess in sessions: if len(sess) < N_CTX + 2: continue for i in range(N_CTX + 1, len(sess)): seqs.append(sess[i - N_CTX - 1: i]) # (N_CTX+1, 17) if len(seqs) < 10: return model seqs = np.array(seqs, dtype=np.float32) # (total, N_CTX+1, 17) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), # context torch.tensor(seqs[:, -1])) # target (next window) ld = torch.utils.data.DataLoader(ds, batch_size=128, shuffle=True) opt = torch.optim.Adam(model.parameters(), lr=lr) model.train() for _ in range(epochs): for xc, xt in ld: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() model.eval() return model def finetune_on_thalamic(model, train_patients, scaler, epochs=SEQ_EP_THAL): """Continue pre-training on thalamic baseline sequences (per-patient sessions).""" sessions = [] for p in train_patients: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] # ordered baseline windows per patient if len(sess) >= N_CTX + 2: sessions.append(sess) if not sessions: return model return pretrain_on_sessions(model, sessions, epochs, lr=SEQ_LR * 0.3) # ── ProtoNet eval ────────────────────────────────────────────────────────────── def encode(model, seqs): model.eval() z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:, -1, :].cpu().numpy()) return np.vstack(z) def build_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i - N_CTX: i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def diversity_support(lbls, K): pos = np.where(lbls == 1)[0]; neg = np.where(lbls == 0)[0] if len(pos) < K or len(neg) < K: return None, None sup_p = np.random.choice(pos, K, replace=False) sup_n = np.random.choice(neg, K, replace=False) return sup_p, sup_n def kshot_eval(model, seqs, lbls, K, n_trials=5): Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') if K == 0: # Cross-patient prototypes — use mean of all available pp = Z[lbls == 1].mean(0); pb = Z[lbls == 0].mean(0) preds = (np.linalg.norm(Z - pp, axis=1) < np.linalg.norm(Z - pb, axis=1)).astype(int) return float(f1_score(lbls, preds, zero_division=0)) scores = [] for _ in range(n_trials): sup_p, sup_n = diversity_support(lbls, K) if sup_p is None: continue pp = Z[sup_p].mean(0); pb = Z[sup_n].mean(0) qry = np.array([i for i in range(len(lbls)) if i not in set(sup_p)|set(sup_n)]) if len(qry) == 0: continue preds = (np.linalg.norm(Z[qry] - pp, axis=1) < np.linalg.norm(Z[qry] - pb, axis=1)).astype(int) scores.append(float(f1_score(lbls[qry], preds, zero_division=0))) return float(np.mean(scores)) if scores else float('nan') # ══════════════════════════════════════════════════════════════════════════════ # MAIN # ══════════════════════════════════════════════════════════════════════════════ if __name__ == '__main__': log('=' * 60) log('DACTRL — TUH Scalp Pre-training (Feature-Space + Inversion Correction)') log('=' * 60) # ── Step 1: Load thalamic patients ────────────────────────────────────── log('Loading thalamic patients...') _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as _f: exec(compile( _f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] log('Imported v1 data loaders.') SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" meta_df = pd.read_excel(METADATA) raw_data = load_all_seeg(meta_df) NUCLEUS_MAP = { 'P1':'CeM','P2':'CL','P3':'CeM','P4':'MD','P5':'CeM', 'P6':'MD','P7':'CL','P8':'CL','P9':'CeM','P10':'ANT', 'P11':'ANT','P12':'ANT','P13':'ANT','P14':'ANT','P15':'ANT', } patients = [{'pid': p, 'nucleus': NUCLEUS_MAP.get(p,'UNK'), 'X': d['X'], 'labels': d['y_temporal']} for p, d in raw_data.items() if d['y_temporal'].sum() > 0 and p != 'P13'] log(f'Loaded {len(patients)} thalamic patients.') # ── Step 2: Build TUH feature corpus ──────────────────────────────────── log('') log('Step 2: Extracting features from TUH corpus (gnsz/tcsz only)...') csvs = [f for f in glob.glob(os.path.join(TUH_BASE, '**', '*.csv'), recursive=True) if 'worksheet' not in f.lower()] def _has_target(f): try: return any(t in open(f, errors='ignore').read() for t in ['tcsz','gnsz']) except Exception: return False tgt_csvs = [f for f in csvs if _has_target(f)] tgt_pairs = [(f, f.replace('.csv','.edf')) for f in tgt_csvs if os.path.exists(f.replace('.csv','.edf'))] log(f'Found {len(tgt_pairs)} valid EDF+CSV pairs with gnsz/tcsz.') # Subsample to MAX_TUH for memory budget np.random.seed(42) if len(tgt_pairs) > MAX_TUH: idxs = np.random.choice(len(tgt_pairs), MAX_TUH, replace=False) tgt_pairs = [tgt_pairs[i] for i in idxs] log(f'Using {len(tgt_pairs)} files (MAX_TUH={MAX_TUH}).') def _extract_with_timeout(edf_path, csv_path, timeout=180): result = [None, None] def _run(): try: result[0], result[1] = extract_tuh_features(edf_path, csv_path) except Exception: pass t = threading.Thread(target=_run, daemon=True) t.start(); t.join(timeout=timeout) if t.is_alive(): log(f' [TIMEOUT] {os.path.basename(edf_path)} — skipping') return result[0], result[1] # sessions = lists of (N_i, 17) ordered arrays — one per TUH seizure event tuh_sessions_raw, tuh_sessions_cor = [], [] skipped = 0 total_pges_wins = 0 for k, (csv_p, edf_p) in enumerate(tgt_pairs): pges_sess, base_sess = _extract_with_timeout(edf_p, csv_p, timeout=180) if pges_sess is None: skipped += 1; continue # raw sessions tuh_sessions_raw.extend(pges_sess) tuh_sessions_raw.extend(base_sess if base_sess else []) # corrected (inversion-corrected) sessions tuh_sessions_cor.extend([apply_inversion_correction(s) for s in pges_sess]) tuh_sessions_cor.extend([apply_inversion_correction(s) for s in (base_sess or [])]) total_pges_wins += sum(len(s) for s in pges_sess) if (k+1) % 10 == 0: log(f' Processed {k+1}/{len(tgt_pairs)} | sessions={len(tuh_sessions_raw)} ' f'| PGES wins={total_pges_wins} | skipped={skipped}') total_wins = sum(len(s) for s in tuh_sessions_raw) log(f'TUH corpus: {total_wins} windows in {len(tuh_sessions_raw)} sessions ' f'| PGES wins={total_pges_wins} | skipped={skipped}') # Normalize using global stats across all windows all_wins_raw = np.vstack(tuh_sessions_raw) tuh_scaler = StandardScaler().fit(all_wins_raw) tuh_sessions_raw_n = [tuh_scaler.transform(s) for s in tuh_sessions_raw] tuh_sessions_cor_n = [tuh_scaler.transform(s) for s in tuh_sessions_cor] # ── Step 2b: Train CycleGAN (feature-space, scalp ↔ thalamic) ─────────── log('') log('Step 2b: Training feature-space CycleGAN (TUH scalp ↔ thalamic)...') # Build flat arrays for CycleGAN training (does not need session ordering) all_scalp_wins = np.vstack(tuh_sessions_raw_n) # normalised raw scalp all_thal_wins = np.vstack([ scaler_tmp.transform(p['X'].astype(np.float32)) for p in patients for scaler_tmp in [StandardScaler().fit( np.vstack([q['X'].astype(np.float32) for q in patients if q['pid'] != p['pid']]))] ]) # approximate thalamic feature distribution # Simpler: just use all thalamic windows pooled with a single scaler X_all_thal = np.vstack([p['X'].astype(np.float32) for p in patients]) scaler_thal = StandardScaler().fit(X_all_thal) all_thal_wins = scaler_thal.transform(X_all_thal) log(f' Scalp windows (TUH): {len(all_scalp_wins)} | ' f'Thalamic windows: {len(all_thal_wins)}') G_S2T = train_cyclegan(all_scalp_wins, all_thal_wins, epochs=60) log(' CycleGAN training complete.') # Translate TUH sessions to thalamic domain tuh_sessions_gan = translate_sessions(G_S2T, tuh_sessions_raw_n) log(f' Translated {len(tuh_sessions_gan)} TUH sessions to thalamic domain.') # ── Step 3: LOSO experiment ────────────────────────────────────────────── log('') log('Step 3: LOSO evaluation (4 conditions)...') log(' A = Thalamic-only TSM (baseline)') log(' B = TUH TSM + inversion correction + thalamic fine-tune') log(' C = TUH TSM + NO correction (ablation)') log(' D = TUH CycleGAN translation + TSM fine-tune') log(' E = Best TUH backbone + Day-0 temporal heuristic [COMBO]') results = {c: {k: [] for k in K_VALS} for c in ['A','B','C','D','E']} for fold_i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) seqs, lbls = build_seqs(test_p, scaler) if seqs is None or lbls.sum() == 0: log(f' {pid}: skip'); continue # ── Condition A: thalamic-only TSM (baseline) ── thal_sessions = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] if len(sess) >= N_CTX + 2: thal_sessions.append(sess) model_A = CausalTransformer().to(DEVICE) model_A = pretrain_on_sessions(model_A, thal_sessions, epochs=SEQ_EP_THAL) res_A = {k: kshot_eval(model_A, seqs, lbls, k) for k in K_VALS} del model_A # ── Condition B: TUH TSM (corrected) → thalamic fine-tune ── model_B = CausalTransformer().to(DEVICE) model_B = pretrain_on_sessions(model_B, tuh_sessions_cor_n, epochs=SEQ_EP_TUH) model_B = finetune_on_thalamic(model_B, train_ps, scaler) res_B = {k: kshot_eval(model_B, seqs, lbls, k) for k in K_VALS} del model_B # ── Condition C: TUH TSM (no correction ablation) → thalamic fine-tune ── model_C = CausalTransformer().to(DEVICE) model_C = pretrain_on_sessions(model_C, tuh_sessions_raw_n, epochs=SEQ_EP_TUH) model_C = finetune_on_thalamic(model_C, train_ps, scaler) res_C = {k: kshot_eval(model_C, seqs, lbls, k) for k in K_VALS} del model_C # ── Condition D: CycleGAN translated sessions → TSM fine-tune ── model_D = CausalTransformer().to(DEVICE) model_D = pretrain_on_sessions(model_D, tuh_sessions_gan, epochs=SEQ_EP_TUH) model_D = finetune_on_thalamic(model_D, train_ps, scaler) res_D = {k: kshot_eval(model_D, seqs, lbls, k) for k in K_VALS} del model_D # ── Condition E: Best TUH (B or D, pick higher K=0) + Day-0 heuristic ── # Use the better TUH backbone (B vs D by K=0), then apply temporal auto-label # for zero-label adaptation — the full platform vision combo best_k0_B = res_B[0]; best_k0_D = res_D[0] best_sessions = tuh_sessions_cor_n if (best_k0_B >= best_k0_D or np.isnan(best_k0_D)) else tuh_sessions_gan model_E = CausalTransformer().to(DEVICE) model_E = pretrain_on_sessions(model_E, best_sessions, epochs=SEQ_EP_TUH) model_E = finetune_on_thalamic(model_E, train_ps, scaler) # Day-0 temporal heuristic: auto-label first K_AUTO windows after seizure offset K_AUTO = 10 from numpy import where, diff, concatenate pges_starts = where(diff(concatenate([[0], lbls])) == 1)[0] if len(pges_starts) > 0: onset = pges_starts[0] pges_end = onset while pges_end < len(lbls) and lbls[pges_end] == 1: pges_end += 1 auto_p = np.arange(onset, min(onset + K_AUTO, pges_end)) base_idx = where(lbls == 0)[0] pre_base = base_idx[base_idx < onset] if len(pre_base) < K_AUTO: pre_base = base_idx[:K_AUTO] auto_b = pre_base[-K_AUTO:] if len(auto_p) > 0 and len(auto_b) > 0: Z_E = encode(model_E, seqs) pp_e = Z_E[auto_p].mean(0); pb_e = Z_E[auto_b].mean(0) used = set(auto_p.tolist()) | set(auto_b.tolist()) qry = np.array([i for i in range(len(lbls)) if i not in used]) if len(qry) > 0: preds_e = (np.linalg.norm(Z_E[qry]-pp_e,axis=1) < np.linalg.norm(Z_E[qry]-pb_e,axis=1)).astype(int) from sklearn.metrics import f1_score as _f1 res_E = {k: (float(_f1(lbls[qry], preds_e, zero_division=0)) if k == 0 else kshot_eval(model_E, seqs, lbls, k)) for k in K_VALS} else: res_E = {k: kshot_eval(model_E, seqs, lbls, k) for k in K_VALS} else: res_E = {k: kshot_eval(model_E, seqs, lbls, k) for k in K_VALS} else: res_E = {k: float('nan') for k in K_VALS} del model_E for k in K_VALS: results['A'][k].append(res_A[k]) results['B'][k].append(res_B[k]) results['C'][k].append(res_C[k]) results['D'][k].append(res_D[k]) results['E'][k].append(res_E[k]) log(f' [{fold_i+1:02d}] {pid}: ' f'A={res_A[10]:.4f} B_TSM={res_B[10]:.4f} ' f'D_GAN={res_D[10]:.4f} E_Combo={res_E[10]:.4f}') # ── Step 4: Summary ────────────────────────────────────────────────────── log('') log('=' * 60) log('=== TUH Pre-training Summary (K=10 F1) ===') log(f'{"Condition":<35} {"K=0":>8} {"K=2":>8} {"K=5":>8} {"K=10":>8}') log('-' * 65) for cond, label in [('A', 'A: Thalamic-only TSM (baseline)'), ('B', 'B: TUH TSM + Inversion Correction'), ('C', 'C: TUH TSM + No Correction [ablation]'), ('D', 'D: TUH CycleGAN → TSM fine-tune'), ('E', 'E: Best TUH + Day-0 Heuristic [COMBO]')]: row = f'{label:<35}' for k in K_VALS: vals = [v for v in results[cond][k] if not np.isnan(v)] row += f' {np.mean(vals):.4f}' if vals else ' nan' log(row) log('') log('Gain over thalamic-only baseline (A):') for cond, label in [('B','TSM+Correction'), ('C','TSM+NoCorrection'), ('D','CycleGAN'), ('E','BestTUH+Day0Combo')]: log(f' {label}:') for k in K_VALS: a = [v for v in results['A'][k] if not np.isnan(v)] b = [v for v in results[cond][k] if not np.isnan(v)] n = min(len(a), len(b)) delta = np.mean(b[:n]) - np.mean(a[:n]) if n > 0 else float('nan') log(f' K={k:>2}: {delta:+.4f}') # Save results np.save(os.path.join(OUT_DIR, 'tuh_pretrain_results.npy'), results) log(f'Results saved -> {OUT_DIR}') log('COMPLETE')