# -*- coding: utf-8 -*- """ DACTRL C13 — Three-Source Integrated Contrastive Pre-training N_TRIALS=10 rerun for statistical significance (Wilcoxon power improvement) =========================================================================== Same as dactrl_three_source_contrastive.py but with N_TRIALS=10 per fold. Goal: reduce per-patient F1 variance so Wilcoxon D-vs-A p-value crosses 0.05. Mac paths (/Volumes/Expansion). Run on M1 Max 64GB. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, glob, random, threading, warnings, copy from pathlib import Path from datetime import datetime from collections import Counter import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy.signal import resample_poly, butter, filtfilt from scipy.stats import wilcoxon from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F try: import mne; mne.set_log_level('ERROR') except ImportError: pass warnings.filterwarnings('ignore') if torch.cuda.is_available(): DEVICE = torch.device('cuda') print(f"[GPU] {torch.cuda.get_device_name(0)}") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEVICE = torch.device('mps') print("[MPS] Apple Silicon GPU") else: DEVICE = torch.device('cpu') print("[CPU] No GPU") torch.manual_seed(42); np.random.seed(42); random.seed(42) # ── Paths ───────────────────────────────────────────────────────────────────── SEEG_ROOT = Path("/Volumes/Expansion/phd_datasets/Data/Thalamus/SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" TUH_BASE = "/Volumes/Expansion/phd_datasets/Data/Scalp/tueeg_data/tuh_eeg_seizure/v2.0.3/edf/dev" GTC_ROOT = Path("/Volumes/Expansion/phd_datasets/Data/Thalamus/eeg_ecg_us_clinical/GTC_Focal_SEEG") OUT_ROOT = Path("/Volumes/Expansion/phd_datasets/Code/pges_toolkit_mac/results/dactrl_c13_hightrials") OUT_ROOT.mkdir(parents=True, exist_ok=True) # ── Constants ────────────────────────────────────────────────────────────────── FS_NATIVE = 2048 FS_TARGET = 250 WIN_SEC = 5 WIN_TARGET = WIN_SEC * FS_TARGET MAX_TUH = 300 N_FEAT = 17 D_MODEL = 64; N_HEADS = 4; N_LAYERS = 4; N_CTX = 8 SEQ_EP_THAL = 60; SEQ_EP_PRETRAIN = 30; SEQ_LR = 3e-4 SUPCON_T = 0.07 LAM_L2 = 0.5 LAM_L3 = 1.0 K_VALS = [0, 2, 5, 10] N_TRIALS = 10 # increased from 5 → reduces per-patient variance for Wilcoxon NUCLEUS_MAP = { 'P1':'CeM','P2':'CL','P3':'CeM','P4':'MD','P5':'CeM', 'P6':'MD','P7':'CL','P8':'CL','P9':'CeM','P10':'ANT', 'P11':'ANT','P12':'ANT','P13':'ANT','P14':'ANT','P15':'ANT', } THAL_PIDS = ['P1', 'P2', 'P3', 'P4', 'P5', 'P7', 'P8', 'P15'] INVERT_IDX = [10, 0, 3] TOPO_CH = ['FZ', 'CZ', 'C3', 'F3'] ALL_SCALP = ['FP1','FP2','F7','F8','F3','F4','FZ', 'T3','T4','C3','C4','CZ','T5','T6','P3','P4','PZ','O1','O2'] def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) # ══════════════════════════════════════════════════════════════════════════════ # Feature extraction (17-feature pipeline — identical to all other scripts) # ══════════════════════════════════════════════════════════════════════════════ def compute_features(sig, fs): sig = sig - sig.mean() n = len(sig) if n < int(fs * 0.5): return None from numpy.fft import rfft, rfftfreq rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs); psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0.0 total = float(psd.sum()) + 1e-10 delta = band(0.5,4); theta = band(4,8); alpha = band(8,13) beta = band(13,30); gamma = band(80,150) sr = (delta+theta)/(alpha+beta+1e-10) p = psd/(psd.sum()+1e-10); p = p[p>0] shan = float(-np.sum(p*np.log(p+1e-10))) supp = float(np.mean(np.abs(sig) < 0.05*np.max(np.abs(sig)+1e-10))) u = sig[:min(200,n)] def _apen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def phi(mm): x=np.array([u[i:i+mm] for i in range(N-mm+1)]) C=np.sum(np.max(np.abs(x[:,None]-x[None,:]),axis=2)<=r,axis=0)/(N-mm+1) return np.sum(np.log(C+1e-10))/(N-mm+1) return abs(phi(m)-phi(m+1)) apen=float(_apen(u)) def _sampen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def _count(mm): cnt=0 for i in range(N-mm): cnt+=np.sum(np.max(np.abs(np.array([u[i:i+mm]])-np.array( [u[j:j+mm] for j in range(N-mm) if j!=i])),axis=1)<=r) return cnt A=_count(m+1);B=_count(m) return float(-np.log((A+1e-10)/(B+1e-10))) sampen=float(_sampen(u)) hist,_=np.histogram(sig,bins=10); hist=hist/(hist.sum()+1e-10) etc=float(-np.sum(hist[hist>0]*np.log(hist[hist>0]))) bsig=(sig>np.median(sig)).astype(int) lzc=float(len(set(''.join(map(str,bsig))[i:i+4] for i in range(len(bsig)-3)))) m3=3; perms=[tuple(np.argsort(sig[i:i+m3])) for i in range(n-m3+1)] cnt2=Counter(perms); tot=sum(cnt2.values()) pent=float(-sum((v/tot)*np.log(v/tot+1e-10) for v in cnt2.values())) return np.array([rms,ll,zc,var,delta/total,theta/total,alpha/total,beta/total, sr,shan,supp,apen,sampen,etc,lzc,pent,gamma/total],dtype=np.float32) def _find_ch(ch_names, targets): cu = [c.upper().replace('-','').replace(' ','') for c in ch_names] for t in targets: t2 = t.upper().replace('-','').replace(' ','') for i,c in enumerate(cu): if c == t2 or t2 in c: return i return None def _scalp_avg(data, ch_names, targets): idxs = [i for t in targets for i in [_find_ch(ch_names,[t])] if i is not None] idxs = list(dict.fromkeys(idxs)) return data[idxs].mean(axis=0) if idxs else None def _downsample(sig, fs_in, fs_out): from math import gcd g = gcd(int(fs_in), int(fs_out)) return resample_poly(sig, int(fs_out)//g, int(fs_in)//g) def _bp(sig, lo, hi, fs, order=4): nyq = fs/2 b, a = butter(order, [lo/nyq, hi/nyq], btype='band') return filtfilt(b, a, sig) # ══════════════════════════════════════════════════════════════════════════════ # EDF loading helpers # ══════════════════════════════════════════════════════════════════════════════ def _crop_load_segment(edf_path, t_start_s, t_end_s): try: raw = mne.io.read_raw_edf(str(edf_path), preload=False, verbose=False) fs = raw.info['sfreq'] dur = raw.n_times / fs t0 = max(0.0, t_start_s) t1 = min(t_end_s, dur) if t1 - t0 < 1.0: raw.close(); return None, None, None raw.crop(tmin=t0, tmax=t1) raw.load_data() data = raw.get_data() chs = raw.ch_names raw.close(); del raw if np.abs(np.median(data)) < 0.01: data *= 1e6 return data, fs, chs except Exception: return None, None, None # ══════════════════════════════════════════════════════════════════════════════ # Data extraction — GTC bridge (A2/A4), GTC thalamic (B2/B3), P2 paired, TUH # ══════════════════════════════════════════════════════════════════════════════ def extract_gtc_bridge_features(edf_name): edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None, None GTC_TOPO = ['EEG FZ', 'EEG CZ', 'EEG C3', 'EEG F3', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG C3-REF', 'EEG F3-REF', 'FZ', 'CZ', 'C3', 'F3'] Xs, Xt, Y = [], [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, GTC_TOPO) i1 = _find_ch(chs, ['EEG LT1', 'LT1']) i2 = _find_ch(chs, ['EEG LT2', 'LT2']) if s_raw is None or i1 is None or i2 is None: del data; continue s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1] - data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds)) - W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(label) del data; gc.collect() if not Xs: return None, None, None log(f' GTC {edf_name}: {len(Y)} bridge windows (PGES={sum(Y)}, base={sum(1 for y in Y if y==0)})') return (np.array(Xs, dtype=np.float32), np.array(Xt, dtype=np.float32), np.array(Y, dtype=np.int32)) def extract_gtc_thalamic_features(edf_name): edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None X, Y = [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue i1 = _find_ch(chs, ['EEG LTP1', 'LTP1']) i2 = _find_ch(chs, ['EEG LTP2', 'LTP2']) if i1 is None: del data; continue ref = data[i2] if i2 is not None else np.zeros_like(data[i1]) t_ds = _downsample(_bp(data[i1] - ref, 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, len(t_ds) - W, W): f = compute_features(t_ds[i:i+W], FS_TARGET) if f is not None: X.append(f); Y.append(label) del data; gc.collect() if not X: return None, None return np.array(X, dtype=np.float32), np.array(Y, dtype=np.int32) def extract_p2_paired_features(meta_df): p2_rows = meta_df[meta_df['Patient ID'] == 'P2'] pdir = SEEG_ROOT / 'P2_SEEG' Xs, Xt, Y = [], [], [] for _, row in p2_rows.iterrows(): sz_file = str(row['Seizure_Filename']) sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']) edf_path = pdir / sz_file if not edf_path.exists(): continue try: for (t0, t1, label) in [(sz_end+5, sz_end+185, 1), (sz_start-130, sz_start-10, 0)]: data, fs, chs = _crop_load_segment(edf_path, t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) i1 = _find_ch(chs, ['LT1']); i2 = _find_ch(chs, ['LT2']) if s_raw is not None and i1 is not None and i2 is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1]-data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds)) - W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(label) del data; gc.collect() except Exception as e: log(f" [ERR] P2 {sz_file}: {e}") log(f" P2 paired: {len(Y)} windows (PGES={sum(Y)}, base={sum(1 for y in Y if y==0)})") if not Xs: return None, None, None return (np.array(Xs, dtype=np.float32), np.array(Xt, dtype=np.float32), np.array(Y, dtype=np.int32)) def extract_scalp_pges_windows(pid, meta_df): rows = meta_df[meta_df['Patient ID'] == pid] pdir = SEEG_ROOT / f'{pid}_SEEG' Xs, Y = [], [] for _, row in rows.iterrows(): sz_file = str(row['Seizure_Filename']) sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']) edf_path = pdir / sz_file if not edf_path.exists(): continue try: W = WIN_TARGET for (t0, t1, label) in [(sz_end+5, sz_end+185, 1), (sz_start-130, sz_start-10, 0)]: data, fs, chs = _crop_load_segment(edf_path, t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) if s_raw is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) for i in range(0, len(s_ds) - W, W): f = compute_features(s_ds[i:i+W], FS_TARGET) if f is not None: Xs.append(f); Y.append(label) del data; gc.collect() except Exception as e: log(f" [ERR] {pid} {sz_file}: {e}") log(f" {pid} scalp: {len(Y)} windows ({sum(Y)} PGES, {sum(1 for y in Y if y==0)} base)") if not Xs: return None, None return np.array(Xs, dtype=np.float32), np.array(Y, dtype=np.int32) def extract_tuh_topo_features(edf_path, csv_path, target_labels=('gnsz','tcsz')): try: raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) fs = raw.info['sfreq']; chs = raw.ch_names data = raw.get_data() if np.abs(np.median(data)) < 0.01: data *= 1e6 szs = [] with open(csv_path) as f: for line in f: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 5: continue if parts[3].strip() in target_labels: try: szs.append((float(parts[1]), float(parts[2]))) except: continue if not szs: return None, None szs = sorted(set(szs)) merged = [list(szs[0])] for s, e in szs[1:]: if s <= merged[-1][1]: merged[-1][1] = max(merged[-1][1], e) else: merged.append([s, e]) sig_raw = _scalp_avg(data, chs, TOPO_CH) if sig_raw is None: sig_raw = _scalp_avg(data, chs, ALL_SCALP) if sig_raw is None: return None, None if abs(fs - FS_TARGET) > 5: sig = _downsample(_bp(sig_raw, 0.5, 100, fs), fs, FS_TARGET) else: sig = _bp(sig_raw, 0.5, 100, fs) W = WIN_TARGET pges_wins, base_wins = [], [] for sz_start, sz_end in merged: pi_start = int((sz_end + 5) * FS_TARGET) pi_end = min(int((sz_end + 5 + 180) * FS_TARGET), len(sig)) pi_w = [] for i in range(pi_start, pi_end - W, W): f = compute_features(sig[i:i+W], FS_TARGET) if f is not None: pi_w.append(f) if len(pi_w) >= N_CTX + 2: pges_wins.append(np.array(pi_w, dtype=np.float32)) pre_end = int((sz_start - 10) * FS_TARGET) pre_start = max(0, int((sz_start - 10 - 120) * FS_TARGET)) pr_w = [] for i in range(pre_start, pre_end - W, W): f = compute_features(sig[i:i+W], FS_TARGET) if f is not None: pr_w.append(f) if len(pr_w) >= N_CTX + 2: base_wins.append(np.array(pr_w, dtype=np.float32)) return (pges_wins or None), (base_wins or None) except: return None, None # ══════════════════════════════════════════════════════════════════════════════ # Model and losses # ══════════════════════════════════════════════════════════════════════════════ class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj = nn.Linear(N_FEAT, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, N_FEAT) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): h = self.enc(self.proj(x), mask=self.mask[:x.shape[1], :x.shape[1]]) return h if return_hidden else self.head(h) def embed_windows(self, X_2d): x = X_2d.unsqueeze(1) h = self.enc(self.proj(x)) return h[:, 0, :] def _supcon_loss(z1, z2, y1, y2, temp=SUPCON_T): z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) sim = torch.mm(z1, z2.t()) / temp mask = (y1.unsqueeze(1) == y2.unsqueeze(0)).float() neg_mask = 1.0 - mask sim_max, _ = sim.max(dim=1, keepdim=True) sim = sim - sim_max.detach() exp_sim = torch.exp(sim) log_prob = sim - torch.log((exp_sim * (mask + neg_mask)).sum(1, keepdim=True) + 1e-8) loss = -(mask * log_prob).sum(1) / (mask.sum(1) + 1e-8) return loss.mean() def pretrain_three_source(model, thal_sessions, tuh_pges_wins, tuh_base_wins, scalp_pool_Xs, scalp_pool_y, bridge_Xs, bridge_Xt, bridge_y, thal_scaler, epochs=60, conditions='ABCD'): opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) model.train() tsm_seqs = [] for sess in thal_sessions: if len(sess) < N_CTX + 2: continue for i in range(N_CTX + 1, len(sess)): tsm_seqs.append(sess[i - N_CTX - 1: i]) tsm_seqs = np.array(tsm_seqs, dtype=np.float32) if tsm_seqs else None if tuh_pges_wins is not None and len(tuh_pges_wins) > 0 and 'B' in conditions: tuh_flat = np.vstack(tuh_pges_wins + (tuh_base_wins or [])) tuh_flat_y = np.concatenate([ np.ones(sum(len(s) for s in tuh_pges_wins), dtype=np.int64), np.zeros(sum(len(s) for s in (tuh_base_wins or [])), dtype=np.int64)]) if len(tuh_flat) > 2000: idx = np.random.choice(len(tuh_flat), 2000, replace=False) tuh_flat, tuh_flat_y = tuh_flat[idx], tuh_flat_y[idx] else: tuh_flat = tuh_flat_y = None p2s_flat = scalp_pool_Xs p2s_flat_y = scalp_pool_y.astype(np.int64) if scalp_pool_y is not None else None if bridge_Xs is not None and bridge_Xt is not None and 'C' in conditions: p2_pair_Xs = bridge_Xs p2_pair_Xt = bridge_Xt p2_pair_y = bridge_y.astype(np.int64) else: p2_pair_Xs = p2_pair_Xt = p2_pair_y = None for ep in range(epochs): total_loss = 0.0; n_batches = 0 if tsm_seqs is not None and len(tsm_seqs) >= 10: idx = np.random.choice(len(tsm_seqs), min(128, len(tsm_seqs)), replace=False) xc = torch.tensor(tsm_seqs[idx, :N_CTX], dtype=torch.float32).to(DEVICE) xt = torch.tensor(tsm_seqs[idx, N_CTX], dtype=torch.float32).to(DEVICE) pred = model(xc)[:, -1, :] L1 = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) total_loss += L1; n_batches += 1 if tuh_flat is not None and p2s_flat is not None: t_idx = np.random.choice(len(tuh_flat), min(64, len(tuh_flat)), replace=False) s_idx = np.random.choice(len(p2s_flat), min(64, len(p2s_flat)), replace=False) z_tuh = model.embed_windows( torch.tensor(tuh_flat[t_idx], dtype=torch.float32).to(DEVICE)) z_p2s = model.embed_windows( torch.tensor(p2s_flat[s_idx], dtype=torch.float32).to(DEVICE)) y_tuh = torch.tensor(tuh_flat_y[t_idx], dtype=torch.long).to(DEVICE) y_p2s = torch.tensor(p2s_flat_y[s_idx], dtype=torch.long).to(DEVICE) L2 = _supcon_loss(z_tuh, z_p2s, y_tuh, y_p2s) total_loss += LAM_L2 * L2; n_batches += 1 if p2_pair_Xs is not None and len(p2_pair_Xs) >= 8: idx = np.random.choice(len(p2_pair_Xs), min(64, len(p2_pair_Xs)), replace=False) z_s = model.embed_windows( torch.tensor(p2_pair_Xs[idx], dtype=torch.float32).to(DEVICE)) z_t = model.embed_windows( torch.tensor(p2_pair_Xt[idx], dtype=torch.float32).to(DEVICE)) y_b = torch.tensor(p2_pair_y[idx], dtype=torch.long).to(DEVICE) L3 = _supcon_loss(z_s, z_t, y_b, y_b) total_loss += LAM_L3 * L3; n_batches += 1 if n_batches > 0: opt.zero_grad(); total_loss.backward(); opt.step() if (ep+1) % 15 == 0: log(f" ep {ep+1}/{epochs} loss={total_loss.item()/n_batches:.4f}") model.eval() return model def build_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i-N_CTX:i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def encode(model, seqs): model.eval(); z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:,-1,:].cpu().numpy()) return np.vstack(z) def kshot_eval(model, seqs, lbls, K, n_trials=N_TRIALS): Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') if K == 0: pp = Z[lbls==1].mean(0); pb = Z[lbls==0].mean(0) return float(f1_score(lbls, (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int), zero_division=0)) scores = [] for _ in range(n_trials): pos = np.where(lbls==1)[0]; neg = np.where(lbls==0)[0] if len(pos) 0] log(f' {len(patients)} institutional thalamic patients.') log(' Adding GTC B2/B3...') for gtc_b in ['B2.edf', 'B3.edf']: Xb, yb = extract_gtc_thalamic_features(gtc_b) if Xb is not None: patients.append({'pid': gtc_b.replace('.edf',''), 'nucleus': 'ANT_GTC', 'X': Xb, 'labels': yb}) log(f' {gtc_b}: {len(Xb)} windows') log(f' Total thalamic pool: {len(patients)} sources') log('\nStep 2: Bridge features (P2 + GTC A2/A4)...') p2_Xs, p2_Xt, p2_y = extract_p2_paired_features(meta_df) HAS_P2 = p2_Xs is not None and len(p2_Xs) >= 10 log(' GTC A2/A4 bridge...') all_bridge_Xs, all_bridge_Xt, all_bridge_y = [], [], [] if HAS_P2: all_bridge_Xs.append(p2_Xs); all_bridge_Xt.append(p2_Xt); all_bridge_y.append(p2_y) for gtc_a in ['A2.edf', 'A4.edf']: aXs, aXt, ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: all_bridge_Xs.append(aXs); all_bridge_Xt.append(aXt); all_bridge_y.append(ay) if all_bridge_Xs: bridge_Xs = np.vstack(all_bridge_Xs) bridge_Xt = np.vstack(all_bridge_Xt) bridge_y = np.concatenate(all_bridge_y) HAS_BRIDGE = True log(f' Combined bridge: {len(bridge_y)} windows from P2+A2+A4') else: bridge_Xs = bridge_Xt = bridge_y = None HAS_BRIDGE = False inst_scalp_Xs, inst_scalp_y = [], [] if HAS_P2: inst_scalp_Xs.append(p2_Xs); inst_scalp_y.append(p2_y) for pid_s in ['P10', 'P12']: Xs_s, y_s = extract_scalp_pges_windows(pid_s, meta_df) if Xs_s is not None: inst_scalp_Xs.append(Xs_s); inst_scalp_y.append(y_s) if HAS_BRIDGE: inst_scalp_Xs.append(bridge_Xs); inst_scalp_y.append(bridge_y) if inst_scalp_Xs: inst_scalp_Xs = np.vstack(inst_scalp_Xs) inst_scalp_y = np.concatenate(inst_scalp_y) HAS_INST_SCALP = True log(f' Scalp pool (L2): {len(inst_scalp_Xs)} windows') else: inst_scalp_Xs = inst_scalp_y = None HAS_INST_SCALP = False log('\nStep 3: TUH corpus...') csvs = [f for f in glob.glob(os.path.join(TUH_BASE, '**', '*.csv'), recursive=True) if 'worksheet' not in f.lower()] def _has_target(f): try: return any(t in open(f, errors='ignore').read() for t in ['tcsz','gnsz']) except: return False tgt_csvs = [f for f in csvs if _has_target(f)] tgt_pairs = [(f, f.replace('.csv','.edf')) for f in tgt_csvs if os.path.exists(f.replace('.csv','.edf'))] np.random.seed(42) if len(tgt_pairs) > MAX_TUH: idxs = np.random.choice(len(tgt_pairs), MAX_TUH, replace=False) tgt_pairs = [tgt_pairs[i] for i in idxs] log(f' Using {len(tgt_pairs)} TUH files.') def _with_timeout(fn, *args, timeout=240): result = [None, None] def _r(): try: result[0], result[1] = fn(*args) except: pass t = threading.Thread(target=_r, daemon=True) t.start(); t.join(timeout=timeout) return result[0], result[1] tuh_pges, tuh_base = [], [] for k, (csv_p, edf_p) in enumerate(tgt_pairs): p, b = _with_timeout(extract_tuh_topo_features, edf_p, csv_p) if p: tuh_pges.extend(p) if b: tuh_base.extend(b) if (k+1) % 20 == 0: log(f' TUH {k+1}/{len(tgt_pairs)} | PGES sessions={len(tuh_pges)}') log(f' TUH: {len(tuh_pges)} PGES | {len(tuh_base)} base') HAS_TUH = len(tuh_pges) > 0 if HAS_TUH: tuh_scaler = StandardScaler().fit( np.vstack([w for sess in tuh_pges+tuh_base for w in [sess]])) tuh_pges_n = [tuh_scaler.transform(s) for s in tuh_pges] tuh_base_n = [tuh_scaler.transform(s) for s in tuh_base] else: tuh_pges_n = tuh_base_n = [] log(f'\nStep 4: LOSO (N_TRIALS={N_TRIALS} per fold)...') results = {c: {k: [] for k in K_VALS} for c in ['A','B','C','D','E']} for fold_i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) seqs, lbls = build_seqs(test_p, scaler) if seqs is None or lbls.sum() == 0: log(f' [{fold_i+1:02d}] {pid}: skip'); continue log(f'\n [{fold_i+1:02d}/{len(patients)}] Test={pid} PGES={int(lbls.sum())} Base={int((lbls==0).sum())}') thal_sess = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] if len(sess) >= N_CTX + 2: thal_sess.append(sess) fold_bridge_Xs_parts, fold_bridge_Xt_parts, fold_bridge_y_parts = [], [], [] if HAS_P2 and pid != 'P2': fold_bridge_Xs_parts.append(p2_Xs) fold_bridge_Xt_parts.append(p2_Xt) fold_bridge_y_parts.append(p2_y) for gtc_a in ['A2.edf', 'A4.edf']: aXs, aXt, ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: fold_bridge_Xs_parts.append(aXs) fold_bridge_Xt_parts.append(aXt) fold_bridge_y_parts.append(ay) if fold_bridge_Xs_parts: fold_bridge_Xs = np.vstack(fold_bridge_Xs_parts) fold_bridge_Xt = np.vstack(fold_bridge_Xt_parts) fold_bridge_y = np.concatenate(fold_bridge_y_parts) fold_bridge_Xs_n = tuh_scaler.transform(fold_bridge_Xs) if HAS_TUH else fold_bridge_Xs fold_bridge_Xt_n = scaler.transform(fold_bridge_Xt) HAS_FOLD_BRIDGE = True else: fold_bridge_Xs_n = fold_bridge_Xt_n = fold_bridge_y = None HAS_FOLD_BRIDGE = False if HAS_INST_SCALP: inst_fold_Xs, inst_fold_y = [], [] for pid_s in ['P2', 'P10', 'P12']: if pid_s == pid: continue if pid_s == 'P2' and HAS_P2: inst_fold_Xs.append(p2_Xs); inst_fold_y.append(p2_y) else: Xs_tmp, y_tmp = extract_scalp_pges_windows(pid_s, meta_df) if Xs_tmp is not None: inst_fold_Xs.append(Xs_tmp); inst_fold_y.append(y_tmp) if HAS_FOLD_BRIDGE: inst_fold_Xs.append(fold_bridge_Xs); inst_fold_y.append(fold_bridge_y) if inst_fold_Xs: inst_fold_Xs = np.vstack(inst_fold_Xs) inst_fold_y = np.concatenate(inst_fold_y) inst_fold_Xs_n = tuh_scaler.transform(inst_fold_Xs) if HAS_TUH else inst_fold_Xs else: inst_fold_Xs_n = inst_fold_y = None else: inst_fold_Xs_n = inst_fold_y = None fold_tuh_pges = tuh_pges_n if HAS_TUH else [] fold_tuh_base = tuh_base_n if HAS_TUH else [] # Condition A m_A = CausalTransformer().to(DEVICE) m_A = pretrain_three_source(m_A, thal_sess, None, None, None, None, None, None, None, scaler, epochs=SEQ_EP_THAL, conditions='A') res_A = {k: kshot_eval(m_A, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['A'][k].append(res_A[k]) log(f' A: K=0={res_A[0]:.4f} K=10={res_A[10]:.4f}') del m_A # Condition B m_B = CausalTransformer().to(DEVICE) m_B = pretrain_three_source(m_B, thal_sess, fold_tuh_pges, fold_tuh_base, inst_fold_Xs_n, inst_fold_y, None, None, None, scaler, epochs=SEQ_EP_PRETRAIN, conditions='AB') res_B = {k: kshot_eval(m_B, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['B'][k].append(res_B[k]) log(f' B: K=0={res_B[0]:.4f} K=10={res_B[10]:.4f}') del m_B # Condition C if HAS_FOLD_BRIDGE: m_C = CausalTransformer().to(DEVICE) m_C = pretrain_three_source(m_C, thal_sess, None, None, None, None, fold_bridge_Xs_n, fold_bridge_Xt_n, fold_bridge_y, scaler, epochs=SEQ_EP_PRETRAIN, conditions='AC') res_C = {k: kshot_eval(m_C, seqs, lbls, k) for k in K_VALS} del m_C else: res_C = res_A for k in K_VALS: results['C'][k].append(res_C[k]) log(f' C: K=0={res_C[0]:.4f} K=10={res_C[10]:.4f}') # Condition D (MAIN) m_D = CausalTransformer().to(DEVICE) m_D = pretrain_three_source(m_D, thal_sess, fold_tuh_pges, fold_tuh_base, inst_fold_Xs_n, inst_fold_y, fold_bridge_Xs_n if HAS_FOLD_BRIDGE else None, fold_bridge_Xt_n if HAS_FOLD_BRIDGE else None, fold_bridge_y if HAS_FOLD_BRIDGE else None, scaler, epochs=SEQ_EP_PRETRAIN, conditions='ABCD' if HAS_FOLD_BRIDGE else 'AB') res_D = {k: kshot_eval(m_D, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['D'][k].append(res_D[k]) log(f' D: K=0={res_D[0]:.4f} K=10={res_D[10]:.4f} [MAIN]') # Condition E pgs = np.where(np.diff(np.concatenate([[0], lbls])) == 1)[0] if len(pgs) > 0: onset = pgs[0]; pend = onset while pend < len(lbls) and lbls[pend] == 1: pend += 1 auto_p = np.arange(onset, min(onset+10, pend)) base_idx = np.where(lbls==0)[0] pre_b = base_idx[base_idx < onset] if len(pre_b) < 10: pre_b = base_idx[:10] auto_b = pre_b[-10:] Z = encode(m_D, seqs) pp = Z[auto_p].mean(0); pb = Z[auto_b].mean(0) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) results['E'][0].append(float(f1_score(lbls, preds, zero_division=0))) for k in [2,5,10]: results['E'][k].append(res_D[k]) else: for k in K_VALS: results['E'][k].append(res_D[k]) log(f' E: K=0={results["E"][0][-1]:.4f}') del m_D; gc.collect() if DEVICE.type == 'cuda': torch.cuda.empty_cache() # ── Results ─────────────────────────────────────────────────────────────── log('\n' + '='*60) log(f'=== C13 High-Trials (N_TRIALS={N_TRIALS}) Results ===') log('='*60) labels_m = { 'A': 'A: L1 only (baseline)', 'B': 'B: L1+L2 (TUH scalp align)', 'C': 'C: L1+L3 (bridge P2+A2+A4)', 'D': 'D: L1+L2+L3 MAIN', 'E': 'E: D + Day-0', } cond_means = {c: {k: np.nanmean(results[c][k]) for k in K_VALS} for c in ['A','B','C','D','E']} cond_stds = {c: {k: np.nanstd(results[c][k]) for k in K_VALS} for c in ['A','B','C','D','E']} log(f"{'Condition':<35} {'K=0':>8} {'K=2':>8} {'K=5':>8} {'K=10':>8}") log('-'*70) for cond in ['A','B','C','D','E']: m = cond_means[cond] s = cond_stds[cond] log(f"{labels_m[cond]:<35} " f"{m[0]:>6.4f}±{s[0]:.3f} {m[2]:>6.4f}±{s[2]:.3f} " f"{m[5]:>6.4f}±{s[5]:.3f} {m[10]:>6.4f}±{s[10]:.3f}") log('\nGain D over A:') for k in K_VALS: log(f" K={k}: {cond_means['D'][k]-cond_means['A'][k]:+.4f}") # Wilcoxon at every K log('\nWilcoxon D vs A:') for k in K_VALS: d_vals = [v for v in results['D'][k] if not np.isnan(v)] a_vals = [v for v in results['A'][k] if not np.isnan(v)] n = min(len(d_vals), len(a_vals)) if n >= 5: try: stat, pv = wilcoxon(d_vals[:n], a_vals[:n]) sig = '**' if pv<0.01 else ('*' if pv<0.05 else ('~' if pv<0.10 else 'ns')) log(f" K={k}: stat={stat:.3f} p={pv:.4f} {sig}") except Exception as e: log(f" K={k}: {e}") # Bootstrap 95% CI for D at each K log('\nBootstrap 95% CI for D:') for k in K_VALS: v = np.array([x for x in results['D'][k] if not np.isnan(x)]) if len(v) >= 3: bs = [np.mean(np.random.choice(v, len(v))) for _ in range(2000)] log(f" K={k}: mean={np.mean(v):.4f} CI=[{np.percentile(bs,2.5):.4f}, {np.percentile(bs,97.5):.4f}]") # Save np.save(str(OUT_ROOT/'results_raw.npy'), results) rows = [] for c in ['A','B','C','D','E']: for k in K_VALS: v = [x for x in results[c][k] if not np.isnan(x)] rows.append({'cond':c,'K':k,'mean':np.nanmean(v) if v else np.nan, 'std':np.nanstd(v) if v else np.nan,'n':len(v)}) pd.DataFrame(rows).to_csv(str(OUT_ROOT/'results_summary.csv'), index=False) # Per-patient per_pat = [] for fold_i, test_p in enumerate(patients): pid = test_p['pid'] for c in ['A','B','C','D','E']: for k in K_VALS: v = results[c][k][fold_i] if fold_i < len(results[c][k]) else np.nan per_pat.append({'pid':pid,'cond':c,'K':k,'F1':v}) pd.DataFrame(per_pat).to_csv(str(OUT_ROOT/'results_per_patient.csv'), index=False) # Figure fig, axes = plt.subplots(1, 2, figsize=(14, 5)) colors = {'A':'#7f8c8d','B':'#e74c3c','C':'#f39c12','D':'#27ae60','E':'#2980b9'} for ax, conds, title in [ (axes[0], ['A','D'], 'A vs D (MAIN comparison)'), (axes[1], ['A','B','C','D','E'], 'All conditions'), ]: for cond in conds: vals = [cond_means[cond][k] for k in K_VALS] stds = [cond_stds[cond][k] for k in K_VALS] ax.plot(K_VALS, vals, 'o-', color=colors[cond], label=labels_m[cond], linewidth=2, markersize=8) ax.fill_between(K_VALS, [v-s for v,s in zip(vals,stds)], [v+s for v,s in zip(vals,stds)], color=colors[cond], alpha=0.12) ax.set_xlabel('K shots'); ax.set_ylabel('F1') ax.set_title(title); ax.legend(fontsize=8); ax.grid(alpha=0.3) ax.set_ylim(0.5, 1.05) plt.suptitle(f'C13 Three-Source Contrastive (N_TRIALS={N_TRIALS})', fontsize=12) plt.tight_layout() fig.savefig(str(OUT_ROOT/'c13_hightrials.png'), dpi=150) plt.close() log(f'Figure → {OUT_ROOT}/c13_hightrials.png') log('COMPLETE')