Reconstructing Visual Perception from EEG Signals via Multi-Scale Neural Encoding and Latent Diffusion
- Overview
- Architecture
- Latent Space Transformations
- Datasets
- Training Pipeline
- Evaluation
- Repository Structure
TTDnet is a two-stage framework for generating photorealistic images directly from electroencephalography (EEG) recordings. Given a raw multi-channel EEG epoch captured while a subject views an object, TTDnet reconstructs a semantically faithful image of that object. The system bridges the gap between low-SNR neural signals and high-fidelity image generation by chaining:
- A self-supervised EEG encoder (InceptSADNet) pre-trained to learn robust temporal-spatial neural representations,
- An IP-Adapter–style conditioning bridge that translates EEG latents into the semantic embedding space of a large-scale text-to-image diffusion model, and
- A Stable Diffusion XL (SDXL) backbone fine-tuned with LoRA to generate 512 × 512 images conditioned on EEG rather than text.
The pipeline requires no textual descriptions at any point; visual content is decoded entirely from brain activity.
The full architecture is divided into two sequential training stages. Stage A learns a general-purpose EEG representation via masked signal reconstruction; Stage B fine-tunes the frozen encoder together with a conditioning bridge and an SDXL diffusion backbone on paired EEG–image data.
Goal: Learn transferable EEG representations without any image labels, using large-scale unlabelled EEG recordings.
The encoder transforms raw multi-channel EEG into a compact sequence of latent tokens. It consists of the following major blocks:
| Block | Description | Input → Output |
|---|---|---|
| Multi-Scale Temporal Convolutions | Three parallel 1-D convolution branches with kernel sizes 7, 15, and 31 time-steps extract features at short, medium, and long temporal scales simultaneously. Each branch applies Conv2d → BatchNorm → ELU. Outputs are concatenated along the filter dimension. |
[B, 1, C, T] → [B, 3·F₁, C, T] |
| Depthwise Spatial Convolution | A grouped convolution with kernel (C, 1) collapses the spatial (channel) dimension, learning per-filter spatial weights across all EEG electrodes, followed by BatchNorm → ELU → AvgPool → Dropout. |
[B, 3·F₁, C, T] → [B, 3·F₁·D, 1, T/4] |
| Squeeze-and-Excitation (SE) Block | A channel recalibration module that adaptively re-weights filter responses: GlobalAvgPool → FC → ReLU → FC → Sigmoid → scale. This suppresses noisy channels and amplifies informative ones. |
[B, F, 1, T'] → [B, F, 1, T'] |
| Projection & Tokenization | A 1×1 Conv → BatchNorm → ELU → AvgPool → Dropout projection maps features to the target embedding dimension, then spatially reshapes the output into a token sequence. |
[B, F, 1, T'] → [B, seq_len, embed_dim] |
| Transformer Encoder | A stack of L standard Transformer blocks, each with multi-head self-attention (MHSA) and a GELU feed-forward network (FFN), both wrapped in Pre-LayerNorm residual connections. |
[B, seq_len, embed_dim] → [B, seq_len, embed_dim] |
Default hyperparameters: embed_dim = 1024, depth (L) = 6, num_heads = 8, F₁ = 8, D = 2.
Instead of the patch-level random masking used in standard MAE, we employ contiguous temporal masking tailored to EEG:
- Masking Stage: Multiple random contiguous time segments are zeroed out in the raw EEG input (default: 50% of all time-steps). This forces the encoder to infer missing temporal context from surrounding neural activity.
- Encoding: The masked signal passes through the full InceptSADEncoder, producing latent tokens.
- Lightweight Decoder: A 4-layer Transformer decoder with sinusoidal positional embeddings and a linear prediction head reconstructs the original (un-masked) signal from the encoder's output.
- Reconstruction Loss: MSE is computed only over the masked time regions, encouraging the encoder to build a holistic understanding of the temporal dynamics rather than merely copying visible segments.
The pre-training loop (stageA1_incept_pretrain.py) trains the encoder + decoder jointly with AdamW (lr = 4e-4, cosine schedule) for 200 epochs on large-scale unlabelled EEG data. After pre-training, only the encoder weights are transferred to Stage B; the decoder is discarded.
Alternative encoder path: The codebase also retains a ViT-based Masked Autoencoder (
sc_mbm/mae_for_eeg.py) with 1-D Patch Embedding and standard random patch masking, used with the original EEG dataset and the SD 1.5 backend. The InceptSADEncoder supersedes this for the SDXL pipeline.
Goal: Learn to generate images that match the visual stimulus a subject was viewing, given only their EEG recording.
A thin wrapper that chains the frozen pre-trained InceptSADEncoder with the IP-Adapter Bridge and exposes a unified conditioning interface. On a forward pass it returns two outputs: (a) cross-attention conditioning tokens for the UNet, and (b) the raw EEG latent for CLIP alignment supervision.
The bridge converts the EEG encoder's variable-length token sequence into a fixed-size conditioning signal compatible with SDXL's cross-attention mechanism. It has two sub-modules:
| Sub-Module | Description | Input → Output |
|---|---|---|
| Perceiver Resampler | A set of 16 learnable latent query tokens attend to the EEG encoder tokens via iterative cross-attention (2 Resampler Layers). Each layer consists of: LayerNorm → MultiHeadCrossAttention → Residual → LayerNorm → FFN → Residual. The Resampler compresses the variable-length EEG sequence into a fixed set of 16 conditioning tokens projected to match SDXL's 2048-dim context space. |
[B, seq_len, 1024] → [B, 16, 2048] |
| CLIP Alignment Head | Mean-pools the EEG latent → MLP (1024 → 1024, GELU, → 768) → L2-normalize. A learnable logit_scale parameter controls the temperature of the symmetric InfoNCE contrastive loss that aligns EEG embeddings with CLIP ViT-L/14 image embeddings. |
[B, seq_len, 1024] → [B, 768] |
The generative core is built on SDXL (Stable Diffusion XL Base 1.0), modified for EEG-conditioned generation:
| Component | Role | Details |
|---|---|---|
| VAE (AutoencoderKL) | Encodes images into / decodes images from a spatial latent space. Runs in FP32 for numerical stability (its internal exp() overflows in FP16). |
[B, 3, 512, 512] ↔ [B, 4, 64, 64] (48× compression). Frozen during training. |
| UNet (UNet2DConditionModel) + LoRA | The denoising network. The base UNet is frozen; LoRA adapters (rank-16, α = 16, Gaussian initialization) are injected into all to_q, to_k, to_v, and to_out.0 attention projection matrices, enabling parameter-efficient fine-tuning. |
Receives noisy latents + EEG conditioning tokens via cross-attention. Only LoRA weights (~20M params) are updated. |
| DPM++ 2M Karras Scheduler | A fast ODE-based noise scheduler that requires only 25 denoising steps (vs. 250 for PLMS/DDIM in SD 1.5), dramatically reducing inference time. | Used for both training (noise addition) and inference (iterative denoising). |
| CLIP ViT-L/14 Image Encoder | Extracts 768-dim image embeddings from ground-truth images during training for the contrastive CLIP alignment loss. Frozen. | Not used at inference. |
| Pooled Conditioning Projection | A linear layer that mean-pools the 16 cross-attention tokens and projects them to 1280-dim, replacing SDXL's pooled text embedding (text_embeds). |
[B, 16, 2048] → mean → [B, 1280] |
The total loss during fine-tuning is a weighted sum:
L_total = L_diffusion + 0.5 · L_clip
- L_diffusion (MSE): Standard ε-prediction loss — the UNet predicts the noise added to the VAE latent, and the MSE between predicted and actual noise is minimized.
- L_clip (InfoNCE): Symmetric contrastive loss aligning the EEG-derived embedding with the CLIP image embedding in a shared 768-dim space. Computed every 4 batches to reduce CLIP encoder overhead.
At inference, no ground-truth images are needed:
Raw EEG → InceptSADEncoder → EEG Latent Tokens
→ Perceiver Resampler → 16 Conditioning Tokens [B, 16, 2048]
→ Pooled Projection → Pooled Embedding [B, 1280]
Random Gaussian Noise [B, 4, 64, 64]
→ 25 DPM++ Steps (UNet denoising, conditioned via cross-attention)
→ Denoised VAE Latent [B, 4, 64, 64]
→ VAE Decoder
→ Reconstructed Image [B, 3, 512, 512]
Legacy SD 1.5 path: The codebase retains a full Stable Diffusion 1.5 pipeline (
dc_ldm/ldm_for_eeg.py → eLDM) with DDIM/PLMS sampling (250 steps), a simplercond_stage_modelusing linear projection, and the ViT-basedeeg_encoder. This is used whenconfig.model_type = 'sd15'.
| Property | Value |
|---|---|
| Source | ThingsEEG dataset (Gifford et al.) |
| Subjects | Up to 50 (default: subjects 1–5) |
| EEG channels | 63 |
| Sampling rate | Resampled to 512 time-points per epoch |
| Stimuli | 1,854 unique object concepts from the THINGS database |
| Split strategy | By concept — test set contains held-out object categories to evaluate generalization to unseen visual concepts |
| Test ratio | 10% of concepts |
| Image size | 512 × 512 (for SDXL) |
| Preprocessing | Raw .edf → band-pass filter → epoch extraction → resampling → per-channel z-normalization → .pth files (preprocess_things_eeg.py) |
| Property | Value |
|---|---|
| Source | PhysioNet / MNE datasets (.edf → .npy via preprocess_edf_to_npy.py) |
| Channels | 64 |
| Purpose | Self-supervised temporal masking pre-training of InceptSADEncoder (no image labels) |
python stageA1_incept_pretrain.py
| Parameter | Value |
|---|---|
| Encoder | InceptSADEncoder (embed_dim=1024, depth=6, heads=8) |
| Masking | 50% contiguous temporal masking |
| Decoder | 4-layer Transformer (512-dim) |
| Optimizer | AdamW (lr=4e-4, β=(0.9, 0.95), weight_decay=0.05) |
| Epochs | 200 |
| Batch size | 256 |
| Loss | MSE (masked regions only) |
| Output | Encoder checkpoint (checkpoint_best.pth) |
python eeg_ldm.py
| Parameter | Value |
|---|---|
| Diffusion model | SDXL Base 1.0 |
| EEG encoder | Frozen InceptSADEncoder (from Stage A1) |
| Bridge | Perceiver Resampler (16 tokens, 2 layers) |
| UNet adaptation | LoRA (rank=16, α=16) on Q/K/V/Out projections |
| Optimizer | AdamW (lr=5.3e-5, weight_decay=0.01) |
| LR schedule | Cosine with 2-epoch linear warmup |
| Batch size | 2 × 16 gradient accumulation = 32 effective |
| Samples per epoch | 15,000 (random subset; full dataset seen across epochs) |
| Epochs | 50 |
| Precision | Mixed (AMP): VAE in FP32, UNet in FP16 |
| Loss | MSE (diffusion) + 0.5 × InfoNCE (CLIP alignment) |
| CLIP loss frequency | Every 4th batch |
| Scheduler | DPM++ 2M Karras |
| Inference steps | 25 |
| Checkpoint | Per-epoch (latest + best) with full resume support |
| Tracking | Weights & Biases |
| Training time | ~42 hours (single GPU) |
After training, the framework generates multiple samples per test EEG epoch and evaluates them against ground-truth images using:
| Metric | Type | Measures |
|---|---|---|
| MSE | Pixel-level | Mean squared error (lower is better) |
| PCC | Pixel-level | Pearson correlation coefficient (higher is better) |
| SSIM | Structural | Structural similarity index (higher is better) |
| LPIPS | Perceptual | Learned perceptual similarity via AlexNet (lower is better) |
| FID | Distributional | Fréchet Inception Distance via InceptionV3 (lower is better) |
| Top-1 Accuracy | Semantic | 50-way classification accuracy using ViT-H/14 (higher is better) |
Evaluation is implemented in eval_metrics.py and operates in both pair-wise and n-way scoring modes.
TTDnet/
├── code/
│ ├── config.py # All hyperparameter configurations
│ ├── dataset.py # EEG pre-training & original EEG datasets
│ ├── things_dataset.py # ThingsEEG paired dataset (EEG + images)
│ ├── eeg_ldm.py # Main entry point (training + generation)
│ ├── eval_metrics.py # MSE, PCC, SSIM, LPIPS, FID, n-way accuracy
│ │
│ ├── stageA1_eeg_pretrain.py # Stage A1: ViT-MAE pre-training script
│ ├── stageA1_incept_pretrain.py # Stage A1: InceptSAD pre-training script
│ │
│ ├── preprocess_edf_to_npy.py # Raw .edf → .npy conversion
│ ├── preprocess_things_eeg.py # ThingsEEG preprocessing pipeline
│ │
│ ├── sc_mbm/ # Self-supervised EEG encoders
│ │ ├── InceptSADNet.py # InceptSADNet classification model
│ │ ├── incept_encoder.py # InceptSADEncoder (used in pipeline)
│ │ ├── incept_pretrain.py # Temporal masking pre-training wrapper
│ │ ├── mae_for_eeg.py # ViT-based MAE encoder (legacy)
│ │ ├── trainer.py # MAE training utilities
│ │ └── utils.py # Checkpoint save/load helpers
│ │
│ └── dc_ldm/ # Diffusion & conditioning modules
│ ├── sdxl_pipeline.py # SDXL pipeline (train + generate)
│ ├── ip_adapter_bridge.py # Perceiver Resampler + CLIP alignment
│ ├── ldm_for_eeg.py # SD 1.5 pipeline (legacy)
│ ├── util.py # Config instantiation utilities
│ │
│ ├── models/
│ │ ├── autoencoder.py # VAE (SD 1.5 path)
│ │ └── diffusion/
│ │ ├── ddpm.py # DDPM implementation
│ │ ├── ddim.py # DDIM sampler
│ │ ├── plms.py # PLMS sampler
│ │ └── classifier.py # Classifier-free guidance
│ │
│ └── modules/
│ ├── attention.py # Cross-/Self-/Linear Attention
│ ├── x_transformer.py # Extended Transformer blocks
│ ├── ema.py # Exponential Moving Average
│ ├── diffusionmodules/ # UNet blocks, timestep embed, utils
│ ├── encoders/ # CLIP/FrozenEncoder wrappers
│ ├── distributions/ # Gaussian distributions
│ └── losses/ # Perceptual & VQ losses

