# -*- coding: utf-8 -*- """ DACTRL C13 — Three-Source Integrated Contrastive Pre-training ============================================================== The core argument ----------------- Every prior scalp transfer experiment failed because it tried to cross the scalp→thalamic domain gap in ONE step (feature CycleGAN, waveform translator, etc.). The gap is real but it has a structured bridge: patient P2. TUH scalp ───same domain───► P2 scalp ───paired supervision───► P2 thalamic ───same domain───► all 13 patients (no gap) (known, supervised) (no gap) There are only TWO domain gaps in the entire pipeline and BOTH can be collapsed: 1. TUH scalp ↔ P2 scalp : same modality, same signal, same 10-20 layout. A contrastive loss on PGES windows from both should trivially align them. 2. P2 scalp ↔ P2 thalamic : the ONLY true cross-modal gap, but we have ground-truth pairs (same patient, same time window) to supervise it. Gaps that do NOT exist: 3. P2 thalamic ↔ other patients' thalamic : same modality, same electrode type, same PGES physiology. The TSM self-supervised loss handles this. Strategy: unified shared encoder --------------------------------- One CausalTransformer E maps all feature sequences to the same embedding space. Three losses are applied simultaneously during pre-training: L1 — TSM (next-window prediction) on thalamic baseline sequences. Teaches temporal dynamics of the thalamic signal. L2 — Same-domain SupCon: TUH-scalp-PGES and P2-scalp-PGES should embed close together; both should be far from baseline. This aligns TUH with P2's scalp without any translation. L3 — Cross-modal SupCon on paired windows: P2 scalp at time t and P2 thalamic at time t have the same label and same neural state — force them to embed close together. This is the bridge: the encoder must find a representation where scalp-PGES and thalamic-PGES are indistinguishable. Final embedding space (after all three losses): TUH-scalp-PGES ≈ P2-scalp-PGES ≈ P2-thalamic-PGES ≈ other-thalamic-PGES (via L2 transitivity) No explicit translation is required. The encoder learns to read thalamic PGES from scalp features because L3 forces the scalp and thalamic representations of the SAME EVENT to coincide. Conditions ---------- A: Thalamic-only TSM (canonical baseline — L1 only) B: L1 + L2 (TSM + TUH/P2 scalp same-domain alignment) Does aligning TUH with P2 scalp alone help? (no bridge yet) C: L1 + L3 (TSM + paired cross-modal bridge, P2 only) Does the bridge alone (without TUH scale) help? D: L1 + L2 + L3 (all three — the full integrated approach) [MAIN] E: D + Day-0 temporal heuristic LOSO note: When test patient is P2, L3 cannot use P2's paired data. Condition C/D fall back to A/B respectively for that fold. """ import os; os.environ.setdefault('PYTHONIOENCODING', 'utf-8') import gc, glob, random, threading, warnings, copy from pathlib import Path from datetime import datetime from collections import Counter import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from scipy.signal import resample_poly, butter, filtfilt from scipy.stats import wilcoxon from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F try: import mne; mne.set_log_level('ERROR') except ImportError: pass warnings.filterwarnings('ignore') DEVICE = torch.device('mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')) print(f"[GPU/MPS] {DEVICE}") torch.manual_seed(42); np.random.seed(42); random.seed(42) # ── Paths ───────────────────────────────────────────────────────────────────── SEEG_ROOT = Path(r"/Volumes/Expansion/phd_datasets/Data/Thalamus/SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" TUH_BASE = r"/Volumes/Expansion/phd_datasets/Data/Scalp/tueeg_data/tuh_eeg_seizure/v2.0.3/edf" GTC_ROOT = Path(r"/Volumes/Expansion/phd_datasets/Data/Thalamus/eeg_ecg_us_clinical/GTC_Focal_SEEG") OUT_ROOT = Path(r"/Volumes/Expansion/phd_datasets/Code/pges_toolkit_mac/results/dactrl_three_source") OUT_ROOT.mkdir(parents=True, exist_ok=True) # ── Constants ────────────────────────────────────────────────────────────────── FS_NATIVE = 2048 FS_TARGET = 250 WIN_SEC = 5 WIN_TARGET = WIN_SEC * FS_TARGET MAX_TUH = 300 N_FEAT = 17 D_MODEL = 64; N_HEADS = 4; N_LAYERS = 4; N_CTX = 8 SEQ_EP_THAL = 60; SEQ_EP_PRETRAIN = 30; SEQ_LR = 3e-4 SUPCON_T = 0.07 # SupCon temperature LAM_L2 = 0.5 # weight for L2 (same-domain SupCon) LAM_L3 = 1.0 # weight for L3 (cross-modal bridge) — higher: bridge matters more K_VALS = [0, 2, 5, 10] N_TRIALS = 5 NUCLEUS_MAP = { 'P1':'CeM','P2':'CL','P3':'CeM','P4':'MD','P5':'CeM', 'P6':'MD','P7':'CL','P8':'CL','P9':'CeM','P10':'ANT', 'P11':'ANT','P12':'ANT','P13':'ANT','P14':'ANT','P15':'ANT', } # Confirmed LT/LTP thalamic channel patients (verified via EDF header scan) # P6=LTHAL, P9=RT, P10-P14=wrong-hemisphere/non-thalamic contacts → excluded THAL_PIDS = ['P1', 'P2', 'P3', 'P4', 'P5', 'P7', 'P8', 'P15'] INVERT_IDX = [10, 0, 3] # SR, RMS, Variance (C2) # Topology-informed scalp channels for left CL thalamus (LT1-LT2) # Odd = LEFT, Even = RIGHT, Z = MIDLINE TOPO_CH = ['FZ', 'CZ', 'C3', 'F3'] ALL_SCALP = ['FP1','FP2','F7','F8','F3','F4','FZ', 'T3','T4','C3','C4','CZ','T5','T6','P3','P4','PZ','O1','O2'] def log(msg): print(f"[{datetime.now().strftime('%H:%M:%S')}] {msg}", flush=True) # ══════════════════════════════════════════════════════════════════════════════ # Feature extraction (17-feature pipeline) # ══════════════════════════════════════════════════════════════════════════════ def compute_features(sig, fs): sig = sig - sig.mean() n = len(sig) if n < int(fs * 0.5): return None from numpy.fft import rfft, rfftfreq rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs); psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0.0 total = float(psd.sum()) + 1e-10 delta = band(0.5,4); theta = band(4,8); alpha = band(8,13) beta = band(13,30); gamma = band(80,150) sr = (delta+theta)/(alpha+beta+1e-10) p = psd/(psd.sum()+1e-10); p = p[p>0] shan = float(-np.sum(p*np.log(p+1e-10))) supp = float(np.mean(np.abs(sig) < 0.05*np.max(np.abs(sig)+1e-10))) u = sig[:min(200,n)] def _apen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def phi(mm): x=np.array([u[i:i+mm] for i in range(N-mm+1)]) C=np.sum(np.max(np.abs(x[:,None]-x[None,:]),axis=2)<=r,axis=0)/(N-mm+1) return np.sum(np.log(C+1e-10))/(N-mm+1) return abs(phi(m)-phi(m+1)) apen=float(_apen(u)) def _sampen(u,m=2,r=None): if r is None: r=0.2*np.std(u)+1e-10 N=len(u) def _count(mm): cnt=0 for i in range(N-mm): cnt+=np.sum(np.max(np.abs(np.array([u[i:i+mm]])-np.array( [u[j:j+mm] for j in range(N-mm) if j!=i])),axis=1)<=r) return cnt A=_count(m+1);B=_count(m) return float(-np.log((A+1e-10)/(B+1e-10))) sampen=float(_sampen(u)) hist,_=np.histogram(sig,bins=10); hist=hist/(hist.sum()+1e-10) etc=float(-np.sum(hist[hist>0]*np.log(hist[hist>0]))) bsig=(sig>np.median(sig)).astype(int) lzc=float(len(set(''.join(map(str,bsig))[i:i+4] for i in range(len(bsig)-3)))) m3=3; perms=[tuple(np.argsort(sig[i:i+m3])) for i in range(n-m3+1)] cnt2=Counter(perms); tot=sum(cnt2.values()) pent=float(-sum((v/tot)*np.log(v/tot+1e-10) for v in cnt2.values())) return np.array([rms,ll,zc,var,delta/total,theta/total,alpha/total,beta/total, sr,shan,supp,apen,sampen,etc,lzc,pent,gamma/total],dtype=np.float32) def _find_ch(ch_names, targets): cu = [c.upper().replace('-','').replace(' ','') for c in ch_names] for t in targets: t2 = t.upper().replace('-','').replace(' ','') for i,c in enumerate(cu): if c == t2 or t2 in c: return i return None def _scalp_avg(data, ch_names, targets): idxs = [i for t in targets for i in [_find_ch(ch_names,[t])] if i is not None] idxs = list(dict.fromkeys(idxs)) # dedup preserve order return data[idxs].mean(axis=0) if idxs else None def _downsample(sig, fs_in, fs_out): from math import gcd g = gcd(int(fs_in), int(fs_out)) return resample_poly(sig, int(fs_out)//g, int(fs_in)//g) def _bp(sig, lo, hi, fs, order=4): nyq = fs/2 b, a = butter(order, [lo/nyq, hi/nyq], btype='band') return filtfilt(b, a, sig) # ══════════════════════════════════════════════════════════════════════════════ # P2 paired feature extraction (scalp features + thalamic features, aligned) # ══════════════════════════════════════════════════════════════════════════════ def _crop_load_segment(edf_path, t_start_s, t_end_s): """ Memory-safe EDF read: crop to [t_start_s, t_end_s] before loading. Returns (data_array, fs, ch_names) or (None, None, None). Never loads more than the requested segment into RAM. """ try: raw = mne.io.read_raw_edf(str(edf_path), preload=False, verbose=False) fs = raw.info['sfreq'] dur = raw.n_times / fs t0 = max(0.0, t_start_s) t1 = min(t_end_s, dur) if t1 - t0 < 1.0: raw.close(); return None, None, None raw.crop(tmin=t0, tmax=t1) raw.load_data() data = raw.get_data() chs = raw.ch_names raw.close(); del raw if np.abs(np.median(data)) < 0.01: data *= 1e6 return data, fs, chs except Exception as e: return None, None, None def extract_gtc_bridge_features(edf_name): """ Paired scalp + thalamic features from GTC A2/A4 (simultaneous LT1-8 + scalp 10-20). Each file is a ~240s seizure recording. Heuristic timing: baseline : t=0-50s (pre-ictal) PGES : t=130-220s (post-ictal) Returns (Xs, Xt, y) or (None, None, None). """ edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None, None GTC_TOPO = ['EEG FZ', 'EEG CZ', 'EEG C3', 'EEG F3', 'EEG FZ-REF', 'EEG CZ-REF', 'EEG C3-REF', 'EEG F3-REF', 'FZ', 'CZ', 'C3', 'F3'] Xs, Xt, Y = [], [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue s_raw = _scalp_avg(data, chs, GTC_TOPO) i1 = _find_ch(chs, ['EEG LT1', 'LT1']) i2 = _find_ch(chs, ['EEG LT2', 'LT2']) if s_raw is None or i1 is None or i2 is None: del data; continue s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1] - data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds)) - W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(label) del data; gc.collect() if not Xs: return None, None, None log(f' GTC {edf_name}: {len(Y)} bridge windows ' f'(PGES={sum(Y)}, base={sum(1 for y in Y if y==0)})') return (np.array(Xs, dtype=np.float32), np.array(Xt, dtype=np.float32), np.array(Y, dtype=np.int32)) def extract_gtc_thalamic_features(edf_name): """ Thalamic-only features from GTC B2/B3 (LTP1-LTP6, no scalp). Returns (X, y) for adding to L1 thalamic pool. """ edf_path = GTC_ROOT / edf_name if not edf_path.exists(): return None, None X, Y = [], [] for (t0, t1, label) in [(0.0, 50.0, 0), (130.0, 220.0, 1)]: data, fs, chs = _crop_load_segment(str(edf_path), t0, t1) if data is None: continue i1 = _find_ch(chs, ['EEG LTP1', 'LTP1']) i2 = _find_ch(chs, ['EEG LTP2', 'LTP2']) if i1 is None: del data; continue ref = data[i2] if i2 is not None else np.zeros_like(data[i1]) t_ds = _downsample(_bp(data[i1] - ref, 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, len(t_ds) - W, W): f = compute_features(t_ds[i:i+W], FS_TARGET) if f is not None: X.append(f); Y.append(label) del data; gc.collect() if not X: return None, None return np.array(X, dtype=np.float32), np.array(Y, dtype=np.int32) def extract_p2_paired_features(meta_df): """ Returns three arrays (all aligned — same time window): Xs_topo : (N, 17) scalp features from Fz/Cz/C3/F3 average Xt : (N, 17) thalamic features from LT1-LT2 bipolar y : (N,) label 1=PGES 0=baseline Uses crop-before-load to avoid OOM on large EDF files. """ p2_rows = meta_df[meta_df['Patient ID'] == 'P2'] pdir = SEEG_ROOT / 'P2_SEEG' Xs, Xt, Y = [], [], [] for _, row in p2_rows.iterrows(): sz_file = str(row['Seizure_Filename']) sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']) edf_path = pdir / sz_file if not edf_path.exists(): continue try: # ── PGES segment: [sz_end+5, sz_end+185] ────────────────────── data, fs, chs = _crop_load_segment(edf_path, sz_end + 5, sz_end + 185) if data is not None: s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) i1 = _find_ch(chs, ['LT1']); i2 = _find_ch(chs, ['LT2']) if s_raw is not None and i1 is not None and i2 is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1] - data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds)) - W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(1) del data # ── Baseline segment: [sz_start-130, sz_start-10] ───────────── data, fs, chs = _crop_load_segment(edf_path, sz_start - 130, sz_start - 10) if data is not None: s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) i1 = _find_ch(chs, ['LT1']); i2 = _find_ch(chs, ['LT2']) if s_raw is not None and i1 is not None and i2 is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) t_ds = _downsample(_bp(data[i1] - data[i2], 0.5, 100, fs), fs, FS_TARGET) W = WIN_TARGET for i in range(0, min(len(s_ds), len(t_ds)) - W, W): fs_f = compute_features(s_ds[i:i+W], FS_TARGET) ft_f = compute_features(t_ds[i:i+W], FS_TARGET) if fs_f is not None and ft_f is not None: Xs.append(fs_f); Xt.append(ft_f); Y.append(0) del data gc.collect() except Exception as e: log(f" [ERR] P2 {sz_file}: {e}") log(f" P2 paired features: {len(Y)} windows " f"({sum(Y)} PGES, {sum(1 for y in Y if y==0)} baseline)") if not Xs: return None, None, None return (np.array(Xs, dtype=np.float32), np.array(Xt, dtype=np.float32), np.array(Y, dtype=np.int32)) def extract_scalp_pges_windows(pid, meta_df): """ Extract labeled scalp PGES and baseline windows from patients who have scalp EEG but NOT necessarily thalamic — i.e. P10, P12. Uses crop-before-load to avoid OOM on large EDF files (P12 = 18 GB). Returns (Xs, y) — scalp features, label (1=PGES/0=baseline). """ rows = meta_df[meta_df['Patient ID'] == pid] pdir = SEEG_ROOT / f'{pid}_SEEG' Xs, Y = [], [] for _, row in rows.iterrows(): sz_file = str(row['Seizure_Filename']) sz_start = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']) edf_path = pdir / sz_file if not edf_path.exists(): continue try: W = WIN_TARGET # ── PGES segment ──────────────────────────────────────────────── data, fs, chs = _crop_load_segment(edf_path, sz_end + 5, sz_end + 185) if data is not None: s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) if s_raw is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) for i in range(0, len(s_ds) - W, W): f = compute_features(s_ds[i:i+W], FS_TARGET) if f is not None: Xs.append(f); Y.append(1) del data # ── Baseline segment ───────────────────────────────────────────── data, fs, chs = _crop_load_segment(edf_path, sz_start - 130, sz_start - 10) if data is not None: s_raw = _scalp_avg(data, chs, TOPO_CH) if s_raw is None: s_raw = _scalp_avg(data, chs, ALL_SCALP) if s_raw is not None: s_ds = _downsample(_bp(s_raw, 0.5, 100, fs), fs, FS_TARGET) for i in range(0, len(s_ds) - W, W): f = compute_features(s_ds[i:i+W], FS_TARGET) if f is not None: Xs.append(f); Y.append(0) del data gc.collect() except Exception as e: log(f" [ERR] {pid} {sz_file}: {e}") log(f" {pid} scalp: {len(Y)} windows ({sum(Y)} PGES, {sum(1 for y in Y if y==0)} base)") if not Xs: return None, None return np.array(Xs, dtype=np.float32), np.array(Y, dtype=np.int32) # ══════════════════════════════════════════════════════════════════════════════ # TUH scalp feature extraction (topology-informed: Fz/Cz/C3/F3) # ══════════════════════════════════════════════════════════════════════════════ def extract_tuh_topo_features(edf_path, csv_path, target_labels=('gnsz','tcsz')): try: raw = mne.io.read_raw_edf(edf_path, preload=True, verbose=False) fs = raw.info['sfreq']; chs = raw.ch_names data = raw.get_data() if np.abs(np.median(data)) < 0.01: data *= 1e6 szs = [] with open(csv_path) as f: for line in f: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 5: continue if parts[3].strip() in target_labels: try: szs.append((float(parts[1]), float(parts[2]))) except: continue if not szs: return None, None szs = sorted(set(szs)) merged = [list(szs[0])] for s, e in szs[1:]: if s <= merged[-1][1]: merged[-1][1] = max(merged[-1][1], e) else: merged.append([s, e]) sig_raw = _scalp_avg(data, chs, TOPO_CH) if sig_raw is None: sig_raw = _scalp_avg(data, chs, ALL_SCALP) if sig_raw is None: return None, None if abs(fs - FS_TARGET) > 5: sig = _downsample(_bp(sig_raw, 0.5, 100, fs), fs, FS_TARGET) else: sig = _bp(sig_raw, 0.5, 100, fs) W = WIN_TARGET pges_wins, base_wins = [], [] for sz_start, sz_end in merged: pi_start = int((sz_end + 5) * FS_TARGET) pi_end = min(int((sz_end + 5 + 180) * FS_TARGET), len(sig)) pi_w = [] for i in range(pi_start, pi_end - W, W): f = compute_features(sig[i:i+W], FS_TARGET) if f is not None: pi_w.append(f) if len(pi_w) >= N_CTX + 2: pges_wins.append(np.array(pi_w, dtype=np.float32)) pre_end = int((sz_start - 10) * FS_TARGET) pre_start = max(0, int((sz_start - 10 - 120) * FS_TARGET)) pr_w = [] for i in range(pre_start, pre_end - W, W): f = compute_features(sig[i:i+W], FS_TARGET) if f is not None: pr_w.append(f) if len(pr_w) >= N_CTX + 2: base_wins.append(np.array(pr_w, dtype=np.float32)) return (pges_wins or None), (base_wins or None) except: return None, None # ══════════════════════════════════════════════════════════════════════════════ # Three-source integrated encoder (CausalTransformer with multi-source losses) # ══════════════════════════════════════════════════════════════════════════════ class CausalTransformer(nn.Module): def __init__(self): super().__init__() self.proj = nn.Linear(N_FEAT, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, N_FEAT) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): h = self.enc(self.proj(x), mask=self.mask[:x.shape[1], :x.shape[1]]) return h if return_hidden else self.head(h) def embed_windows(self, X_2d): """Embed (N, 17) flat windows as single-step sequences → (N, D_MODEL).""" x = X_2d.unsqueeze(1) # (N, 1, 17) h = self.enc(self.proj(x)) return h[:, 0, :] # (N, D_MODEL) def _supcon_loss(z1, z2, y1, y2, temp=SUPCON_T): """ SupCon between two sets of embeddings z1, z2 with labels y1, y2. Positives: same class across both sets. z1, z2: (N, D) — can be different sizes; loss computed from z1's perspective. """ z1 = F.normalize(z1, dim=1) z2 = F.normalize(z2, dim=1) # Cross-similarity (N1, N2) sim = torch.mm(z1, z2.t()) / temp # Mask: positive if same class mask = (y1.unsqueeze(1) == y2.unsqueeze(0)).float() neg_mask = 1.0 - mask # For numerical stability sim_max, _ = sim.max(dim=1, keepdim=True) sim = sim - sim_max.detach() exp_sim = torch.exp(sim) log_prob = sim - torch.log((exp_sim * (mask + neg_mask)).sum(1, keepdim=True) + 1e-8) loss = -(mask * log_prob).sum(1) / (mask.sum(1) + 1e-8) return loss.mean() def pretrain_three_source(model, thal_sessions, tuh_pges_wins, tuh_base_wins, scalp_pool_Xs, scalp_pool_y, # L2: P2+P10+P12 scalp (unlabeled for TUH alignment) bridge_Xs, bridge_Xt, bridge_y, # L3: P2 scalp↔thalamic SAME-TIME pairs thal_scaler, epochs=60, conditions='ABCD'): """ Three-source integrated pre-training. L1: TSM next-window prediction on thalamic baseline sessions L2: SupCon between TUH scalp windows and P2 scalp windows (same-domain) L3: SupCon between P2 scalp windows and P2 thalamic windows at same time t (bridge) conditions: string of which losses to include, e.g. 'A'=L1 only, 'AB'=L1+L2, etc. """ opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) model.train() # ── L1: TSM sequences ────────────────────────────────────────────────── tsm_seqs = [] for sess in thal_sessions: if len(sess) < N_CTX + 2: continue for i in range(N_CTX + 1, len(sess)): tsm_seqs.append(sess[i - N_CTX - 1: i]) tsm_seqs = np.array(tsm_seqs, dtype=np.float32) if tsm_seqs else None # ── L2: flat TUH scalp windows and institutional scalp pool ──────────── if tuh_pges_wins is not None and len(tuh_pges_wins) > 0 and 'B' in conditions: tuh_flat = np.vstack(tuh_pges_wins + (tuh_base_wins or [])) tuh_flat_y = np.concatenate([ np.ones(sum(len(s) for s in tuh_pges_wins), dtype=np.int64), np.zeros(sum(len(s) for s in (tuh_base_wins or [])), dtype=np.int64)]) # Subsample for memory max_tuh = 2000 if len(tuh_flat) > max_tuh: idx = np.random.choice(len(tuh_flat), max_tuh, replace=False) tuh_flat, tuh_flat_y = tuh_flat[idx], tuh_flat_y[idx] else: tuh_flat = tuh_flat_y = None # P2+P10+P12 scalp pool for L2 (scalp_pool_Xs/y passed in) p2s_flat = scalp_pool_Xs p2s_flat_y = scalp_pool_y.astype(np.int64) if scalp_pool_y is not None else None # ── L3: P2 paired windows (SAME TIME — bridge only, P2 only) ──────────── if bridge_Xs is not None and bridge_Xt is not None and 'C' in conditions: p2_pair_Xs = bridge_Xs p2_pair_Xt = bridge_Xt p2_pair_y = bridge_y.astype(np.int64) else: p2_pair_Xs = p2_pair_Xt = p2_pair_y = None for ep in range(epochs): total_loss = 0.0; n_batches = 0 # ── L1 batch (TSM) ────────────────────────────────────────────────── if tsm_seqs is not None and len(tsm_seqs) >= 10: idx = np.random.choice(len(tsm_seqs), min(128, len(tsm_seqs)), replace=False) xc = torch.tensor(tsm_seqs[idx, :N_CTX], dtype=torch.float32).to(DEVICE) xt = torch.tensor(tsm_seqs[idx, N_CTX], dtype=torch.float32).to(DEVICE) pred = model(xc)[:, -1, :] L1 = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) \ + 0.5 * F.mse_loss(pred, xt) total_loss += L1; n_batches += 1 # ── L2 batch (TUH scalp ↔ P2 scalp same-domain SupCon) ───────────── if tuh_flat is not None and p2s_flat is not None: t_idx = np.random.choice(len(tuh_flat), min(64, len(tuh_flat)), replace=False) s_idx = np.random.choice(len(p2s_flat), min(64, len(p2s_flat)), replace=False) z_tuh = model.embed_windows( torch.tensor(tuh_flat[t_idx], dtype=torch.float32).to(DEVICE)) z_p2s = model.embed_windows( torch.tensor(p2s_flat[s_idx], dtype=torch.float32).to(DEVICE)) y_tuh = torch.tensor(tuh_flat_y[t_idx], dtype=torch.long).to(DEVICE) y_p2s = torch.tensor(p2s_flat_y[s_idx], dtype=torch.long).to(DEVICE) L2 = _supcon_loss(z_tuh, z_p2s, y_tuh, y_p2s) total_loss += LAM_L2 * L2; n_batches += 1 # ── L3 batch (P2 scalp ↔ P2 thalamic paired bridge SupCon) ───────── if p2_pair_Xs is not None and len(p2_pair_Xs) >= 8: idx = np.random.choice(len(p2_pair_Xs), min(64, len(p2_pair_Xs)), replace=False) z_s = model.embed_windows( torch.tensor(p2_pair_Xs[idx], dtype=torch.float32).to(DEVICE)) z_t = model.embed_windows( torch.tensor(p2_pair_Xt[idx], dtype=torch.float32).to(DEVICE)) y_b = torch.tensor(p2_pair_y[idx], dtype=torch.long).to(DEVICE) L3 = _supcon_loss(z_s, z_t, y_b, y_b) # same labels (same windows) total_loss += LAM_L3 * L3; n_batches += 1 if n_batches > 0: opt.zero_grad(); total_loss.backward(); opt.step() if (ep+1) % 15 == 0: log(f" ep {ep+1}/{epochs} loss={total_loss.item()/n_batches:.4f}") model.eval() return model # ══════════════════════════════════════════════════════════════════════════════ # Standard TSM utilities (identical across all scripts) # ══════════════════════════════════════════════════════════════════════════════ def pretrain_tsm_only(model, thal_sessions, epochs, lr=SEQ_LR): seqs = [] for sess in thal_sessions: if len(sess) < N_CTX + 2: continue for i in range(N_CTX + 1, len(sess)): seqs.append(sess[i - N_CTX - 1: i]) if len(seqs) < 10: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.DataLoader( torch.utils.data.TensorDataset( torch.tensor(seqs[:, :N_CTX]), torch.tensor(seqs[:, N_CTX])), batch_size=128, shuffle=True) opt = torch.optim.Adam(model.parameters(), lr=lr) model.train() for _ in range(epochs): for xc, xt in ds: xc, xt = xc.to(DEVICE), xt.to(DEVICE) pred = model(xc)[:, -1, :] loss = (1. - F.cosine_similarity(pred, xt, dim=1).mean()) + 0.5*F.mse_loss(pred, xt) opt.zero_grad(); loss.backward(); opt.step() model.eval(); return model def build_seqs(patient, scaler): X_n = scaler.transform(patient['X'].astype(np.float32)) y = patient['labels'].astype(np.int32) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i-N_CTX:i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def encode(model, seqs): model.eval(); z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:,-1,:].cpu().numpy()) return np.vstack(z) def kshot_eval(model, seqs, lbls, K, n_trials=N_TRIALS): Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') if K == 0: pp = Z[lbls==1].mean(0); pb = Z[lbls==0].mean(0) return float(f1_score(lbls, (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int), zero_division=0)) scores = [] for _ in range(n_trials): pos = np.where(lbls==1)[0]; neg = np.where(lbls==0)[0] if len(pos) P2+A2+A4 scalp --(L3)--> thalamic bridge --(L1)--> 8+B2+B3 patients') log('=' * 60) # ── Step 1: Thalamic patients ──────────────────────────────────────────── log('\nStep 1: Loading thalamic patients...') log(f' Restricted to confirmed LT/LTP patients: {THAL_PIDS}') log(' (P6/P9-P14 excluded: wrong-hemisphere or non-thalamic contacts)') _V3 = Path(__file__).parent / "dactrl_v3_episodic_protonet.py" _v3g = {'__file__': str(_V3)} with open(_V3, 'r', errors='replace') as _f: exec(compile(_f.read().replace("if __name__=='__main__':", "if __name__=='__never__':"), str(_V3), 'exec'), _v3g) load_all_seeg = _v3g['load_all_seeg'] meta_df = pd.read_excel(METADATA) meta_thal = meta_df[meta_df['Patient ID'].isin(THAL_PIDS)].copy() raw_data = load_all_seeg(meta_thal) patients = [{'pid': p, 'nucleus': NUCLEUS_MAP.get(p,'UNK'), 'X': d['X'], 'labels': d['y_temporal']} for p, d in raw_data.items() if d['y_temporal'].sum() > 0] log(f' {len(patients)} institutional thalamic patients.') # Add GTC B2/B3 (thalamic-only LTP1-6) to L1 pool log(' Adding GTC B2/B3 thalamic data (LTP1-LTP6, approximate labels)...') for gtc_b in ['B2.edf', 'B3.edf']: Xb, yb = extract_gtc_thalamic_features(gtc_b) if Xb is not None: patients.append({'pid': gtc_b.replace('.edf',''), 'nucleus': 'ANT_GTC', 'X': Xb, 'labels': yb}) log(f' {gtc_b}: {len(Xb)} windows added to L1 pool') log(f' Total thalamic pool: {len(patients)} sources') # ── Step 2: Bridge feature extraction (P2 + GTC A2/A4 simultaneous scalp+thalamic) log('\nStep 2: Extracting bridge features (simultaneous scalp + thalamic)...') log(' P2: institutional bridge (LT1-LT2 + Fz/Cz/C3/F3)') p2_Xs, p2_Xt, p2_y = extract_p2_paired_features(meta_df) HAS_P2 = p2_Xs is not None and len(p2_Xs) >= 10 if HAS_P2: log(f' P2 bridge: {len(p2_Xs)} windows (PGES={p2_y.sum()}, base={(p2_y==0).sum()})') else: log(' [WARN] P2 paired data unavailable') # GTC A2/A4: new bridge patients (LT1-LT8 + full scalp 10-20, heuristic labels) log(' GTC A2/A4: new bridge patients (LT1-LT8 + scalp 10-20, heuristic PGES labels)') all_bridge_Xs, all_bridge_Xt, all_bridge_y = [], [], [] if HAS_P2: all_bridge_Xs.append(p2_Xs); all_bridge_Xt.append(p2_Xt); all_bridge_y.append(p2_y) for gtc_a in ['A2.edf', 'A4.edf']: aXs, aXt, ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: all_bridge_Xs.append(aXs); all_bridge_Xt.append(aXt); all_bridge_y.append(ay) if all_bridge_Xs: bridge_Xs = np.vstack(all_bridge_Xs) bridge_Xt = np.vstack(all_bridge_Xt) bridge_y = np.concatenate(all_bridge_y) log(f' Combined bridge: {len(bridge_y)} windows ' f'(PGES={bridge_y.sum()}, base={(bridge_y==0).sum()}) ' f'from P2+A2+A4') HAS_BRIDGE = True else: bridge_Xs = bridge_Xt = bridge_y = None HAS_BRIDGE = False # Scalp pool for L2: P2 scalp + P10/P12 scalp (confirmed scalp-only SEEG patients) log(' P10/P12: scalp-only SEEG patients (no thalamic contact) — scalp pool for L2') inst_scalp_Xs, inst_scalp_y = [], [] if HAS_P2: inst_scalp_Xs.append(p2_Xs); inst_scalp_y.append(p2_y) for pid_s in ['P10', 'P12']: Xs_s, y_s = extract_scalp_pges_windows(pid_s, meta_df) if Xs_s is not None: inst_scalp_Xs.append(Xs_s); inst_scalp_y.append(y_s) # Also add GTC A2/A4 scalp to pool if all_bridge_Xs: inst_scalp_Xs.append(bridge_Xs); inst_scalp_y.append(bridge_y) if inst_scalp_Xs: inst_scalp_Xs = np.vstack(inst_scalp_Xs) inst_scalp_y = np.concatenate(inst_scalp_y) log(f' Total scalp pool (L2): {len(inst_scalp_Xs)} windows ' f'(PGES={inst_scalp_y.sum()}, base={(inst_scalp_y==0).sum()}) ' f'from P2+P10+P12+A2+A4') HAS_INST_SCALP = True else: inst_scalp_Xs = inst_scalp_y = None HAS_INST_SCALP = False # ── Step 3: TUH topology features ──────────────────────────────────────── log('\nStep 3: Building TUH corpus (topology-informed: Fz/Cz/C3/F3)...') csvs = [f for f in glob.glob(os.path.join(TUH_BASE, '**', '*.csv'), recursive=True) if 'worksheet' not in f.lower()] def _has_target(f): try: return any(t in open(f, errors='ignore').read() for t in ['tcsz','gnsz']) except: return False tgt_csvs = [f for f in csvs if _has_target(f)] tgt_pairs_tuh = [(f, f.replace('.csv','.edf')) for f in tgt_csvs if os.path.exists(f.replace('.csv','.edf'))] np.random.seed(42) if len(tgt_pairs_tuh) > MAX_TUH: idxs = np.random.choice(len(tgt_pairs_tuh), MAX_TUH, replace=False) tgt_pairs_tuh = [tgt_pairs_tuh[i] for i in idxs] log(f' Using {len(tgt_pairs_tuh)} TUH files.') def _with_timeout(fn, *args, timeout=240): result = [None, None] def _r(): try: result[0], result[1] = fn(*args) except: pass t = threading.Thread(target=_r, daemon=True) t.start(); t.join(timeout=timeout) return result[0], result[1] tuh_pges, tuh_base = [], [] skipped = 0 for k, (csv_p, edf_p) in enumerate(tgt_pairs_tuh): p, b = _with_timeout(extract_tuh_topo_features, edf_p, csv_p) if p: tuh_pges.extend(p) if b: tuh_base.extend(b) if p is None and b is None: skipped += 1 if (k+1) % 20 == 0: log(f' TUH {k+1}/{len(tgt_pairs_tuh)} | ' f'PGES sessions={len(tuh_pges)} | skipped={skipped}') log(f' TUH: {len(tuh_pges)} PGES sessions | {len(tuh_base)} base sessions | skipped={skipped}') HAS_TUH = len(tuh_pges) > 0 # Normalise TUH features for L2 (flat windows, class-labeled) if HAS_TUH: tuh_scaler = StandardScaler().fit( np.vstack([w for sess in tuh_pges+tuh_base for w in [sess]])) tuh_pges_n = [tuh_scaler.transform(s) for s in tuh_pges] tuh_base_n = [tuh_scaler.transform(s) for s in tuh_base] else: tuh_pges_n = tuh_base_n = [] # ── Step 4: LOSO ───────────────────────────────────────────────────────── log('\nStep 4: LOSO evaluation (5 conditions)...') log(' A = L1 only: Thalamic TSM (baseline)') log(' B = L1+L2: TSM + TUH/P2-scalp same-domain alignment') log(' C = L1+L3: TSM + P2 scalp↔thalamic bridge (no TUH scale)') log(' D = L1+L2+L3: All three — full integrated [MAIN]') log(' E = D + Day-0 temporal heuristic') results = {c: {k: [] for k in K_VALS} for c in ['A','B','C','D','E']} for fold_i, test_p in enumerate(patients): pid = test_p['pid'] train_ps = [p for p in patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) seqs, lbls = build_seqs(test_p, scaler) if seqs is None or lbls.sum() == 0: log(f' [{fold_i+1:02d}] {pid}: skip'); continue log(f'\n [{fold_i+1:02d}/{len(patients)}] Test={pid} ' f'PGES={int(lbls.sum())} Base={int((lbls==0).sum())}') # Thalamic baseline sessions (normalised with fold scaler) thal_sess = [] for p in train_ps: X_n = scaler.transform(p['X'].astype(np.float32)) sess = X_n[p['labels'] == 0] if len(sess) >= N_CTX + 2: thal_sess.append(sess) # Bridge data for L3 — combined P2 + GTC A2/A4 (exclude P2 fold if test=P2) # GTC patients are external so always available regardless of fold fold_bridge_Xs_parts, fold_bridge_Xt_parts, fold_bridge_y_parts = [], [], [] if HAS_P2 and pid != 'P2': fold_bridge_Xs_parts.append(p2_Xs) fold_bridge_Xt_parts.append(p2_Xt) fold_bridge_y_parts.append(p2_y) # Always add GTC A2/A4 bridge (external dataset, no LOSO conflict) for gtc_a in ['A2.edf', 'A4.edf']: aXs, aXt, ay = extract_gtc_bridge_features(gtc_a) if aXs is not None: fold_bridge_Xs_parts.append(aXs) fold_bridge_Xt_parts.append(aXt) fold_bridge_y_parts.append(ay) if fold_bridge_Xs_parts: fold_bridge_Xs = np.vstack(fold_bridge_Xs_parts) fold_bridge_Xt = np.vstack(fold_bridge_Xt_parts) fold_bridge_y = np.concatenate(fold_bridge_y_parts) fold_bridge_Xs_n = tuh_scaler.transform(fold_bridge_Xs) if HAS_TUH else fold_bridge_Xs fold_bridge_Xt_n = scaler.transform(fold_bridge_Xt) HAS_FOLD_BRIDGE = True else: fold_bridge_Xs_n = fold_bridge_Xt_n = fold_bridge_y = None HAS_FOLD_BRIDGE = False # Institutional scalp pool for L2 (P2+P10+P12+A2+A4 minus test patient) if HAS_INST_SCALP: inst_fold_Xs = [] inst_fold_y = [] for pid_s in ['P2', 'P10', 'P12']: if pid_s == pid: continue if pid_s == 'P2' and HAS_P2: inst_fold_Xs.append(p2_Xs); inst_fold_y.append(p2_y) else: Xs_tmp, y_tmp = extract_scalp_pges_windows(pid_s, meta_df) if Xs_tmp is not None: inst_fold_Xs.append(Xs_tmp); inst_fold_y.append(y_tmp) # GTC A2/A4 scalp always included if HAS_FOLD_BRIDGE: inst_fold_Xs.append(fold_bridge_Xs) inst_fold_y.append(fold_bridge_y) if inst_fold_Xs: inst_fold_Xs = np.vstack(inst_fold_Xs) inst_fold_y = np.concatenate(inst_fold_y) inst_fold_Xs_n = tuh_scaler.transform(inst_fold_Xs) if HAS_TUH else inst_fold_Xs else: inst_fold_Xs_n = inst_fold_y = None else: inst_fold_Xs_n = inst_fold_y = None # TUH normalised sessions (scaled once globally — consistent across folds) fold_tuh_pges = tuh_pges_n if HAS_TUH else [] fold_tuh_base = tuh_base_n if HAS_TUH else [] # ── Condition A: L1 only (thalamic TSM baseline) ───────────────── m_A = CausalTransformer().to(DEVICE) m_A = pretrain_three_source(m_A, thal_sess, None, None, None, None, None, None, None, scaler, epochs=SEQ_EP_THAL, conditions='A') res_A = {k: kshot_eval(m_A, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['A'][k].append(res_A[k]) log(f' A: K=0={res_A[0]:.4f} K=10={res_A[10]:.4f}') del m_A # ── Condition B: L1+L2 (TSM + TUH ↔ P2+P10+P12 scalp alignment) ─ # L2 uses all 3 institutional patients' scalp PGES as anchor for TUH m_B = CausalTransformer().to(DEVICE) m_B = pretrain_three_source(m_B, thal_sess, fold_tuh_pges, fold_tuh_base, inst_fold_Xs_n, inst_fold_y, # L2: scalp pool None, None, None, # L3: no bridge scaler, epochs=SEQ_EP_PRETRAIN, conditions='AB') res_B = {k: kshot_eval(m_B, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['B'][k].append(res_B[k]) log(f' B: K=0={res_B[0]:.4f} K=10={res_B[10]:.4f} [L1+L2: TUH↔P2+P10+P12 scalp]') del m_B # ── Condition C: L1+L3 (TSM + bridge: P2 + GTC A2/A4) ──────────── if HAS_FOLD_BRIDGE: m_C = CausalTransformer().to(DEVICE) m_C = pretrain_three_source(m_C, thal_sess, None, None, None, None, fold_bridge_Xs_n, fold_bridge_Xt_n, fold_bridge_y, scaler, epochs=SEQ_EP_PRETRAIN, conditions='AC') res_C = {k: kshot_eval(m_C, seqs, lbls, k) for k in K_VALS} del m_C else: res_C = res_A for k in K_VALS: results['C'][k].append(res_C[k]) log(f' C: K=0={res_C[0]:.4f} K=10={res_C[10]:.4f} ' f'[L1+L3: P2+A2+A4 bridge{"" if HAS_FOLD_BRIDGE else " — fallback to A"}]') # ── Condition D: L1+L2+L3 — full chain [MAIN] ──────────────────── m_D = CausalTransformer().to(DEVICE) m_D = pretrain_three_source(m_D, thal_sess, fold_tuh_pges, fold_tuh_base, inst_fold_Xs_n, inst_fold_y, fold_bridge_Xs_n if HAS_FOLD_BRIDGE else None, fold_bridge_Xt_n if HAS_FOLD_BRIDGE else None, fold_bridge_y if HAS_FOLD_BRIDGE else None, scaler, epochs=SEQ_EP_PRETRAIN, conditions='ABCD' if HAS_FOLD_BRIDGE else 'AB') res_D = {k: kshot_eval(m_D, seqs, lbls, k) for k in K_VALS} for k in K_VALS: results['D'][k].append(res_D[k]) log(f' D: K=0={res_D[0]:.4f} K=10={res_D[10]:.4f} [L1+L2+L3]') # ── Condition E: D + Day-0 heuristic ───────────────────────────── pgs = np.where(np.diff(np.concatenate([[0], lbls])) == 1)[0] if len(pgs) > 0: onset = pgs[0]; pend = onset while pend < len(lbls) and lbls[pend] == 1: pend += 1 auto_p = np.arange(onset, min(onset+10, pend)) base_idx = np.where(lbls==0)[0] pre_b = base_idx[base_idx < onset] if len(pre_b) < 10: pre_b = base_idx[:10] auto_b = pre_b[-10:] Z = encode(m_D, seqs) pp = Z[auto_p].mean(0); pb = Z[auto_b].mean(0) preds = (np.linalg.norm(Z-pp,axis=1) < np.linalg.norm(Z-pb,axis=1)).astype(int) results['E'][0].append(float(f1_score(lbls, preds, zero_division=0))) for k in [2,5,10]: results['E'][k].append(res_D[k]) else: for k in K_VALS: results['E'][k].append(res_D[k]) log(f' E: K=0={results["E"][0][-1]:.4f} K=10={results["E"][10][-1]:.4f}') del m_D gc.collect() if DEVICE.type == 'cuda': torch.cuda.empty_cache() elif DEVICE.type == 'mps': torch.mps.empty_cache() # ── Results ─────────────────────────────────────────────────────────────── log('\n' + '='*60) log('=== C13: Three-Source Integrated Contrastive Results ===') log('='*60) labels_m = { 'A': 'A: L1 only — Thalamic TSM (baseline)', 'B': 'B: L1+L2 — TSM + TUH/P2 scalp align', 'C': 'C: L1+L3 — TSM + P2+A2+A4 bridge (expanded)', 'D': 'D: L1+L2+L3 — Full integrated +GTC [MAIN]', 'E': 'E: D + Day-0 heuristic', } cond_means = {c: {k: np.nanmean(results[c][k]) for k in K_VALS} for c in ['A','B','C','D','E']} log(f"{'Condition':<50} {'K=0':>6} {'K=2':>6} {'K=5':>6} {'K=10':>6}") log('-'*75) for cond in ['A','B','C','D','E']: m = cond_means[cond] log(f"{labels_m[cond]:<50} {m[0]:>6.4f} {m[2]:>6.4f} {m[5]:>6.4f} {m[10]:>6.4f}") log('\nGain over A:') for cond in ['B','C','D','E']: log(f" {cond}: " + " ".join( f"K={k}: {cond_means[cond][k]-cond_means['A'][k]:+.4f}" for k in K_VALS)) # Wilcoxon D vs A d14 = [v for v in results['D'][10] if not np.isnan(v)] a14 = [v for v in results['A'][10] if not np.isnan(v)] if len(d14) >= 5: try: stat, pv = wilcoxon(d14[:len(a14)], a14[:len(d14)]) log(f'\nWilcoxon D vs A (K=10): stat={stat:.3f} p={pv:.4f}') except Exception as e: log(f'\nWilcoxon: {e}') # Save np.save(str(OUT_ROOT/'results_raw.npy'), results) rows = [] for c in ['A','B','C','D','E']: for k in K_VALS: v = [x for x in results[c][k] if not np.isnan(x)] rows.append({'cond':c,'K':k,'mean':np.mean(v) if v else np.nan, 'std':np.std(v) if v else np.nan,'n':len(v)}) pd.DataFrame(rows).to_csv(str(OUT_ROOT/'results_summary.csv'), index=False) # Figure fig, ax = plt.subplots(figsize=(10, 5)) colors = {'A':'#7f8c8d','B':'#e74c3c','C':'#f39c12','D':'#27ae60','E':'#2980b9'} for cond in ['A','B','C','D','E']: vals = [cond_means[cond][k] for k in K_VALS] ax.plot(K_VALS, vals, 'o-', color=colors[cond], label=labels_m[cond], linewidth=2, markersize=8) ax.set_xlabel('K'); ax.set_ylabel('F1') ax.set_title('C13: Three-Source Integrated Contrastive\n' 'L1(TSM) + L2(TUH↔P2 scalp) + L3(P2 scalp↔thalamic bridge)') ax.legend(fontsize=8, loc='lower right'); ax.grid(alpha=0.3); ax.set_ylim(0.5, 1.05) plt.tight_layout() fig.savefig(str(OUT_ROOT/'c13_three_source.png'), dpi=150) plt.close() log(f'Figure saved -> {OUT_ROOT}/c13_three_source.png') log('COMPLETE')