""" DACTRL — Cross-Region sEEG Generalization Experiment ====================================================== Platform vision: Does the DACTRL-TSM model (trained on thalamic LFP) detect PGES in other brain regions recorded in the same SEEG session? The SEEG EDF files contain simultaneous recordings from: - Thalamus: LT1-LT16 (current system — DBS contact) - Hippocampus: LAH/LPH/RAH channels - Amygdala: LA/RA channels - Orbitofrontal: LAOF/LPOF/RAOF/RPOF channels - Cingulate cortex: LAC/RAC channels Protocol: 1. Load same SEEG EDF files used in main pipeline 2. Extract bipolar LFP from each brain region (same 17 features) 3. Two tests: A: Zero-shot — use thalamic-trained model directly on other region channels B: Same-region LOSO — train AND test on same non-thalamic region This answers: does PGES generalise across brain regions (thalamocortical collapse)? """ import os, sys, warnings, copy, threading from pathlib import Path from datetime import datetime import numpy as np import pandas as pd import matplotlib; matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score import torch import torch.nn as nn import torch.nn.functional as F warnings.filterwarnings('ignore') DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') print(f"[{'GPU' if torch.cuda.is_available() else 'CPU'}]", flush=True) SEEG_ROOT = Path(r"G:\PHD Datasets\Data\Thalamus\SEEG Seizure Data") METADATA = SEEG_ROOT / "metadata_SEEG.xlsx" OUT_DIR = Path(r"D:\Projects\phd\PSEG\pges_toolkit\results\dactrl_cross_region") OUT_DIR.mkdir(parents=True, exist_ok=True) WIN_SEC = 5 N_CTX = 8 SEQ_EP = 100 SEQ_LR = 3e-4 D_MODEL = 64 N_HEADS = 4 N_LAYERS = 4 K_VALS = [0, 2, 5, 10] # Brain region channel prefixes present in these EDF files REGIONS = { 'Thalamus': ['LT'], # baseline — current system 'Hippocampus': ['LAH', 'LPH'], # left anterior + posterior hippocampus 'Amygdala': ['LA'], # left amygdala 'Orbitofrontal': ['LAOF', 'LPOF'],# left orbitofrontal cortex 'Cingulate': ['LAC'], # left anterior cingulate } def log(msg): print(f'[{datetime.now().strftime("%H:%M:%S")}] {msg}', flush=True) # ── Model (identical to main pipeline) ─────────────────────────────────────── class CausalTransformer(nn.Module): def __init__(self, n_feat=17): super().__init__() self.proj = nn.Linear(n_feat, D_MODEL) enc = nn.TransformerEncoderLayer(D_MODEL, N_HEADS, D_MODEL*4, dropout=0.1, batch_first=True) self.enc = nn.TransformerEncoder(enc, N_LAYERS) self.head = nn.Linear(D_MODEL, n_feat) mask = torch.triu(torch.ones(N_CTX, N_CTX), 1).bool() self.register_buffer('mask', mask) def forward(self, x, return_hidden=False): B, T, _ = x.shape h = self.enc(self.proj(x), mask=self.mask[:T, :T]) if return_hidden: return h return self.head(h) # ── Feature extraction (same 17 features as main pipeline) ─────────────────── def compute_features_from_signal(sig, fs): from numpy.fft import rfft, rfftfreq sig = sig - sig.mean() n = len(sig) rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs) psd = np.abs(rfft(sig))**2 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0. total = float(psd.sum()) + 1e-10 delta = band(0.5,4); theta=band(4,8); alpha=band(8,13) beta = band(13,30); gamma=band(80,150) sr = (delta+theta)/(alpha+beta+1e-10) p = psd/(psd.sum()+1e-10); p=p[p>0] shan = float(-np.sum(p*np.log(p+1e-10))) supp = float(np.mean(np.abs(sig)<0.05*np.max(np.abs(sig)+1e-10))) u = sig[:min(200,n)] r = 0.2*np.std(u)+1e-10 def phi(mm): x2=np.array([u[i:i+mm] for i in range(len(u)-mm+1)]) C=np.sum(np.max(np.abs(x2[:,None]-x2[None,:]),axis=2)<=r,axis=0)/(len(u)-mm+1) return np.sum(np.log(C+1e-10))/(len(u)-mm+1) apen = float(abs(phi(2)-phi(3))) sampen = 0.5 # simplified for speed in cross-region hist,_ = np.histogram(sig, bins=10) hist = hist/(hist.sum()+1e-10) etc = float(-np.sum(hist[hist>0]*np.log(hist[hist>0]))) bsig = (sig>np.median(sig)).astype(np.uint8) _b = bsig.astype(np.int16) _idx4 = _b[:-3]*8 + _b[1:-2]*4 + _b[2:-1]*2 + _b[3:] lzc = float(len(np.unique(_idx4))) _wins = np.lib.stride_tricks.sliding_window_view(sig.astype(np.float64), 3) _ranks = np.argsort(_wins, axis=1, kind='quicksort') _penc = _ranks[:,0]*4 + _ranks[:,1]*2 + _ranks[:,2] _,_cnts = np.unique(_penc, return_counts=True) _tot = float(_cnts.sum()) pent = float(-np.sum((_cnts/_tot)*np.log(_cnts/_tot+1e-10))) return np.array([rms,ll,zc,var,delta/total,theta/total,alpha/total,beta/total, sr,shan,supp,apen,sampen,etc,lzc,pent,gamma/total],dtype=np.float32) def _edf_read_bipolar(fp, ch1_name, ch2_name, tmin_sec, tmax_sec): """ Read a bipolar channel pair from an EDF using pyedflib random access. Returns (sig_uV, fs) or (None, None). O(window) — handles 276ch files instantly. """ import pyedflib f = pyedflib.EdfReader(str(fp)) try: fs = f.getSampleFrequency(0) chs = f.getSignalLabels() idx1 = next((i for i,c in enumerate(chs) if c == ch1_name), None) idx2 = next((i for i,c in enumerate(chs) if c == ch2_name), None) if idx1 is None or idx2 is None: return None, None ns = f.getNSamples() fs1 = f.getSampleFrequency(idx1) n_tot1 = ns[idx1]; n_tot2 = ns[idx2] s_start = max(0, int(tmin_sec * fs1)) s_end1 = min(n_tot1, int(tmax_sec * fs1)) s_end2 = min(n_tot2, int(tmax_sec * fs1)) n1 = s_end1 - s_start; n2 = s_end2 - s_start if n1 <= 0 or n2 <= 0: return None, None d1 = f.readSignal(idx1, start=s_start, n=n1) d2 = f.readSignal(idx2, start=s_start, n=min(n2, len(d1))) n_min = min(len(d1), len(d2)) return (d1[:n_min] - d2[:n_min]) * 1e6, fs1 finally: f.close() def _edf_find_region_pair(fp, region_prefixes): """Return first two channel names matching region_prefixes, or (None, None).""" import pyedflib f = pyedflib.EdfReader(str(fp)) chs = f.getSignalLabels() f.close() matched = [c for c in chs if any(c.upper().startswith(p.upper()) for p in region_prefixes)] return (matched[0], matched[1]) if len(matched) >= 2 else (None, None) def load_patient_region(pid, meta_df, region_prefixes, thal_contact): """ Load EDF files for a patient, extract features from the specified brain region. Returns (X, labels) where labels=1 is PGES, 0 is baseline. Uses pyedflib for O(window) random access — no sequential scan of large files. """ patient_dir = SEEG_ROOT / f"{pid}_SEEG" if not patient_dir.exists(): return None, None baseline_X = [] pat_meta_base = meta_df[meta_df['Patient ID'] == pid] for _, row in pat_meta_base.iterrows(): if str(row['Seizure_Type']) not in ['FBTCS']: continue fp = patient_dir / row['Seizure_Filename'] if not fp.exists(): continue try: sz_onset = float(row['Seizure_Onset_Sec']) ch1, ch2 = _edf_find_region_pair(fp, region_prefixes) if ch1 is None: continue tmin = max(0.0, sz_onset - 125.0) tmax = sz_onset - 25.0 sig, fs = _edf_read_bipolar(fp, ch1, ch2, tmin, tmax) if sig is None: continue win_samp = int(WIN_SEC * fs) pre_end = max(0, int((sz_onset - 30 - tmin) * int(fs))) for i in range(0, pre_end - win_samp, win_samp): seg = sig[i:i+win_samp] if len(seg) == win_samp: baseline_X.append(compute_features_from_signal(seg, fs)) except Exception: continue seiz_X, seiz_lbls = [], [] pat_meta = meta_df[meta_df['Patient ID'] == pid] for _, row in pat_meta.iterrows(): fp = patient_dir / row['Seizure_Filename'] if not fp.exists(): continue if str(row['Seizure_Type']) not in ['FBTCS']: continue try: sz_onset = float(row['Seizure_Onset_Sec']) sz_end = float(row['Seizure_Offset_Sec']) ch1, ch2 = _edf_find_region_pair(fp, region_prefixes) if ch1 is None: continue # Post-ictal window sig_pi, fs = _edf_read_bipolar(fp, ch1, ch2, sz_end + 5.0, sz_end + 305.0) if sig_pi is not None: win_samp = int(WIN_SEC * fs) for i in range(0, len(sig_pi) - win_samp, win_samp): seg = sig_pi[i:i+win_samp] if len(seg) == win_samp: seiz_X.append(compute_features_from_signal(seg, fs)) seiz_lbls.append(1) # Pre-ictal baseline from seizure file sig_pre, fs = _edf_read_bipolar(fp, ch1, ch2, max(0, sz_onset - 130.0), sz_onset - 10.0) if sig_pre is not None: win_samp = int(WIN_SEC * fs) for i in range(0, len(sig_pre) - win_samp, win_samp): seg = sig_pre[i:i+win_samp] if len(seg) == win_samp: seiz_X.append(compute_features_from_signal(seg, fs)) seiz_lbls.append(0) except Exception: continue if not seiz_X: return None, None X_base = np.array(baseline_X, dtype=np.float32) if baseline_X else np.zeros((0,17),dtype=np.float32) X_seiz = np.array(seiz_X, dtype=np.float32) y_seiz = np.array(seiz_lbls, dtype=np.int32) y_base = np.zeros(len(X_base), dtype=np.int32) X_all = np.vstack([X_base, X_seiz]) if len(X_base) > 0 else X_seiz y_all = np.concatenate([y_base, y_seiz]) if len(X_base) > 0 else y_seiz return X_all, y_all def build_seqs(X, y, scaler): X_n = scaler.transform(X) seqs, lbls = [], [] for i in range(N_CTX, len(X_n)): seqs.append(X_n[i-N_CTX:i]); lbls.append(y[i]) if not seqs: return None, None return np.array(seqs, dtype=np.float32), np.array(lbls, dtype=np.int32) def encode(model, seqs): model.eval(); z = [] for i in range(0, len(seqs), 64): b = torch.tensor(seqs[i:i+64], dtype=torch.float32).to(DEVICE) with torch.no_grad(): z.append(model(b, return_hidden=True)[:,-1,:].cpu().numpy()) return np.vstack(z) def pretrain_thalamic(model, train_data, scaler, epochs=SEQ_EP): seqs = [] for X, y in train_data: X_n = scaler.transform(X) base = X_n[y == 0] for i in range(N_CTX+1, len(base)): seqs.append(base[i-N_CTX-1:i]) if len(seqs) < 10: return model seqs = np.array(seqs, dtype=np.float32) ds = torch.utils.data.TensorDataset( torch.tensor(seqs[:,:N_CTX]), torch.tensor(seqs[:,-1])) ld = torch.utils.data.DataLoader(ds, batch_size=64, shuffle=True) opt = torch.optim.Adam(model.parameters(), lr=SEQ_LR) model.train() for _ in range(epochs): for xc, xt in ld: xc,xt = xc.to(DEVICE),xt.to(DEVICE) pred = model(xc)[:,-1,:] loss = (1.-F.cosine_similarity(pred,xt,dim=1).mean())+0.5*F.mse_loss(pred,xt) opt.zero_grad(); loss.backward(); opt.step() model.eval(); return model def kshot_eval(model, seqs, lbls, K, n_trials=5): Z = encode(model, seqs) if lbls.sum() == 0: return float('nan') if K == 0: pp=Z[lbls==1].mean(0); pb=Z[lbls==0].mean(0) preds=(np.linalg.norm(Z-pp,axis=1)0 and p!='P13'] log(f'Loaded {len(thal_patients)} thalamic patients.') # Thalamic contacts per patient (from metadata) thal_contacts = meta_df.groupby('Patient ID')['TH_Contact'].first().to_dict() # ── Load each brain region for each patient ─────────────────────────────── log('\nLoading brain region data...') region_data = {r: {} for r in REGIONS} for tp in thal_patients: pid = tp['pid'] contact = thal_contacts.get(pid, 'LT2-LT3') pid_timeouts = 0 for region, prefixes in REGIONS.items(): if pid_timeouts >= 3: log(f' [SKIP] {pid} {region} — 3 consecutive timeouts, skipping patient') continue result = [None, None] def _run(p=pid, pf=prefixes, c=contact): result[0], result[1] = load_patient_region(p, meta_df, pf, c) t = threading.Thread(target=_run, daemon=True); t.start(); t.join(timeout=120) if t.is_alive(): log(f' [TIMEOUT] {pid} {region} — skipping'); X, y = None, None; pid_timeouts += 1 else: X, y = result[0], result[1]; pid_timeouts = 0 if X is not None and y.sum() > 0: region_data[region][pid] = {'X': X, 'labels': y} log(f' {pid} {region}: {len(X)} windows, {y.sum()} PGES') # ── Experiment A: Zero-shot cross-region ───────────────────────────────── # Train thalamic model (LOSO), apply directly to other regions log('\n--- Test A: Zero-shot cross-region (thalamic model → other regions) ---') results_zero = {r: {k: [] for k in K_VALS} for r in REGIONS} for fold_i, test_p in enumerate(thal_patients): pid = test_p['pid'] train_ps = [p for p in thal_patients if p['pid'] != pid] X_tr = np.vstack([p['X'].astype(np.float32) for p in train_ps]) scaler = StandardScaler().fit(X_tr) # Train thalamic model model = CausalTransformer().to(DEVICE) model = pretrain_thalamic(model, [(p['X'],p['labels']) for p in train_ps], scaler) for region in REGIONS: if pid not in region_data[region]: for k in K_VALS: results_zero[region][k].append(float('nan')) continue rd = region_data[region][pid] seqs, lbls = build_seqs(rd['X'], rd['labels'], scaler) if seqs is None or lbls.sum() == 0: for k in K_VALS: results_zero[region][k].append(float('nan')) continue for k in K_VALS: results_zero[region][k].append(kshot_eval(model, seqs, lbls, k)) del model log(f' [{fold_i+1:02d}] {pid} done') # ── Experiment B: Same-region LOSO ─────────────────────────────────────── log('\n--- Test B: Same-region LOSO (train and test on same non-thalamic region) ---') results_same = {r: {k: [] for k in K_VALS} for r in REGIONS} for region in REGIONS: rdata = region_data[region] pids = list(rdata.keys()) if len(pids) < 3: log(f' {region}: insufficient patients ({len(pids)}), skip') continue log(f' Region: {region} ({len(pids)} patients)') for test_pid in pids: train_pids = [p for p in pids if p != test_pid] X_tr = np.vstack([rdata[p]['X'].astype(np.float32) for p in train_pids]) scaler = StandardScaler().fit(X_tr) model = CausalTransformer().to(DEVICE) model = pretrain_thalamic(model, [(rdata[p]['X'], rdata[p]['labels']) for p in train_pids], scaler) seqs, lbls = build_seqs(rdata[test_pid]['X'], rdata[test_pid]['labels'], scaler) if seqs is None or lbls.sum() == 0: for k in K_VALS: results_same[region][k].append(float('nan')) del model; continue for k in K_VALS: results_same[region][k].append(kshot_eval(model, seqs, lbls, k)) del model # ── Summary ─────────────────────────────────────────────────────────────── log('\n' + '='*60) log('=== Cross-Region Results (K=10 F1) ===') log(f'{"Region":<20} {"Zero-shot K=0":>14} {"Zero-shot K=10":>14} {"Same-region K=10":>16}') log('-'*66) for region in REGIONS: zk0 = [v for v in results_zero[region][0] if not np.isnan(v)] zk10 = [v for v in results_zero[region][10] if not np.isnan(v)] sk10 = [v for v in results_same[region][10] if not np.isnan(v)] log(f'{region:<20} {np.mean(zk0) if zk0 else float("nan"):>14.4f} ' f'{np.mean(zk10) if zk10 else float("nan"):>14.4f} ' f'{np.mean(sk10) if sk10 else float("nan"):>16.4f}') # Save and plot np.save(str(OUT_DIR/'cross_region_zero_shot.npy'), results_zero) np.save(str(OUT_DIR/'cross_region_same.npy'), results_same) regions_plot = list(REGIONS.keys()) k10_zero = [np.nanmean(results_zero[r][10]) for r in regions_plot] k10_same = [np.nanmean(results_same[r][10]) for r in regions_plot] x = np.arange(len(regions_plot)) fig, ax = plt.subplots(figsize=(10,5)) ax.bar(x-0.2, k10_zero, 0.35, label='Zero-shot (thalamic model)', color='steelblue') ax.bar(x+0.2, k10_same, 0.35, label='Same-region LOSO', color='darkorange') ax.axhline(0.898, color='green', ls='--', lw=1.5, label='Thalamic LOSO baseline (0.898)') ax.set_xticks(x); ax.set_xticklabels(regions_plot, fontsize=11) ax.set_ylabel('F1 Score (K=10)'); ax.set_ylim(0,1.05) ax.set_title('DACTRL Cross-Region sEEG Generalization') ax.legend(); ax.grid(axis='y', alpha=0.3) fig.tight_layout() fig.savefig(str(OUT_DIR/'cross_region_bar.png'), dpi=150) log(f'Figure saved -> {OUT_DIR}/cross_region_bar.png') log('COMPLETE')