Accepted to ACL 2026 SRW. Go check out our paper! 🎉
First probing study of diffusion language model (DLM) hidden states. Linear classifiers on intermediate denoising steps predict whether outputs will be functionally correct.
- Correctness signal emerges across denoising steps (+0.08-0.11 AUC on reasoning tasks). Unlike AR models, DLMs accumulate additional signal through iterative refinement on GSM8K, MBPP, and ARC.
- Step-0 signal is prompt difficulty, not diffusion-specific (AUC 0.61-0.80). Initial hidden states already carry above-chance correctness information, comparable to AR probes on the same prompt.
- Task-dependent emergence patterns. Structural tasks (JSON) remain flat (~0.80 from step 0), reasoning tasks show gradual buildup (0.08-0.11 AUC gain).
- Distinct layer dynamics. LLaDA concentrates signal in upper layers (L22-28). Dream migrates from upper to lower layers on JSON schema.
- Offline filtering avoids wasted compute. Per-step probe confidence identifies likely failures, skipping 36-98% of generations depending on task.
| Key | Model | Layers |
|---|---|---|
llada |
GSAI-ML/LLaDA-8B-Instruct | 33 |
dream |
Dream-org/Dream-v0-Instruct-7B | 29 |
| Key | Source | N | Gen length | Correctness check |
|---|---|---|---|---|
jsonschema |
eth-sri/json-mode-eval-extended | 272 | 256 | JSON parse + reference match |
gsm8k |
openai/gsm8k (test) | 1,319 | 512 | Numeric answer match |
mbpp |
google-research-datasets/mbpp (sanitized test) | 257 | 256 | Code execution + test assertions |
arc |
allenai/ai2_arc (ARC-Challenge test) | 1,172 | 256 | Answer letter match |
Stars mark the best layer per step. JSON schema shows strong signal from step 0 (flat emergence), while GSM8K shows gradual buildup. Dream's best layer migrates from upper to lower layers on JSON schema.
Best AUC across layers at each step. JSON schema is flat (~0.80 from step 0), while GSM8K, MBPP, and ARC rise gradually.
- Probe: PCA(64) + StandardScaler + LogisticRegression, 5-fold stratified CV
- Steps: 7 checkpoints (0, 1, 4, 16, 32, 64, 127) during 128-step denoising
- Regions: Generation region split into 4 equal-length position regions, mean-pooled
- Metric: AUC (control probes on shuffled labels yield ~0.50)
Scripts are organized by purpose:
src/core/— Main experiments and baselinessrc/ablations/— Length, region, std ablationssrc/applications/— Early exit, seed rerank, rebuttalsrc/utils/— Data processing and result comparison
Core experiments:
.venv/bin/modal run src/core/modal_midstep_probe.py --dataset jsonschema --model llada --chunks 8
.venv/bin/modal run src/core/modal_ar_probe.py --dataset jsonschema --chunks 4
.venv/bin/modal run src/core/modal_baseline_probes.py --baseline-type shuffle --dataset gsm8kAblations:
.venv/bin/modal run src/ablations/modal_length_ablation_probe.py --dataset arc --mode output_matched
.venv/bin/modal run src/ablations/modal_region_ablation.py --dataset jsonschema --model llada
.venv/bin/modal run src/ablations/modal_probe_with_std.py --dataset gsm8k --model dreamApplications:
.venv/bin/modal run src/applications/modal_early_exit_sim.py --dataset jsonschema
.venv/bin/modal run src/applications/modal_seed_rerank.py --dataset gsm8k --model llada
