""" DACTRL — Simultaneous Multi-Region Seizure Lifecycle Analysis ============================================================= Preictal / Ictal / Postictal classification across the thalamocortical network using simultaneous SEEG recordings (thalamus + hippocampus + amygdala + OFC + cingulate). Parts: A — Within-region 3-class phase classifier (LOSO, SVM+RBF per region) B — Cross-region zero-shot phase transfer (5×5 F1 matrix) C — Ictal propagation timing (per-seizure onset lag per region vs clinical label) D — TUH scalp → intracranial transfer (binary ictal/non-ictal, SVM) Window protocol: Preictal : [onset - 120s, onset - 10s] (110s, buffer avoids transition) Ictal : [onset, onset + min(dur, 120s)] (capped for balance) Postictal : [offset + 5s, offset + 125s] (120s, skip transition) All : 10s windows, 5s step (50% overlap) All 69 seizures used (FBTCS + FIAS + FAS + ES) — not restricted to PGES-producing. """ import os, sys, threading, time import numpy as np import pandas as pd import pathlib import pyedflib from datetime import datetime from sklearn.svm import SVC from sklearn.preprocessing import StandardScaler from sklearn.metrics import f1_score, confusion_matrix import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import warnings warnings.filterwarnings('ignore') # ── Paths ────────────────────────────────────────────────────────────────────── SEEG_ROOT = pathlib.Path(r'G:/PHD Datasets/Data/Thalamus/SEEG Seizure Data') METADATA = SEEG_ROOT / 'metadata_SEEG.xlsx' TUH_BASE = pathlib.Path(r'G:/PHD Datasets/Data/Scalp/tueeg_data/tuh_eeg_seizure/v2.0.3/edf') OUT_DIR = pathlib.Path(r'D:/Projects/phd/PSEG/pges_toolkit/results/dactrl_seizure_lifecycle') OUT_DIR.mkdir(parents=True, exist_ok=True) # ── Parameters ───────────────────────────────────────────────────────────────── WIN_SEC = 10.0 # window length STEP_SEC = 5.0 # step (50% overlap) PRE_DUR = 110.0 # preictal window duration (onset-120s to onset-10s) ICT_DUR = 120.0 # max ictal window capped for balance POST_DUR = 120.0 # postictal window duration PHASE_NAMES = ['Preictal', 'Ictal', 'Postictal'] LABEL_MAP = {'Preictal': 0, 'Ictal': 1, 'Postictal': 2} REGIONS = { 'Thalamus' : ['LT', 'RT'], 'Hippocampus' : ['LAH', 'LPH', 'RAH', 'RPH'], 'Amygdala' : ['LA', 'RA'], 'Orbitofrontal': ['LAOF', 'LPOF', 'RAOF', 'RPOF'], 'Cingulate' : ['LAC', 'RAC'], } TUH_ICTAL_LABELS = ('gnsz', 'tcsz', 'fnsz', 'cpsz', 'absz', 'mysz', 'tnsz') MAX_TUH_FILES = 200 # scalp training corpus size def log(msg): print(f'[{datetime.now().strftime("%H:%M:%S")}] {msg}', flush=True) # ── Feature extraction ───────────────────────────────────────────────────────── def compute_features(sig, fs): """17-feature handcrafted extractor (vectorized, no Python loops).""" from numpy.fft import rfft, rfftfreq sig = sig - sig.mean() n = len(sig) rms = float(np.sqrt(np.mean(sig**2))) ll = float(np.mean(np.abs(np.diff(sig)))) zc = float(np.sum(np.diff(np.sign(sig)) != 0)) var = float(np.var(sig)) freqs = rfftfreq(n, 1/fs) psd = np.abs(rfft(sig))**2 total = float(psd.sum()) + 1e-10 def band(lo, hi): idx = (freqs >= lo) & (freqs < hi) return float(psd[idx].sum()) if idx.any() else 0. delta = band(0.5, 4); theta = band(4, 8); alpha = band(8, 13) beta = band(13, 30); gamma = band(80, 150) sr = (delta + theta) / (alpha + beta + 1e-10) p = psd / (psd.sum() + 1e-10); p = p[p > 0] shan = float(-np.sum(p * np.log(p + 1e-10))) supp = float(np.mean(np.abs(sig) < 0.05 * np.max(np.abs(sig) + 1e-10))) u = sig[:min(200, n)]; r = 0.2 * np.std(u) + 1e-10 def phi(mm): x2 = np.array([u[i:i+mm] for i in range(len(u)-mm+1)]) C = np.sum(np.max(np.abs(x2[:,None]-x2[None,:]), axis=2) <= r, axis=0) / (len(u)-mm+1) return np.sum(np.log(C + 1e-10)) / (len(u) - mm + 1) apen = float(abs(phi(2) - phi(3))) hist, _ = np.histogram(sig, bins=10); hist = hist / (hist.sum() + 1e-10) etc = float(-np.sum(hist[hist > 0] * np.log(hist[hist > 0]))) bsig = (sig > np.median(sig)).astype(np.uint8) _b = bsig.astype(np.int16) _idx4 = _b[:-3]*8 + _b[1:-2]*4 + _b[2:-1]*2 + _b[3:] lzc = float(len(np.unique(_idx4))) _wins = np.lib.stride_tricks.sliding_window_view(sig.astype(np.float64), 3) _rnks = np.argsort(_wins, axis=1, kind='quicksort') _penc = _rnks[:,0]*4 + _rnks[:,1]*2 + _rnks[:,2] _, _cnts = np.unique(_penc, return_counts=True) _tot = float(_cnts.sum()) pent = float(-np.sum((_cnts/_tot) * np.log(_cnts/_tot + 1e-10))) return np.array([rms, ll, zc, var, delta/total, theta/total, alpha/total, beta/total, sr, shan, supp, apen, 0.5, etc, lzc, pent, gamma/total], dtype=np.float32) # ── EDF reading ──────────────────────────────────────────────────────────────── def _edf_read_channel(fp, ch_prefixes, tmin_sec, tmax_sec): """ Read a bipolar channel from EDF using pyedflib random access. Finds first two channels matching any prefix in ch_prefixes. Returns (signal_uV, fs) or (None, None). """ f = pyedflib.EdfReader(str(fp)) try: chs = f.getSignalLabels() matched = [c for c in chs if any(c.upper().startswith(p.upper()) for p in ch_prefixes)] if len(matched) < 2: return None, None idx1 = next(i for i, c in enumerate(chs) if c == matched[0]) idx2 = next(i for i, c in enumerate(chs) if c == matched[1]) ns = f.getNSamples() fs1 = f.getSampleFrequency(idx1) n1 = ns[idx1]; n2 = ns[idx2] s_s = max(0, int(tmin_sec * fs1)) s_e1 = min(n1, int(tmax_sec * fs1)) s_e2 = min(n2, int(tmax_sec * fs1)) r1 = s_e1 - s_s; r2 = s_e2 - s_s if r1 <= 0 or r2 <= 0: return None, None d1 = f.readSignal(idx1, start=s_s, n=r1) d2 = f.readSignal(idx2, start=s_s, n=min(r2, len(d1))) nm = min(len(d1), len(d2)) return (d1[:nm] - d2[:nm]) * 1e6, fs1 finally: f.close() def extract_phase_windows(fp, ch_prefixes, onset, offset, fs_hint=2048): """ Extract preictal/ictal/postictal feature windows from one seizure EDF. Returns dict: {phase_name: (N, 17) array or None} """ windows = {} win_s = WIN_SEC step_s = STEP_SEC phases = { 'Preictal' : (onset - 120.0, onset - 10.0), 'Ictal' : (onset, onset + min(offset - onset, ICT_DUR)), 'Postictal': (offset + 5.0, offset + 5.0 + POST_DUR), } for phase, (t0, t1) in phases.items(): if t1 - t0 < win_s: windows[phase] = None continue sig, fs = _edf_read_channel(fp, ch_prefixes, max(0.0, t0), t1) if sig is None or len(sig) < int(win_s * fs_hint): windows[phase] = None continue fs_i = int(round(fs)) win_samp = int(win_s * fs_i) stp_samp = int(step_s * fs_i) feats = [] for i in range(0, len(sig) - win_samp, stp_samp): seg = sig[i:i + win_samp] if len(seg) == win_samp: feats.append(compute_features(seg, fs_i)) windows[phase] = np.array(feats, dtype=np.float32) if feats else None return windows # ── Data loading ─────────────────────────────────────────────────────────────── def load_all_region_data(meta_df): """ Load preictal/ictal/postictal feature windows for all patients × all regions. Returns: dict pid → dict region → dict phase → (N,17) array """ log('Loading multi-region seizure lifecycle data...') pids = meta_df['Patient ID'].unique() all_data = {} for pid in sorted(pids): pat_dir = SEEG_ROOT / f'{pid}_SEEG' if not pat_dir.exists(): continue pat_rows = meta_df[meta_df['Patient ID'] == pid] pid_data = {r: {'Preictal': [], 'Ictal': [], 'Postictal': []} for r in REGIONS} for _, row in pat_rows.iterrows(): fp = pat_dir / str(row['Seizure_Filename']) if not fp.exists(): continue onset = float(row['Seizure_Onset_Sec']) offset = float(row['Seizure_Offset_Sec']) if offset - onset < 5.0: continue for region, prefixes in REGIONS.items(): try: wins = extract_phase_windows(fp, prefixes, onset, offset) for phase in PHASE_NAMES: if wins[phase] is not None and len(wins[phase]) >= 2: pid_data[region][phase].append(wins[phase]) except Exception: continue # Stack across seizures pid_stacked = {} for region in REGIONS: stacked = {} valid = True for phase in PHASE_NAMES: arrs = pid_data[region][phase] if arrs: stacked[phase] = np.vstack(arrs) else: stacked[phase] = None valid = False if valid: pid_stacked[region] = stacked else: pid_stacked[region] = None all_data[pid] = pid_stacked n_valid = sum(1 for r in REGIONS if pid_stacked[r] is not None) log(f' {pid}: {n_valid}/{len(REGIONS)} regions valid') return all_data # ── Part A: Within-region LOSO SVM ──────────────────────────────────────────── def part_A(all_data): """Within-region 3-class LOSO SVM classifier per region.""" log('\n' + '='*60) log('Part A: Within-region 3-class phase classifier (LOSO SVM)') log('='*60) results = {} # region → {macro_f1, per_class_f1, confusion} for region in REGIONS: pids_with_data = [p for p, d in all_data.items() if d.get(region) is not None] if len(pids_with_data) < 4: log(f' {region}: only {len(pids_with_data)} patients — skip') continue macro_f1s, all_true, all_pred = [], [], [] for test_pid in pids_with_data: train_pids = [p for p in pids_with_data if p != test_pid] # Build train set X_tr, y_tr = [], [] for pid in train_pids: for lbl, phase in enumerate(PHASE_NAMES): arr = all_data[pid][region][phase] if arr is not None: X_tr.append(arr) y_tr.extend([lbl] * len(arr)) if not X_tr: continue X_tr = np.vstack(X_tr); y_tr = np.array(y_tr) # Build test set X_te, y_te = [], [] for lbl, phase in enumerate(PHASE_NAMES): arr = all_data[test_pid][region][phase] if arr is not None: X_te.append(arr); y_te.extend([lbl] * len(arr)) if not X_te: continue X_te = np.vstack(X_te); y_te = np.array(y_te) scaler = StandardScaler() X_tr_n = scaler.fit_transform(X_tr) X_te_n = scaler.transform(X_te) clf = SVC(kernel='rbf', C=10, gamma='scale', class_weight='balanced') clf.fit(X_tr_n, y_tr) y_pred = clf.predict(X_te_n) mf1 = f1_score(y_te, y_pred, average='macro', zero_division=0) macro_f1s.append(mf1) all_true.extend(y_te); all_pred.extend(y_pred) if not macro_f1s: continue mean_f1 = float(np.mean(macro_f1s)) std_f1 = float(np.std(macro_f1s)) per_class = f1_score(all_true, all_pred, average=None, zero_division=0) cm = confusion_matrix(all_true, all_pred) results[region] = {'macro_f1': mean_f1, 'std': std_f1, 'per_class': per_class, 'cm': cm, 'n_patients': len(pids_with_data)} log(f' {region:<15} macro-F1={mean_f1:.4f}±{std_f1:.4f} | ' f'Pre={per_class[0]:.3f} Ict={per_class[1]:.3f} Post={per_class[2]:.3f} ' f'(N={len(pids_with_data)})') return results # ── Part B: Cross-region zero-shot transfer ──────────────────────────────────── def part_B(all_data): """Train on region X (all patients), zero-shot test on region Y (all patients).""" log('\n' + '='*60) log('Part B: Cross-region zero-shot phase transfer (5×5 macro-F1)') log('='*60) region_list = list(REGIONS.keys()) transfer_mat = np.full((len(region_list), len(region_list)), np.nan) for i, src in enumerate(region_list): # Build source training set (all patients that have src region) src_pids = [p for p, d in all_data.items() if d.get(src) is not None] if len(src_pids) < 3: continue X_src, y_src = [], [] for pid in src_pids: for lbl, phase in enumerate(PHASE_NAMES): arr = all_data[pid][src][phase] if arr is not None: X_src.append(arr); y_src.extend([lbl] * len(arr)) if not X_src: continue X_src = np.vstack(X_src); y_src = np.array(y_src) scaler = StandardScaler() X_src_n = scaler.fit_transform(X_src) clf = SVC(kernel='rbf', C=10, gamma='scale', class_weight='balanced') clf.fit(X_src_n, y_src) for j, tgt in enumerate(region_list): tgt_pids = [p for p, d in all_data.items() if d.get(tgt) is not None] if not tgt_pids: continue X_tgt, y_tgt = [], [] for pid in tgt_pids: for lbl, phase in enumerate(PHASE_NAMES): arr = all_data[pid][tgt][phase] if arr is not None: X_tgt.append(arr); y_tgt.extend([lbl] * len(arr)) if not X_tgt: continue X_tgt_n = scaler.transform(np.vstack(X_tgt)) y_pred = clf.predict(X_tgt_n) transfer_mat[i, j] = f1_score(np.array(y_tgt), y_pred, average='macro', zero_division=0) log(f' {"":>15}' + ''.join(f'{r[:8]:>10}' for r in region_list)) for i, src in enumerate(region_list): row_str = ''.join(f'{transfer_mat[i,j]:>10.4f}' if not np.isnan(transfer_mat[i,j]) else f'{"--":>10}' for j in range(len(region_list))) log(f' {src:<15}{row_str}') return transfer_mat, region_list # ── Part C: Ictal propagation timing ───────────────────────────────────────── def part_C(meta_df): """ For each FBTCS seizure: detect when each region's LFP features cross the ictal threshold, report lag vs clinical onset label. Threshold: RMS exceeds preictal mean + 2σ for ≥2 consecutive windows. """ log('\n' + '='*60) log('Part C: Ictal propagation timing (lag vs clinical onset label)') log('='*60) fbtcs = meta_df[meta_df['Seizure_Type'] == 'FBTCS'] lags = {r: [] for r in REGIONS} for _, row in fbtcs.iterrows(): pid = row['Patient ID'] fp = SEEG_ROOT / f'{pid}_SEEG' / str(row['Seizure_Filename']) if not fp.exists(): continue onset = float(row['Seizure_Onset_Sec']) offset = float(row['Seizure_Offset_Sec']) for region, prefixes in REGIONS.items(): try: # Read wide window: 60s preictal + ictal + 30s postictal t0 = max(0.0, onset - 60.0) t1 = min(offset + 30.0, onset + 180.0) sig, fs = _edf_read_channel(fp, prefixes, t0, t1) if sig is None: continue fs_i = int(round(fs)) win_samp = int(WIN_SEC * fs_i) stp_samp = int(STEP_SEC * fs_i) # Split into windows with timestamps times, rms_vals = [], [] for i in range(0, len(sig) - win_samp, stp_samp): seg = sig[i:i + win_samp] times.append(t0 + i / fs_i + WIN_SEC / 2) # window centre rms_vals.append(float(np.sqrt(np.mean(seg**2)))) times = np.array(times) rms_vals = np.array(rms_vals) # Preictal baseline: windows before onset - 5s pre_mask = times < (onset - 5.0) if pre_mask.sum() < 3: continue pre_mean = rms_vals[pre_mask].mean() pre_std = rms_vals[pre_mask].std() + 1e-8 thresh = pre_mean + 2.0 * pre_std # Find first 2 consecutive windows above threshold after onset post_mask = times >= onset if post_mask.sum() < 2: continue t_post = times[post_mask] r_post = rms_vals[post_mask] above = r_post > thresh detected = None for k in range(len(above) - 1): if above[k] and above[k+1]: detected = t_post[k] break if detected is not None: lag = float(detected - onset) lags[region].append(lag) log(f' {pid} {row["Seizure_Filename"]} {region}: lag={lag:+.1f}s') except Exception: continue log('\n Propagation summary (mean ± std lag vs clinical onset):') prop_results = {} for region in REGIONS: if lags[region]: m = float(np.mean(lags[region])) s = float(np.std(lags[region])) prop_results[region] = (m, s, len(lags[region])) log(f' {region:<15}: {m:+.2f} ± {s:.2f}s (N={len(lags[region])})') else: prop_results[region] = (np.nan, np.nan, 0) log(f' {region:<15}: no data') return prop_results # ── Part D: TUH scalp → intracranial transfer ───────────────────────────────── def _load_tuh_ictal_features(max_files=MAX_TUH_FILES): """ Load ictal (label=1) and pre-ictal/baseline (label=0) feature windows from TUH scalp EEG CSV-annotated files. """ import mne log(' Scanning TUH files for ictal annotations...') edfs = list(TUH_BASE.rglob('*.edf'))[:max_files * 3] # oversample, filter below X_ict, X_bas = [], [] n_files = 0 for edf_path in edfs: if n_files >= max_files: break csv_path = edf_path.with_suffix('.csv') if not csv_path.exists(): continue # Check if file has any ictal label try: with open(csv_path) as cf: content = cf.read() has_ictal = any(lbl in content for lbl in TUH_ICTAL_LABELS) if not has_ictal: continue except Exception: continue try: raw = mne.io.read_raw_edf(str(edf_path), preload=True, verbose=False) fs = raw.info['sfreq'] picks = mne.pick_types(raw.info, eeg=True) if len(picks) == 0: picks = list(range(len(raw.ch_names))) data = raw.get_data(picks=picks) if np.abs(np.median(data)) < 0.01: data = data * 1e6 ch_std = data.std(axis=1) good = ch_std < np.percentile(ch_std, 90) * 3 if good.sum() == 0: good = np.ones(len(ch_std), dtype=bool) avg = data[good].mean(axis=0) # Parse seizure windows szs = [] with open(csv_path) as cf: for line in cf: if line.startswith('#') or line.startswith('channel'): continue parts = line.strip().split(',') if len(parts) < 4: continue lbl = parts[3].strip() if lbl in TUH_ICTAL_LABELS: try: szs.append((float(parts[1]), float(parts[2]))) except ValueError: continue if not szs: continue win_samp = int(WIN_SEC * fs) stp_samp = int(STEP_SEC * fs) for sz_start, sz_end in szs: # Ictal windows i0 = int(sz_start * fs); i1 = int(min(sz_end, sz_start + ICT_DUR) * fs) for i in range(i0, i1 - win_samp, stp_samp): seg = avg[i:i + win_samp] if len(seg) == win_samp: X_ict.append(compute_features(seg, fs)) # Baseline windows (60s before seizure) b0 = max(0, int((sz_start - 70.0) * fs)) b1 = int((sz_start - 10.0) * fs) for i in range(b0, b1 - win_samp, stp_samp): seg = avg[i:i + win_samp] if len(seg) == win_samp: X_bas.append(compute_features(seg, fs)) n_files += 1 if n_files % 20 == 0: log(f' TUH: {n_files} files | ictal={len(X_ict)} | baseline={len(X_bas)}') except Exception: continue log(f' TUH total: {n_files} files | ictal={len(X_ict)} | baseline={len(X_bas)}') if not X_ict or not X_bas: return None, None return np.array(X_ict, dtype=np.float32), np.array(X_bas, dtype=np.float32) def part_D(all_data): """TUH scalp-trained binary (ictal/non-ictal) → test on each intracranial region.""" log('\n' + '='*60) log('Part D: TUH scalp → intracranial transfer (binary ictal/non-ictal)') log('='*60) X_ict_tuh, X_bas_tuh = _load_tuh_ictal_features() if X_ict_tuh is None: log(' TUH data not available — skipping Part D') return {} # Balance classes n_min = min(len(X_ict_tuh), len(X_bas_tuh)) idx = np.random.RandomState(42).permutation(len(X_ict_tuh))[:n_min] idx_b = np.random.RandomState(42).permutation(len(X_bas_tuh))[:n_min] X_tr = np.vstack([X_ict_tuh[idx], X_bas_tuh[idx_b]]) y_tr = np.array([1]*n_min + [0]*n_min) scaler = StandardScaler() X_tr_n = scaler.fit_transform(X_tr) clf = SVC(kernel='rbf', C=10, gamma='scale', class_weight='balanced') clf.fit(X_tr_n, y_tr) log(f' TUH SVM trained: {n_min} ictal + {n_min} baseline windows') transfer_results = {} for region in REGIONS: pids = [p for p, d in all_data.items() if d.get(region) is not None] if not pids: continue X_ict_r, X_bas_r = [], [] for pid in pids: d = all_data[pid][region] if d['Ictal'] is not None: X_ict_r.append(d['Ictal']) if d['Preictal'] is not None: X_bas_r.append(d['Preictal']) if not X_ict_r or not X_bas_r: continue X_ict_r = np.vstack(X_ict_r); X_bas_r = np.vstack(X_bas_r) X_te = np.vstack([X_ict_r, X_bas_r]) y_te = np.array([1]*len(X_ict_r) + [0]*len(X_bas_r)) X_te_n = scaler.transform(X_te) y_pred = clf.predict(X_te_n) f1_ict = f1_score(y_te, y_pred, pos_label=1, zero_division=0) f1_mac = f1_score(y_te, y_pred, average='macro', zero_division=0) transfer_results[region] = {'f1_ictal': f1_ict, 'macro_f1': f1_mac} log(f' TUH→{region:<15}: macro-F1={f1_mac:.4f} ictal-F1={f1_ict:.4f}') return transfer_results # ── Figures ──────────────────────────────────────────────────────────────────── def save_figures(part_a, transfer_mat, region_list, part_c, part_d): log('\nSaving figures...') fig, axes = plt.subplots(2, 2, figsize=(14, 10)) fig.suptitle('DACTRL — Seizure Lifecycle: Multi-Region Analysis', fontsize=13, fontweight='bold') # A: Within-region macro-F1 bar chart ax = axes[0, 0] if part_a: regs = list(part_a.keys()) f1s = [part_a[r]['macro_f1'] for r in regs] stds = [part_a[r]['std'] for r in regs] colors = ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0'][:len(regs)] bars = ax.bar(regs, f1s, yerr=stds, color=colors, alpha=0.8, capsize=4) ax.axhline(1/3, color='gray', linestyle='--', alpha=0.5, label='Chance (0.33)') ax.set_ylim(0, 1.05); ax.set_ylabel('Macro F1'); ax.legend(fontsize=8) ax.set_title('A: Within-region LOSO (3-class)') ax.set_xticklabels(regs, rotation=15, ha='right', fontsize=9) for bar, f1 in zip(bars, f1s): ax.text(bar.get_x()+bar.get_width()/2, bar.get_height()+0.02, f'{f1:.3f}', ha='center', va='bottom', fontsize=8) # B: Transfer matrix heatmap ax = axes[0, 1] if transfer_mat is not None: im = ax.imshow(transfer_mat, cmap='RdYlGn', vmin=0, vmax=1, aspect='auto') ax.set_xticks(range(len(region_list))); ax.set_yticks(range(len(region_list))) short = [r[:5] for r in region_list] ax.set_xticklabels(short, fontsize=8); ax.set_yticklabels(short, fontsize=8) ax.set_xlabel('Test region'); ax.set_ylabel('Train region') ax.set_title('B: Cross-region transfer (macro-F1)') plt.colorbar(im, ax=ax, shrink=0.8) for i in range(len(region_list)): for j in range(len(region_list)): if not np.isnan(transfer_mat[i, j]): ax.text(j, i, f'{transfer_mat[i,j]:.2f}', ha='center', va='center', fontsize=8, color='black' if 0.3 < transfer_mat[i,j] < 0.7 else 'white') # C: Propagation timing ax = axes[1, 0] if part_c: regs = [r for r in REGIONS if part_c[r][2] > 0] means = [part_c[r][0] for r in regs] stds = [part_c[r][1] for r in regs] colors2 = ['#2196F3','#4CAF50','#FF9800','#E91E63','#9C27B0'][:len(regs)] ax.barh(regs, means, xerr=stds, color=colors2, alpha=0.8, capsize=4) ax.axvline(0, color='black', linewidth=1.5, label='Clinical onset') ax.set_xlabel('Lag vs clinical onset (s)'); ax.legend(fontsize=8) ax.set_title('C: Ictal propagation timing') # D: TUH → intracranial transfer ax = axes[1, 1] if part_d: regs = list(part_d.keys()) mac_f1 = [part_d[r]['macro_f1'] for r in regs] ict_f1 = [part_d[r]['f1_ictal'] for r in regs] x = np.arange(len(regs)); w = 0.35 ax.bar(x - w/2, mac_f1, w, label='Macro-F1', color='#2196F3', alpha=0.8) ax.bar(x + w/2, ict_f1, w, label='Ictal-F1', color='#FF9800', alpha=0.8) ax.axhline(0.5, color='gray', linestyle='--', alpha=0.5, label='Chance') ax.set_xticks(x); ax.set_xticklabels(regs, rotation=15, ha='right', fontsize=9) ax.set_ylim(0, 1.05); ax.set_ylabel('F1'); ax.legend(fontsize=8) ax.set_title('D: TUH scalp → intracranial transfer') plt.tight_layout() out_path = OUT_DIR / 'seizure_lifecycle_results.png' plt.savefig(str(out_path), dpi=150, bbox_inches='tight') plt.close() log(f'Figure saved -> {out_path}') # ── Main ────────────────────────────────────────────────────────────────────── def main(): log('=' * 60) log('DACTRL — Simultaneous Multi-Region Seizure Lifecycle Analysis') log('=' * 60) meta_df = pd.read_excel(METADATA) log(f'Loaded metadata: {len(meta_df)} seizures, {meta_df["Patient ID"].nunique()} patients') log(f'Seizure types: {dict(meta_df["Seizure_Type"].value_counts())}') # Load all data once all_data = load_all_region_data(meta_df) n_total = sum(1 for pid, d in all_data.items() for r in REGIONS if d.get(r) is not None) log(f'Total valid patient×region combinations: {n_total}') # Run all parts part_a = part_A(all_data) transfer_mat, region_list = part_B(all_data) part_c = part_C(meta_df) part_d = part_D(all_data) # Save figures save_figures(part_a, transfer_mat, region_list, part_c, part_d) # Summary table log('\n' + '='*60) log('=== SEIZURE LIFECYCLE RESULTS SUMMARY ===') log('='*60) log('\nPart A — Within-region 3-class macro-F1 (LOSO SVM):') for r, res in part_a.items(): pc = res['per_class'] log(f' {r:<15} {res["macro_f1"]:.4f}±{res["std"]:.4f} | ' f'Pre={pc[0]:.3f} Ict={pc[1]:.3f} Post={pc[2]:.3f}') log('\nPart B — Cross-region transfer (diagonal=within, off-diagonal=transfer):') log(f' {"":>15}' + ''.join(f'{r[:8]:>10}' for r in region_list)) for i, src in enumerate(region_list): row_s = ''.join(f'{transfer_mat[i,j]:>10.4f}' if not np.isnan(transfer_mat[i,j]) else f'{"--":>10}' for j in range(len(region_list))) log(f' {src:<15}{row_s}') log('\nPart C — Propagation lag vs clinical onset:') for r, (m, s, n) in part_c.items(): if n > 0: log(f' {r:<15} {m:+.2f} ± {s:.2f}s (N={n})') log('\nPart D — TUH scalp → intracranial binary transfer:') for r, res in part_d.items(): log(f' TUH→{r:<15} macro={res["macro_f1"]:.4f} ictal={res["f1_ictal"]:.4f}') log(f'\nAll results saved to {OUT_DIR}') log('COMPLETE') if __name__ == '__main__': main()