Ablation results for all 7 OpenAI-requested research architectures #1500
Replies: 1 comment
-
Final Results (All 22 Runs Complete)All runs finished on the DGX Spark. Here is the complete data across all 7 architectures. Full logs and CSV available as a public gist: https://gist.github.com/dentity007/324ac35505c27acd18e7ffb468f4fa08 Combined Results Table (Sorted by val_bpb)
Final Findings Per Architecture1. Universal Transformer: Doubling iterations from 6 to 24 (4x compute per step) produces identical BPB (3.2483 vs 3.2490). Full weight sharing plateaus immediately. Mini depth recurrence is the way. 2. Text Diffusion: All three AR/diffusion ratios produce identical BPB to 4 decimal places. The diffusion loss contributes literally nothing to causal eval. Fundamental mismatch, not a tuning problem. 3. Random Adapters: Frozen random orthogonal projections with ~600K trainable params reach 2.51 BPB. That is 1.5 BPB better than chance and suggests transformer architecture itself (attention, residuals, norms) carries most of the inductive bias. Wider adapters (RND-2) actually hurt, landing at 2.63. 4. JEPA: Three different JEPA weights (10%, 30%, 50%) produce identical BPB. JEPA as a concurrent auxiliary loss has zero effect at 200 steps. Might need a pre-training stage curriculum. 5. Mamba SSM: Winner on raw BPB at 2.0295, but with a catch. Pure PyTorch SSM runs at 37 seconds per step (50x slower than attention). At that speed, only ~5 effective training steps completed in 200 iterations, yet it still reached the lowest BPB. Strong signal that fast SSM kernels (Triton/CUDA selective scan) would be competitive. SSM-4 with larger state also hit 2.18 but never completed training. 6. H-Net: Fastest architecture at 513ms per step while still reaching 2.06 BPB. But the chunker configuration makes zero difference. All three variants (default, large chunker, boundary regularizer) produce identical BPB to 4 decimal places. The chunker is learning the identity function regardless of what you do to it. 7. Megakernels: Without Triton kernels (unavailable on ARM), we tested PyTorch-equivalent configurations. MEGA-2 with d=640 hit 2.16 BPB, beating MEGA-3 with 11 layers at 2.20. Wider is better than deeper for this model size. Note: the actual Triton kernel speedup would be additive on H100. Cross-Architecture Insights
Caveats
Raw Data AccessEverything is in the gist: https://gist.github.com/dentity007/324ac35505c27acd18e7ffb468f4fa08
Public domain. Use for any purpose. If you rerun with different configurations or longer training, please share back - interested to see how the ordering shifts. |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Sharing early results from an overnight ablation study on all 7 of the "Requests for PRs" architectures from the README. These are non-record submissions (PRs #1191-#1197) but I wanted to go deeper than the initial proof-of-concept runs.
All tests run on a single NVIDIA DGX Spark GB10 (128GB unified memory, no torch.compile), sp1024 data, 200 training steps, SEED=42. Hardware is slower than 8xH100 but the relative ordering between configurations should hold. No TTT, no SLOT, no eval-time tricks.
Results So Far (13 of 22 runs complete)
1. Universal Transformer (PR #1193)
Finding: More iterations does not help. 24 iterations is 4x slower per step than 6, and achieves virtually identical BPB (3.2490 vs 3.2483). The shared-weight architecture hits a ceiling quickly. This aligns with PR #363s earlier findings. The practical approach is mini depth recurrence (repeat 2-3 specific layers) rather than full weight sharing, which is what PR #1204 and PR #1334 ended up doing.
2. Text Diffusion (PR #1194)
Finding: The diffusion loss contributes nothing. All three configurations produce identical BPB at 200 steps. The diffusion head learns to predict masked tokens during training, but since eval is purely autoregressive (causal, left-to-right), none of that knowledge transfers. The 70/30 split is actually slower (1388ms vs 997ms) because the diffusion forward pass adds overhead. Diffusion for text compression appears to be a dead end unless the eval protocol changes.
3. Random Linear Map Adapters (PR #1195)
Finding: Random projections provide a surprisingly strong starting point. All configurations land near 2.51 BPB regardless of adapter configuration. Wider adapters (RND-2) actually hurt, suggesting the default diagonal scale+shift is already well-matched to random orthogonal projections. The progressive unfreezing strategy (RND-4) shows no benefit at 200 steps. The gap from 2.51 to the trained baseline (~1.57 at 200 steps) represents the value of learning actual projection directions, not just scales.
4. JEPA (PR #1196)
Finding: JEPA weight has zero effect. Whether the JEPA auxiliary loss is 10%, 30%, or 50% of the total, val_bpb is identical to 4 decimal places. The JEPA predictor learns something (its loss decreases during training), but that knowledge does not transfer to the AR objective within 200 steps. This could change with longer training or with JEPA as a pre-training stage rather than a concurrent auxiliary.
5. Mamba SSM Hybrid (PR #1197) - still running
Note: Pure PyTorch SSM is extremely slow without custom CUDA kernels (35s per step vs ~700ms for attention). The BPB at step 100 (2.2066) is actually competitive with JEPA, suggesting SSMs have real potential if the speed issue is solved with Triton/CUDA selective scan kernels.
Still Running (9 more runs)
Will update this thread when they finish.
Early Conclusions
Depth recurrence plateaus fast. 6 iterations matches 24 at a fraction of the cost. Mini recurrence (2-3 layers) as in PR Record: ParallelResiduals + MiniDepthRecurrence, 1.1063 BPB / 1.8679 nats, -0.0072 vs PR #1179, -0.0143 vs merged SOTA #1204 is the right approach.
Diffusion is incompatible with causal eval. The knowledge from bidirectional masked prediction does not help left-to-right scoring. This is a fundamental mismatch, not a tuning problem.
Random projections are surprisingly capable. Diagonal adapters on frozen random orthogonal matrices reach 2.51 BPB. That is far better than chance and suggests the transformer architecture itself (attention patterns, residual connections, normalization) provides most of the inductive bias. The learned weights add ~1 BPB of improvement on top.
JEPA is neutral as an auxiliary loss. Neither helps nor hurts at any weight. May need a curriculum (pre-train with JEPA, then switch to AR) rather than concurrent training.
SSMs are promising but bottlenecked by implementation. The 2.21 BPB at step 100 (with barely any training due to 35s/step) hints that a fast SSM implementation could be competitive. Someone with Triton skills could make this work.
Hardware Note
All runs on a DGX Spark GB10 (single GPU, 128GB unified memory, ARM architecture). No torch.compile (Triton/inductor unsupported on aarch64), no flash attention (SDPA fallback). This hardware is ~6x slower than 8xH100 per step, so absolute BPB numbers are higher than competition runs. The relative rankings between configurations are what matter here.
Full logs available on request. PRs: #1191, #1192, #1193, #1194, #1195, #1196, #1197.
Corrections and feedback welcome, especially if anyone has tried these architectures with different configurations.
Beta Was this translation helpful? Give feedback.
All reactions