# -*- coding: utf-8 -*- """ DACTRL C14 — Bio-Prior Prototype Initialization for Honest K=0 Evaluation ========================================================================== Core finding that motivates this experiment ------------------------------------------- Every prior experiment (C13, v3, TSM) reports K=0 using: pp = Z[test_lbls==1].mean(0) # ← uses ALL test patient labels pb = Z[test_lbls==0].mean(0) # ← uses ALL test patient labels This is NOT a real zero-shot scenario — it is an oracle that cheats by using the test patient's own labeled data to construct prototypes. A true Day-0 patient has NO labeled seizures yet. C14 measures three honest K=0 variants: K0_oracle : current method (uses test labels) — reported in all prior work K0_train : prototype = mean embedding of training patients' labeled windows This is the true deployment scenario: 7 training patients have labeled data; new patient has none. K0_bio : prototype = canonical thalamic PGES signature (mean raw feature vector from training patients → encoded). Encodes the perspective-inversion biology directly as a prior. No NEW patient data needed; no even test-patient-SPECIFIC training data. Most deployable. Expected findings ----------------- K0_oracle ≈ 0.882–0.903 (current reported numbers — upper bound) K0_train ~ 0.75–0.85 (honest cross-patient zero-shot) K0_bio ~ 0.75–0.85 (bio-prior; close to train if encoder is good) If K0_train/K0_bio > 0.75, the zero-shot claim is defensible for thesis. If they are < 0.65, the thesis should report "K=0 is not viable without patient-specific labels" — which is also an honest, publishable finding. Evaluated on two encoders: A : Thalamic TSM only (baseline) D : C13 three-source contrastive (best domain transfer model) Mac paths (/Volumes/Expansion). Run on M1 Max 64GB. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, glob, random, threading, warnings from pathlib import Path from datetime import datetime from collections import Counter import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy.signal import resample_poly, butter, filtfilt from scipy.stats import wilcoxon from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F try: import mne; mne.set_log_level('ERROR') except ImportError: pass warnings.filterwarnings('ignore') if torch.cuda.is_available(): DEVICE = torch.device('cuda') print(f"[GPU] {torch.cuda.get_device_name(0)}") elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): DEVICE = torch.device('mps') print("[MPS] Apple Silicon GPU") else: DEVICE = torch.device('cpu') print("[CPU]") torch.manual_seed(42); np.random.seed(42); random.seed(42) # ── Paths ───────────────────────────────────────────────────────────────────── SEEG_ROOT = Path("/Volumes/Expansion/phd_datasets/Data/Thalamus/SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" TUH_BASE = "/Volumes/Expansion/phd_datasets/Data/Scalp/tueeg_data/tuh_eeg_seizure/v2.0.3/edf/dev" GTC_ROOT = Path("/Volumes/Expansion/phd_datasets/Data/Thalamus/eeg_ecg_us_clinical/GTC_Focal_SEEG") OUT_ROOT = Path("/Volumes/Expansion/phd_datasets/Code/pges_toolkit_mac/results/dactrl_c14_bioprior") OUT_ROOT.mkdir(parents=True, exist_ok=True) # ── Constants ────────────────────────────────────────────────────────────────── FS_TARGET = 250 WIN_SEC = 5 WIN_TARGET = WIN_SEC * FS_TARGET MAX_TUH = 300 N_FEAT = 17 D_MODEL = 64; N_HEADS = 4; N_LAYERS = 4; N_CTX = 8 SEQ_EP_THAL = 60 # epochs for condition A SEQ_EP_PRETRAIN = 30 # epochs for condition D SEQ_LR = 3e-4 SUPCON_T = 0.07 LAM_L2 = 0.5 LAM_L3 = 1.0 K_VALS = [2, 5, 10] # standard K (K=0 handled separately) N_TRIALS = 10 THAL_PIDS = ['P1', 'P2', 'P3', 'P4', 'P5', 'P7', 'P8', 'P15'] NUCLEUS_MAP = { 'P1':'CeM','P2':'CL','P3':'CeM','P4':'MD','P5':'CeM', 'P7':'CL','P8':'CL','P15':'ANT', } TOPO_CH = ['FZ', 'CZ', 'C3', 'F3'] ALL_SCALP = ['FP1','FP2','F7','F8','F3','F4','FZ', 'T3','T4','C3','C4','CZ','T5','T6','P3','P4','PZ','O1','O2'] def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) # ══════════════════════════════════════════════════════════════════════════════ # Feature extraction — identical 17-feature pipeline # ══════════════════════════════════════════════════════════════════════════════ def compute_features(sig, fs): sig = sig - sig.mean() n = len(sig) if n < int(fs * 0.5): return None from numpy.fft import rfft, rfftfreq rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs); psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0.0 total = float(psd.sum()) + 1e-10 delta = band(0.5,4); theta = band(4,8); alpha = band(8,13) beta = band(13,30); gamma = band(80,150) sr = (delta+theta)/(alpha+beta+1e-10) p = psd/(psd.sum()+1e-10); p = p[p>0] shan = float(-np.sum(p*np.log(p+1e-10))) supp = float(np.mean(np.abs(sig) < 0.05*np.max(np.abs(sig)+1e-10))) u = sig[:min(200,n)] def _apen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def phi(mm): x=np.array([u[i:i+mm] for i in range(N-mm+1)]) C=np.sum(np.max(np.abs(x[:,None]-x[None,:]),axis=2)<=r,axis=0)/(N-mm+1) return np.sum(np.log(C+1e-10))/(N-mm+1) return abs(phi(m)-phi(m+1)) apen=float(_apen(u)) def _sampen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def _count(mm): cnt=0 for i in range(N-mm): cnt+=np.sum(np.max(np.abs(np.array([u[i:i+mm]])-np.array( [u[j:j+mm] for j in range(N-mm) if j!=i])),axis=1)<=r) return cnt A=_count(m+1);B=_count(m) return float(-np.log((A+1e-10)/(B+1e-10))) sampen=float(_sampen(u)) hist,_=np.histogram(sig,bins=10); hist=hist/(hist.sum()+1e-10) etc=float(-np.sum(hist[hist>0]*np.log(hist[hist>0]))) bsig=(sig>np.median(sig)).astype(int) lzc=float(len(set(''.join(map(str,bsig))[i:i+4] for i in range(len(bsig)-3)))) m3=3; perms=[tuple(np.argsort(sig[i:i+m3])) for i in range(n-m3+1)] cnt2=Counter(perms); tot=sum(cnt2.values()) pent=float(-sum((v/tot)*np.log(v/tot+1e-10) for v in cnt2.values())) return np.array([rms,ll,zc,var,delta/total,theta/total,alpha/total,beta/total, sr,shan,supp,apen,sampen,etc,lzc,pent,gamma/total],dtype=np.float32) def _find_ch(ch_names, targets): cu = [c.upper().replace('-','').replace(' ','') for c in ch_names] for t in targets: t2 = t.upper().replace('-','').replace(' ','') for i,c in enumerate(cu): if c == t2 or t2 in c: return i return None def _scalp_avg(data, ch_names, targets): idxs = [i for t in targets for i in [_find_ch(ch_names,[t])] if i is not None] idxs = list(dict.fromkeys(idxs)) return data[idxs].mean(axis=0) if idxs else None def _downsample(sig, fs_in, fs_out): from math import gcd g = gcd(int(fs_in), int(fs_out)) return resample_poly(sig, int(fs_out)//g, int(fs_in)//g) def _bp(sig, lo, hi, fs, order=4): nyq = fs/2 b, a = butter(order, [lo/nyq, hi/nyq], btype='band') return filtfilt(b, a, sig) def _crop_load_segment(edf_path, t_start_s, t_end_s): try: raw = mne.io.read_raw_edf(str(edf_path), preload=False, verbose=False) fs = raw.info['sfreq'] dur = raw.n_times / fs t0 = max(0.0, t_start_s); t1 = min(t_end_s, dur) if t1 - t0 < 1.0: raw.close(); return None, None, None raw.crop(tmin=t0, tmax=t1); raw.load_data() data = raw.get_data(); chs = raw.ch_names raw.close(); del raw if np.abs(np.median(data)) < 0.01: data *= 1e6 return data, fs, chs except Exception: return None, None, None # ══════════════════════════════════════════════════════════════════════════════ # Bridge / TUH data loading — identical to C13 # ══════════════════════════════════════════════════════════════════════════════ def extract_gtc_bridge_features(edf_name): edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None, None GTC_TOPO = ['EEG FZ','EEG CZ','EEG C3','EEG F3', 'EEG FZ-REF','EEG CZ-REF','EEG C3-REF','EEG F3-REF', 'FZ','CZ','C3','F3'] Xs, Xt, Y = [], [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, GTC_TOPO) i1 = _find_ch(chs, ['EEG LT1','LT1']); i2 = _find_ch(chs, ['EEG LT2','LT2']) if s_raw is None or i1 is None or i2 is None: del data; continue s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1]-data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds))-W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(label) del data; gc.collect() if not Xs: return None, None, None log(f' GTC {edf_name}: {len(Y)} bridge windows') return (np.array(Xs,dtype=np.float32), np.array(Xt,dtype=np.float32), np.array(Y,dtype=np.int32)) def extract_gtc_thalamic_features(edf_name): edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None X, Y = [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue i1 = _find_ch(chs, ['EEG LTP1','LTP1']); i2 = _find_ch(chs, ['EEG LTP2','LTP2']) if i1 is None: del data; continue ref = data[i2] if i2 is not None else np.zeros_like(data[i1]) t_ds = _downsample(_bp(data[i1]-ref, 0.5, 100, fs), fs, FS_TARGET) for i in range(0, len(t_ds)-WIN_TARGET, WIN_TARGET): f = compute_features(t_ds[i:i+WIN_TARGET], FS_TARGET) if f is not None: X.append(f); Y.append(label) del data; gc.collect() if not X: return None, None return np.array(X,dtype=np.float32), np.array(Y,dtype=np.int32) def extract_p2_paired_features(meta_df): p2_rows = meta_df[meta_df['Patient ID']=='P2'] pdir = SEEG_ROOT / 'P2_SEEG' Xs, Xt, Y = [], [], [] for _, row in p2_rows.iterrows(): sz_file = str(row['Seizure_Filename']); sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']); edf_path = pdir / sz_file if not edf_path.exists(): continue try: for (t0, t1, label) in [(sz_end+5, sz_end+185, 1), (sz_start-130, sz_start-10, 0)]: data, fs, chs = _crop_load_segment(edf_path, t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) i1 = _find_ch(chs, ['LT1']); i2 = _find_ch(chs, ['LT2']) if s_raw is not None and i1 is not None and i2 is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1]-data[i2], 0.5, 100, fs), fs, FS_TARGET) for i in range(0, min(len(s_ds),len(t_ds))-WIN_TARGET, WIN_TARGET): fs_f = compute_features(s_ds[i:i+WIN_TARGET], FS_TARGET) ft_f = compute_features(t_ds[i:i+WIN_TARGET], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(label) del data; gc.collect() except Exception as e: log(f" [ERR] P2 {sz_file}: {e}") log(f" P2 paired: {len(Y)} windows") if not Xs: return None, None, None return (np.array(Xs,dtype=np.float32), np.array(Xt,dtype=np.float32), np.array(Y,dtype=np.int32)) def extract_scalp_pges_windows(pid, meta_df): rows = meta_df[meta_df['Patient ID']==pid]; pdir = SEEG_ROOT / f'{pid}_SEEG' Xs, Y = [], [] for _, row in rows.iterrows(): sz_file = str(row['Seizure_Filename']); sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']); edf_path = pdir / sz_file if not edf_path.exists(): continue try: for (t0, t1, label) in [(sz_end+5, sz_end+185, 1), (sz_start-130, sz_start-10, 0)]: data, fs, chs = _crop_load_segment(edf_path, t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) if s_raw is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) for i in range(0, len(s_ds)-WIN_TARGET, WIN_TARGET): f = compute_features(s_ds[i:i+WIN_TARGET], FS_TARGET) if f is not None: Xs.append(f); Y.append(label) del data; gc.collect() except Exception as e: log(f" [ERR] {pid} {sz_file}: {e}") if not Xs: return None, None return np.array(Xs,dtype=np.float32), np.array(Y,dtype=np.int32) def extract_tuh_topo_features(edf_path, csv_path, target_labels=('gnsz','tcsz')): try: raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) fs = raw.info['sfreq']; chs = raw.ch_names; data = raw.get_data() if np.abs(np.median(data)) < 0.01: data *= 1e6 szs = [] with open(csv_path) as f: for line in f: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 5: continue if parts[3].strip() in target_labels: try: szs.append((float(parts[1]), float(parts[2]))) except: continue if not szs: return None, None szs = sorted(set(szs)); merged = [list(szs[0])] for s, e in szs[1:]: if s <= merged[-1][1]: merged[-1][1] = max(merged[-1][1], e) else: merged.append([s, e]) sig_raw = _scalp_avg(data, chs, TOPO_CH) if sig_raw is None: sig_raw = _scalp_avg(data, chs, ALL_SCALP) if sig_raw is None: return None, None sig = _downsample(_bp(sig_raw, 0.5, 100, fs), fs, FS_TARGET) if abs(fs-FS_TARGET)>5 \ else _bp(sig_raw, 0.5, 100, fs) W = WIN_TARGET; pges_wins, base_wins = [], [] for sz_start, sz_end in merged: pi_w = [compute_features(sig[i:i+W], FS_TARGET) for i in range(int((sz_end+5)*FS_TARGET), min(int((sz_end+185)*FS_TARGET),len(sig))-W, W)] pi_w = [f for f in pi_w if f is not None] if len(pi_w) >= N_CTX+2: pges_wins.append(np.array(pi_w,dtype=np.float32)) pr_w = [compute_features(sig[i:i+W], FS_TARGET) for i in range(max(0,int((sz_start-130)*FS_TARGET)), int((sz_start-10)*FS_TARGET)-W, W)] pr_w = [f for f in pr_w if f is not None] if len(pr_w) >= N_CTX+2: base_wins.append(np.array(pr_w,dtype=np.float32)) return (pges_wins or None), (base_wins or None) except: return None, None # ══════════════════════════════════════════════════════════════════════════════ # Model — same CausalTransformer as all prior experiments # ══════════════════════════════════════════════════════════════════════════════ class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj = nn.Linear(N_FEAT, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, N_FEAT) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): h = self.enc(self.proj(x), mask=self.mask[:x.shape[1], :x.shape[1]]) return h if return_hidden else self.head(h) def embed_windows(self, X_2d): x = X_2d.unsqueeze(1) return self.enc(self.proj(x))[:, 0, :] def _supcon_loss(z1, z2, y1, y2, temp=SUPCON_T): z1 = F.normalize(z1, dim=1); z2 = F.normalize(z2, dim=1) sim = torch.mm(z1, z2.t()) / temp mask = (y1.unsqueeze(1) == y2.unsqueeze(0)).float() sim_max, _ = sim.max(dim=1, keepdim=True); sim = sim - sim_max.detach() exp_sim = torch.exp(sim) log_prob = sim - torch.log((exp_sim * torch.ones_like(mask)).sum(1, keepdim=True) + 1e-8) loss = -(mask * log_prob).sum(1) / (mask.sum(1) + 1e-8) return loss.mean() def pretrain_three_source(model, thal_sessions, tuh_pges_wins, tuh_base_wins, scalp_pool_Xs, scalp_pool_y, bridge_Xs, bridge_Xt, bridge_y, epochs=60, conditions='A'): opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) model.train() # L1: TSM sequences tsm_seqs = [] for sess in thal_sessions: if len(sess) < N_CTX+2: continue for i in range(N_CTX+1, len(sess)): tsm_seqs.append(sess[i-N_CTX-1:i]) tsm_seqs = np.array(tsm_seqs, dtype=np.float32) if tsm_seqs else None # L2: TUH flat windows if tuh_pges_wins and 'B' in conditions: tuh_flat = np.vstack(tuh_pges_wins + (tuh_base_wins or [])) tuh_flat_y = np.concatenate([ np.ones(sum(len(s) for s in tuh_pges_wins), dtype=np.int64), np.zeros(sum(len(s) for s in (tuh_base_wins or [])), dtype=np.int64)]) if len(tuh_flat) > 2000: idx = np.random.choice(len(tuh_flat), 2000, replace=False) tuh_flat, tuh_flat_y = tuh_flat[idx], tuh_flat_y[idx] else: tuh_flat = tuh_flat_y = None p2s_flat = scalp_pool_Xs p2s_flat_y = scalp_pool_y.astype(np.int64) if scalp_pool_y is not None else None # L3: bridge pairs if bridge_Xs is not None and 'C' in conditions: p2_pair_Xs = bridge_Xs; p2_pair_Xt = bridge_Xt p2_pair_y = bridge_y.astype(np.int64) else: p2_pair_Xs = p2_pair_Xt = p2_pair_y = None for ep in range(epochs): total_loss = 0.0; n_batches = 0 if tsm_seqs is not None and len(tsm_seqs) >= 10: idx = np.random.choice(len(tsm_seqs), min(128, len(tsm_seqs)), replace=False) xc = torch.tensor(tsm_seqs[idx,:N_CTX], dtype=torch.float32).to(DEVICE) xt = torch.tensor(tsm_seqs[idx, N_CTX], dtype=torch.float32).to(DEVICE) pred = model(xc)[:,-1,:] L1 = (1.-F.cosine_similarity(pred,xt,dim=1).mean()) + 0.5*F.mse_loss(pred,xt) total_loss += L1; n_batches += 1 if tuh_flat is not None and p2s_flat is not None: t_idx = np.random.choice(len(tuh_flat), min(64,len(tuh_flat)), replace=False) s_idx = np.random.choice(len(p2s_flat), min(64,len(p2s_flat)), replace=False) z_tuh = model.embed_windows(torch.tensor(tuh_flat[t_idx],dtype=torch.float32).to(DEVICE)) z_p2s = model.embed_windows(torch.tensor(p2s_flat[s_idx],dtype=torch.float32).to(DEVICE)) y_tuh = torch.tensor(tuh_flat_y[t_idx],dtype=torch.long).to(DEVICE) y_p2s = torch.tensor(p2s_flat_y[s_idx],dtype=torch.long).to(DEVICE) L2 = _supcon_loss(z_tuh, z_p2s, y_tuh, y_p2s) total_loss += LAM_L2*L2; n_batches += 1 if p2_pair_Xs is not None and len(p2_pair_Xs) >= 8: idx = np.random.choice(len(p2_pair_Xs), min(64,len(p2_pair_Xs)), replace=False) z_s = model.embed_windows(torch.tensor(p2_pair_Xs[idx],dtype=torch.float32).to(DEVICE)) z_t = model.embed_windows(torch.tensor(p2_pair_Xt[idx],dtype=torch.float32).to(DEVICE)) y_b = torch.tensor(p2_pair_y[idx],dtype=torch.long).to(DEVICE) L3 = _supcon_loss(z_s, z_t, y_b, y_b) total_loss += LAM_L3*L3; n_batches += 1 if n_batches > 0: opt.zero_grad(); total_loss.backward(); opt.step() if (ep+1) % 15 == 0: log(f" ep {ep+1}/{epochs} loss={total_loss.item()/n_batches:.4f}") model.eval(); return model def build_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i-N_CTX:i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs,dtype=np.float32), np.array(lbls,dtype=np.int32) def encode(model, seqs): model.eval(); z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64],dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b,return_hidden=True)[:,-1,:].cpu().numpy()) return np.vstack(z) # ══════════════════════════════════════════════════════════════════════════════ # Three honest K=0 evaluation variants — the core contribution of C14 # ══════════════════════════════════════════════════════════════════════════════ def k0_oracle(model, test_seqs, test_lbls): """ ORACLE (current method in all prior experiments). Uses ALL test patient labels to build prototypes — NOT deployable. Reported here for comparison only. """ Z = encode(model, test_seqs) if test_lbls.sum() == 0: return float('nan') pp = Z[test_lbls==1].mean(0) pb = Z[test_lbls==0].mean(0) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) return float(f1_score(test_lbls, preds, zero_division=0)) def k0_train_prior(model, train_patients, scaler, test_seqs, test_lbls): """ TRAINING-PRIOR K=0 — true deployment scenario. Prototype = mean embedding of labeled PGES/baseline windows from 7 training patients. Test patient contributes ZERO data to prototype construction. This is what K=0 should have measured in all prior experiments. """ all_pges_z, all_base_z = [], [] for p in train_patients: seqs_t, lbls_t = build_seqs(p, scaler) if seqs_t is None: continue Z_t = encode(model, seqs_t) if lbls_t.sum() > 0: all_pges_z.append(Z_t[lbls_t==1]) base_mask = lbls_t==0 if base_mask.sum() > 0: all_base_z.append(Z_t[base_mask]) if not all_pges_z or not all_base_z: return float('nan') pp = np.vstack(all_pges_z).mean(0) pb = np.vstack(all_base_z).mean(0) Z = encode(model, test_seqs) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) return float(f1_score(test_lbls, preds, zero_division=0)) def k0_bio_prior(model, train_patients, scaler, test_seqs, test_lbls): """ BIO-PRIOR K=0 — perspective-inversion biology encoded as a prior. Instead of embedding training windows in the learned space (which could overfit to idiosyncratic patient anatomy), construct a CANONICAL prototype from the mean raw feature vector of training patients' PGES/baseline windows. The mean feature vector represents the average thalamic PGES signature: - high RMS, high delta, high SR (the thalamus actively generates delta) - low suppression ratio (feature 10): thalamus is NOT silent during PGES - low zero crossings: slow delta has few sign changes This canonical prototype is then encoded through the same CausalTransformer to get the bio-prior embedding. The key insight: a well-trained encoder should map the canonical PGES feature trajectory to a consistent region of embedding space regardless of patient-specific scaling. No test patient data is used at any point. """ all_pges_X, all_base_X = [], [] for p in train_patients: X = p['X'].astype(np.float32) y = p['labels'] X_n = scaler.transform(X) if y.sum() > 0: all_pges_X.append(X_n[y==1]) if (y==0).sum() > 0: all_base_X.append(X_n[y==0]) if not all_pges_X or not all_base_X: return float('nan') # Canonical mean feature vectors from training patients mean_pges = np.vstack(all_pges_X).mean(0) # (17,) mean_base = np.vstack(all_base_X).mean(0) # (17,) # Encode as a context sequence: repeat the mean vector N_CTX times # This simulates "the encoder sees N_CTX windows all looking like PGES" def _encode_canonical(feat_vec): seq = np.tile(feat_vec[None, None, :], (1, N_CTX, 1)).astype(np.float32) # (1, N_CTX, 17) t = torch.tensor(seq, dtype=torch.float32).to(DEVICE) with torch.no_grad(): h = model(t, return_hidden=True) # (1, N_CTX, D_MODEL) return h[0, -1, :].cpu().numpy() # (D_MODEL,) — last context position pp = _encode_canonical(mean_pges) pb = _encode_canonical(mean_base) Z = encode(model, test_seqs) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) return float(f1_score(test_lbls, preds, zero_division=0)) def kshot_eval(model, seqs, lbls, K, n_trials=N_TRIALS): """Standard K>0 evaluation — unchanged from all prior scripts.""" Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') scores = [] for _ in range(n_trials): pos = np.where(lbls==1)[0]; neg = np.where(lbls==0)[0] if len(pos) 0] log(f' {len(patients)} thalamic patients') for gtc_b in ['B2.edf', 'B3.edf']: Xb, yb = extract_gtc_thalamic_features(gtc_b) if Xb is not None: patients.append({'pid': gtc_b.replace('.edf',''), 'nucleus':'ANT_GTC', 'X': Xb, 'labels': yb}) log(f' Total L1 pool: {len(patients)} sources') # ── Step 2: Bridge (P2 + GTC A2/A4) ────────────────────────────────────── log('\nStep 2: Bridge features...') p2_Xs, p2_Xt, p2_y = extract_p2_paired_features(meta_df) HAS_P2 = p2_Xs is not None and len(p2_Xs) >= 10 all_bridge_Xs, all_bridge_Xt, all_bridge_y = [], [], [] if HAS_P2: all_bridge_Xs.append(p2_Xs); all_bridge_Xt.append(p2_Xt); all_bridge_y.append(p2_y) for gtc_a in ['A2.edf', 'A4.edf']: aXs, aXt, ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: all_bridge_Xs.append(aXs); all_bridge_Xt.append(aXt); all_bridge_y.append(ay) if all_bridge_Xs: bridge_Xs = np.vstack(all_bridge_Xs); bridge_Xt = np.vstack(all_bridge_Xt) bridge_y = np.concatenate(all_bridge_y); HAS_BRIDGE = True log(f' Bridge: {len(bridge_y)} windows from P2+A2+A4') else: bridge_Xs = bridge_Xt = bridge_y = None; HAS_BRIDGE = False inst_scalp_Xs, inst_scalp_y = [], [] if HAS_P2: inst_scalp_Xs.append(p2_Xs); inst_scalp_y.append(p2_y) for pid_s in ['P10', 'P12']: Xs_s, y_s = extract_scalp_pges_windows(pid_s, meta_df) if Xs_s is not None: inst_scalp_Xs.append(Xs_s); inst_scalp_y.append(y_s) if HAS_BRIDGE: inst_scalp_Xs.append(bridge_Xs); inst_scalp_y.append(bridge_y) if inst_scalp_Xs: inst_scalp_Xs = np.vstack(inst_scalp_Xs); inst_scalp_y = np.concatenate(inst_scalp_y) HAS_INST_SCALP = True else: inst_scalp_Xs = inst_scalp_y = None; HAS_INST_SCALP = False # ── Step 3: TUH ─────────────────────────────────────────────────────────── log('\nStep 3: TUH corpus...') csvs = [f for f in glob.glob(os.path.join(TUH_BASE,'**','*.csv'), recursive=True) if 'worksheet' not in f.lower()] def _has_target(f): try: return any(t in open(f,errors='ignore').read() for t in ['tcsz','gnsz']) except: return False tgt_pairs = [(f, f.replace('.csv','.edf')) for f in csvs if _has_target(f) and os.path.exists(f.replace('.csv','.edf'))] np.random.seed(42) if len(tgt_pairs) > MAX_TUH: idxs = np.random.choice(len(tgt_pairs), MAX_TUH, replace=False) tgt_pairs = [tgt_pairs[i] for i in idxs] log(f' Using {len(tgt_pairs)} TUH files') def _with_timeout(fn, *args, timeout=240): result=[None,None] def _r(): try: result[0],result[1]=fn(*args) except: pass t=threading.Thread(target=_r,daemon=True); t.start(); t.join(timeout=timeout) return result[0], result[1] tuh_pges, tuh_base = [], [] for k, (csv_p, edf_p) in enumerate(tgt_pairs): p, b = _with_timeout(extract_tuh_topo_features, edf_p, csv_p) if p: tuh_pges.extend(p) if b: tuh_base.extend(b) if (k+1) % 20 == 0: log(f' TUH {k+1}/{len(tgt_pairs)} | PGES={len(tuh_pges)}') log(f' TUH: {len(tuh_pges)} PGES | {len(tuh_base)} base') HAS_TUH = len(tuh_pges) > 0 if HAS_TUH: tuh_scaler = StandardScaler().fit(np.vstack(tuh_pges+tuh_base)) tuh_pges_n = [tuh_scaler.transform(s) for s in tuh_pges] tuh_base_n = [tuh_scaler.transform(s) for s in tuh_base] else: tuh_pges_n = tuh_base_n = [] # ── Step 4: LOSO — two encoders × three K=0 variants ──────────────────── log(f'\nStep 4: LOSO — 2 encoders × 3 K=0 variants + K=2,5,10...') # Results storage # k0 variants: oracle, train, bio # encoders: A (TSM-only), D (three-source) r = {enc: {'k0_oracle':[], 'k0_train':[], 'k0_bio':[], 2:[], 5:[], 10:[]} for enc in ['A','D']} for fold_i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) seqs, lbls = build_seqs(test_p, scaler) if seqs is None or lbls.sum() == 0: log(f' [{fold_i+1:02d}] {pid}: skip'); continue log(f'\n [{fold_i+1:02d}/{len(patients)}] Test={pid} ' f'PGES={int(lbls.sum())} Base={int((lbls==0).sum())}') thal_sess = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels']==0] if len(sess) >= N_CTX+2: thal_sess.append(sess) # Fold bridge fold_bXs_parts, fold_bXt_parts, fold_by_parts = [], [], [] if HAS_P2 and pid != 'P2': fold_bXs_parts.append(p2_Xs); fold_bXt_parts.append(p2_Xt) fold_by_parts.append(p2_y) for gtc_a in ['A2.edf', 'A4.edf']: aXs,aXt,ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: fold_bXs_parts.append(aXs); fold_bXt_parts.append(aXt); fold_by_parts.append(ay) if fold_bXs_parts: fold_bXs = np.vstack(fold_bXs_parts); fold_bXt = np.vstack(fold_bXt_parts) fold_by = np.concatenate(fold_by_parts) fold_bXs_n = tuh_scaler.transform(fold_bXs) if HAS_TUH else fold_bXs fold_bXt_n = scaler.transform(fold_bXt); HAS_FOLD_BRIDGE = True else: fold_bXs_n = fold_bXt_n = fold_by = None; HAS_FOLD_BRIDGE = False if HAS_INST_SCALP: iXs, iy = [], [] for pid_s in ['P2','P10','P12']: if pid_s == pid: continue if pid_s == 'P2' and HAS_P2: iXs.append(p2_Xs); iy.append(p2_y) else: Xs_t, y_t = extract_scalp_pges_windows(pid_s, meta_df) if Xs_t is not None: iXs.append(Xs_t); iy.append(y_t) if HAS_FOLD_BRIDGE: iXs.append(fold_bXs); iy.append(fold_by) if iXs: iXs_n = tuh_scaler.transform(np.vstack(iXs)) if HAS_TUH else np.vstack(iXs) iy_arr = np.concatenate(iy) else: iXs_n = iy_arr = None else: iXs_n = iy_arr = None fold_tuh_pges = tuh_pges_n if HAS_TUH else [] fold_tuh_base = tuh_base_n if HAS_TUH else [] for enc_name, cond_str, n_ep in [('A', 'A', SEQ_EP_THAL), ('D', 'ABCD' if HAS_FOLD_BRIDGE else 'AB', SEQ_EP_PRETRAIN)]: model = CausalTransformer().to(DEVICE) model = pretrain_three_source( model, thal_sess, fold_tuh_pges if 'B' in cond_str else None, fold_tuh_base if 'B' in cond_str else None, iXs_n if 'B' in cond_str else None, iy_arr if 'B' in cond_str else None, fold_bXs_n if ('C' in cond_str and HAS_FOLD_BRIDGE) else None, fold_bXt_n if ('C' in cond_str and HAS_FOLD_BRIDGE) else None, fold_by if ('C' in cond_str and HAS_FOLD_BRIDGE) else None, epochs=n_ep, conditions=cond_str) # Three K=0 variants r0_oracle = k0_oracle(model, seqs, lbls) r0_train = k0_train_prior(model, train_ps, scaler, seqs, lbls) r0_bio = k0_bio_prior(model, train_ps, scaler, seqs, lbls) r[enc_name]['k0_oracle'].append(r0_oracle) r[enc_name]['k0_train'].append(r0_train) r[enc_name]['k0_bio'].append(r0_bio) # Standard K>0 for k in K_VALS: r[enc_name][k].append(kshot_eval(model, seqs, lbls, k)) log(f' {enc_name}: oracle={r0_oracle:.4f} train={r0_train:.4f} ' f'bio={r0_bio:.4f} K=10={r[enc_name][10][-1]:.4f}') del model; gc.collect() if DEVICE.type == 'cuda': torch.cuda.empty_cache() # ── Results ─────────────────────────────────────────────────────────────── log('\n' + '='*60) log('=== C14: Honest K=0 Evaluation ===') log('='*60) log(f"\n{'Method':<30} {'K0_oracle':>10} {'K0_train':>10} {'K0_bio':>10} " f"{'K=2':>8} {'K=5':>8} {'K=10':>8}") log('-'*85) for enc in ['A','D']: label = 'A: TSM-only (baseline)' if enc=='A' else 'D: C13 three-source [MAIN]' k0_or = np.nanmean(r[enc]['k0_oracle']) k0_tr = np.nanmean(r[enc]['k0_train']) k0_bi = np.nanmean(r[enc]['k0_bio']) k2 = np.nanmean(r[enc][2]) k5 = np.nanmean(r[enc][5]) k10 = np.nanmean(r[enc][10]) log(f"{label:<30} {k0_or:>10.4f} {k0_tr:>10.4f} {k0_bi:>10.4f} " f"{k2:>8.4f} {k5:>8.4f} {k10:>8.4f}") log('\nKey comparison (D encoder):') log(f" Reported K=0 (oracle): {np.nanmean(r['D']['k0_oracle']):.4f} ← all prior C13 results") log(f" True K=0 (train prior): {np.nanmean(r['D']['k0_train']):.4f} ← real deployment scenario") log(f" True K=0 (bio prior): {np.nanmean(r['D']['k0_bio']):.4f} ← bio-informed zero-shot") gap = np.nanmean(r['D']['k0_oracle']) - np.nanmean(r['D']['k0_train']) log(f" Oracle inflation: {gap:+.4f} ← gap between reported and honest K=0") # Wilcoxon: train_prior vs bio_prior (are they equivalent?) tp = [v for v in r['D']['k0_train'] if not np.isnan(v)] bp = [v for v in r['D']['k0_bio'] if not np.isnan(v)] n = min(len(tp), len(bp)) if n >= 5: try: stat, pv = wilcoxon(tp[:n], bp[:n]) log(f"\n Wilcoxon train_prior vs bio_prior: p={pv:.4f} " f"({'equivalent' if pv>0.05 else 'different'})") except: pass # Bootstrap CIs log('\nBootstrap 95% CI (D encoder, 2000 resamples):') for key, label in [('k0_oracle','K0 oracle'), ('k0_train','K0 train'), ('k0_bio','K0 bio'), (10,'K=10')]: v = np.array([x for x in r['D'][key] if not np.isnan(x)]) if len(v) >= 3: bs = [np.mean(np.random.choice(v,len(v))) for _ in range(2000)] log(f" {label:<12}: {np.mean(v):.4f} 95% CI [{np.percentile(bs,2.5):.4f}, " f"{np.percentile(bs,97.5):.4f}]") # Save rows = [] for enc in ['A','D']: for key in ['k0_oracle','k0_train','k0_bio',2,5,10]: v = [x for x in r[enc][key] if not np.isnan(x)] rows.append({'encoder':enc,'metric':str(key), 'mean':np.nanmean(v) if v else np.nan, 'std':np.nanstd(v) if v else np.nan, 'n':len(v)}) pd.DataFrame(rows).to_csv(str(OUT_ROOT/'c14_results.csv'), index=False) per_pat = [] for fold_i, test_p in enumerate(patients): pid = test_p['pid'] for enc in ['A','D']: for key in ['k0_oracle','k0_train','k0_bio',2,5,10]: v = r[enc][key][fold_i] if fold_i < len(r[enc][key]) else np.nan per_pat.append({'pid':pid,'encoder':enc,'metric':str(key),'F1':v}) pd.DataFrame(per_pat).to_csv(str(OUT_ROOT/'c14_per_patient.csv'), index=False) # Figure fig, axes = plt.subplots(1, 2, figsize=(14, 5)) k0_labels = ['K0\noracle\n(prior work)', 'K0\ntrain\n(honest)', 'K0\nbio\n(prior)'] k_labels = ['K=2', 'K=5', 'K=10'] colors = {'A': '#7f8c8d', 'D': '#27ae60'} for ax, enc, title in [(axes[0],'A','Encoder A: TSM-only'), (axes[1],'D','Encoder D: C13 Three-Source [MAIN]')]: k0_vals = [np.nanmean(r[enc]['k0_oracle']), np.nanmean(r[enc]['k0_train']), np.nanmean(r[enc]['k0_bio'])] k0_stds = [np.nanstd(r[enc]['k0_oracle']), np.nanstd(r[enc]['k0_train']), np.nanstd(r[enc]['k0_bio'])] kn_vals = [np.nanmean(r[enc][k]) for k in K_VALS] kn_stds = [np.nanstd(r[enc][k]) for k in K_VALS] x = np.arange(len(k0_labels)) bars = ax.bar(x, k0_vals, yerr=k0_stds, capsize=5, color=[colors[enc]]*3, alpha=0.9, edgecolor='black', linewidth=0.8) for bar, a in zip(bars, [0.5, 0.9, 0.7]): bar.set_alpha(a) # Shade oracle bar differently to signal "not honest" bars[0].set_hatch('//') bars[0].set_alpha(0.35) ax.set_xticks(x); ax.set_xticklabels(k0_labels, fontsize=9) ax.set_ylabel('F1 Score'); ax.set_ylim(0.4, 1.05) ax.set_title(f'{title}\nK=0 Variants') ax.axhline(0.5, color='red', linestyle='--', alpha=0.4, label='chance') ax.axhline(kn_vals[0], color='blue', linestyle=':', alpha=0.5, label=f'K=2={kn_vals[0]:.3f}') ax.legend(fontsize=8); ax.grid(axis='y', alpha=0.3) for bar, val in zip(bars, k0_vals): ax.text(bar.get_x()+bar.get_width()/2, val+0.01, f'{val:.3f}', ha='center', va='bottom', fontsize=9, fontweight='bold') plt.suptitle('C14: Honest K=0 — Oracle vs Training-Prior vs Bio-Prior\n' '(hatched = uses test patient labels, NOT deployable)', fontsize=11) plt.tight_layout() fig.savefig(str(OUT_ROOT/'c14_honest_k0.png'), dpi=150) plt.close() log(f'\nFigure → {OUT_ROOT}/c14_honest_k0.png') log('COMPLETE')