Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
# Nemotron-H Inspired Mamba-3 Hybrid + Hinge Point Depth Recurrence

**Non-record submission. First Mamba depth recurrence and first hinge-point multi-recurrence in the competition.**

## Summary

This submission explores a hybrid Mamba-3 / Transformer architecture inspired by NVIDIA's Nemotron-H, with a novel depth recurrence strategy focused on the U-Net hinge point. While the absolute bpb does not beat SOTA, the architectural insights and systematic ablation study provide new findings for the SSM track.

**Key result:** post-quant val_bpb = **1.4765** (1000 steps, 1xH100, SP1024, GPTQ int6+LZMA, 8.2MB artifact)

## Architecture

- **7 Mamba-3 SISO layers + 1 Attention layer** (8 physical layers)
- Mamba-3 config: d_state=64, expand=2, headdim=64, chunk_size=64, ngroups=1
- Attention: GQA with 8 heads, 4 KV heads, RoPE base=10000
- Attention placed at layer 4 (evenly spaced, Nemotron-H style)
- U-Net encoder-decoder with skip connections
- `torch.compile(dynamic=False, fullgraph=False)`

### Depth Recurrence (Novel)

**Hinge point multi-recurrence:** Layers 3 and 4 (the U-Net hinge) are repeated twice, creating 12 virtual layers from 8 physical layers with zero extra parameters.

```
Physical: [M0, M1, M2, M3, A4, M5, M6, M7]
Virtual: [M0, M1, M2, M3, A4, M3, A4, M3, A4, M5, M6, M7]
↑ hinge layers 3,4 repeated 2x
```

Recurrence is enabled at 35% of training (step 350/1000) to allow initial convergence without the overhead.

## Ablation Results

### Depth Recurrence (first-ever on Mamba layers)

| Config | val_bpb (2000 steps) | Virtual layers | vs no-recur |
|--------|---------------------|----------------|-------------|
| No recurrence | 1.2916 | 8 | — |
| Block recur 2,3 | 1.2851 | 10 | -0.0065 |
| Block recur 2,3,4 | 1.2830 | 11 | -0.0086 |
| **Hinge recur 3,4 x2** | **1.2824** | **12** | **-0.0092** |
| 4-layer recur 2,3,4,5 | 1.2864 | 12 | -0.0052 |
| Dual Attn@hinge | 1.2899 | 11 | -0.0017 |

**Finding:** Focused recurrence at the hinge point outperforms spread recurrence. Repeating hinge layers 2x (12 virtual) beats 4-layer 1x (also 12 virtual) by 0.004 bpb.

### Approaches Tested and Ruled Out

| Approach | Result | Finding |
|----------|--------|---------|
| Remove RoPE (ROPE_FRACTION=0) | +0.072 worse | Small models (26M) need explicit position encoding, unlike Jamba (1.3B) |
| Ternary Mamba (BitLinear 1.58-bit) | +0.397 worse | 26M params insufficient for ternary (literature confirms min ~1.3B) |
| Q-Mamba DSQ (A=FP16 + mixed precision) | +0.066 worse than standard GPTQ | Full Hessian GPTQ already handles SSM outliers well |

### Quantization

Standard Full Hessian GPTQ int6 with AR self-generated calibration data (from PR #1355 pipeline). LZMA-9 compression.

- Pre-quant val_bpb: 1.3948
- Post-quant val_bpb: 1.4765
- Quantization gap: 0.082
- Artifact size: 8.2MB (well under 16MB cap)

## Reproduction

### Setup (RunPod or Modal with H100)

```bash
# Install dependencies
pip install -r requirements.txt

# Additionally, Mamba-3 modules need to be copied from mamba3-release branch:
git clone --depth 1 --branch mamba3-release https://github.com/state-spaces/mamba.git /tmp/mamba3src
PKG=$(python -c 'import mamba_ssm,os; print(os.path.dirname(mamba_ssm.__file__))')
cp /tmp/mamba3src/mamba_ssm/modules/mamba3.py $PKG/modules/
cp -r /tmp/mamba3src/mamba_ssm/ops/triton/mamba3 $PKG/ops/triton/
cp /tmp/mamba3src/mamba_ssm/ops/triton/angle_cumsum.py $PKG/ops/triton/
rm -rf /tmp/mamba3src

# Download dataset
python3 data/cached_challenge_fineweb.py --variant sp1024 --train-shards 10
```

### Training (1xH100, ~17 min for 1000 steps + GPTQ)

```bash
RUN_ID=nemotron_hinge \
DATA_PATH=./data/datasets/fineweb10B_sp1024/ \
TOKENIZER_PATH=./data/tokenizers/fineweb_1024_bpe.model \
VOCAB_SIZE=1024 \
NUM_LAYERS=8 \
NUM_ATTN_LAYERS=1 \
ATTN_PLACEMENT=even \
MAMBA3_D_STATE=64 \
RECUR_LAYERS=3,4 \
RECUR_MODE=block \
RECUR_REPEATS=2 \
RECUR_START_FRAC=0.35 \
ITERATIONS=1000 \
torchrun --standalone --nproc_per_node=1 train_nemotron_hybrid.py
```

### Training (8xH100, 10 min — pending compute grant)

```bash
# Same config but with:
# torchrun --standalone --nproc_per_node=8
# MAX_WALLCLOCK_SECONDS=600
# Expected: val_bpb ~1.25-1.30 post-quant
```

## Credits / Built On

- **PR #1355** (@mamba3-hybrid author): Mamba-3 Hybrid base, GPTQ pipeline, MuonEq-R optimizer
- **NVIDIA Nemotron-H** (arXiv 2504.03624): Hybrid architecture inspiration (92% SSM + 8% attention)
- **Mamba-3** (ICLR 2026, Gu et al.): SISO SSM with complex-valued states
- **PR #1204** (@sisovic): Depth recurrence concept (adapted from Transformer to SSM)
- **Q-Mamba, Mamba-PTQ, Quamba2**: Mamba quantization research informing our ablations

## Compute

All experiments run on Modal.com 1xH100 instances. Pending OpenAI compute grant for 8xH100 runs.
Total compute used: ~$30 Modal credits across 20+ experiments.

## What's Next

1. Full 8xH100 10-min run with best config (pending compute)
2. SP8192 tokenizer (expected ~0.05 bpb improvement)
3. Long-context evaluation (Mamba's O(n) advantage for 8K-32K eval)
4. Enable TTT and EMA for additional gains
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mamba_ssm @ https://github.com/state-spaces/mamba/releases/download/v2.3.1/mamba_ssm-2.3.1%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
causal_conv1d @ https://github.com/Dao-AILab/causal-conv1d/releases/download/v1.6.1.post4/causal_conv1d-1.6.1%2Bcu12torch2.9cxx11abiTRUE-cp312-cp312-linux_x86_64.whl
einops
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"author": "Yongkang Zou",
"github_id": "inin-zou",
"name": "Nemotron-H Inspired Mamba-3 Hybrid + Hinge Point Depth Recurrence",
"blurb": "First Mamba depth recurrence in Parameter Golf: 7 Mamba-3 + 1 Attention hybrid with hinge-point multi-recurrence (12 virtual layers from 8 physical). Inspired by NVIDIA Nemotron-H architecture.",
"date": "2026-04-13T23:00:00Z",
"val_loss": 2.4930,
"val_bpb": 1.4765,
"bytes_total": 8295138,
"bytes_code": 90450,
"notes": "Run on 1xH100 (1000 steps). Pending OpenAI compute grant for full 8xH100 10-min run. Pre-quant val_bpb=1.3948, estimated 2000-step post-quant ~1.35."
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
logs/baseline_gptq.txt
val_bpb:enabled tokenizer_kind=sentencepiece tokenizer_path=/app/data/tokenizers/fineweb_1024_bpe.model
train_loader:dataset:fineweb10B_sp1024 train_shards:1
val_loader:shards pattern=/app/data/datasets/fineweb10B_sp1024/fineweb_val_*.bin tokens:62021632
attn_placement:even attn_indices:[4]
depth_recurrence: layers=[3, 4] mode=block
model_params:26216040
world_size:1 grad_accum_steps:8
mode:mamba3_hybrid num_attn_layers:1 attn_indices:[4]
ssd: d_state:64 expand:2 headdim:64
attn: num_heads:8 num_kv_heads:4 rope_base:10000.0
num_layers:8 mlp_mult:3
tie_embeddings:True embed_lr:0.05 head_lr:0.0 matrix_lr:0.025 scalar_lr:0.02
train_batch_tokens:524288 train_seq_len:4096 iterations:1000 warmup_steps:20 max_wallclock_seconds:0.000
seed:1337
warmup_step:1/20
warmup_step:2/20
warmup_step:3/20
warmup_step:4/20
warmup_step:5/20
warmup_step:6/20
warmup_step:7/20
warmup_step:8/20
warmup_step:9/20
warmup_step:10/20
warmup_step:11/20
warmup_step:12/20
warmup_step:13/20
warmup_step:14/20
warmup_step:15/20
warmup_step:16/20
warmup_step:17/20
warmup_step:18/20
warmup_step:19/20
warmup_step:20/20
step:0/1000 val_loss:6.9357 val_bpb:4.1077 train_time:0ms step_avg:0.01ms
step:1/1000 train_loss:6.9356 train_time:458ms step_avg:458.05ms
step:2/1000 train_loss:6.5663 train_time:916ms step_avg:458.24ms
step:3/1000 train_loss:6.0412 train_time:1374ms step_avg:458.04ms
step:4/1000 train_loss:5.8702 train_time:1832ms step_avg:458.07ms
step:5/1000 train_loss:5.5853 train_time:2290ms step_avg:458.04ms
step:6/1000 train_loss:5.3060 train_time:2748ms step_avg:458.07ms
step:7/1000 train_loss:5.0688 train_time:3206ms step_avg:458.05ms
step:8/1000 train_loss:4.8584 train_time:3664ms step_avg:458.03ms
step:9/1000 train_loss:4.7453 train_time:4122ms step_avg:458.01ms
step:10/1000 train_loss:4.6856 train_time:4580ms step_avg:458.03ms
step:200/1000 train_loss:2.9027 train_time:91875ms step_avg:459.37ms
depth_recurrence:enabled at step 350 frac=0.35 schedule=[0, 1, 2, 3, 4, 3, 4, 3, 4, 5, 6, 7]
step:400/1000 train_loss:2.5676 train_time:195471ms step_avg:488.68ms
late_qat:enabled bits=6 at step 476 scale=0.1497
step:500/1000 val_loss:2.6478 val_bpb:1.5682 train_time:276969ms step_avg:553.94ms
step:600/1000 train_loss:2.6001 train_time:346158ms step_avg:576.93ms
step:800/1000 train_loss:2.3457 train_time:485299ms step_avg:606.62ms
step:1000/1000 train_loss:2.4538 train_time:624107ms step_avg:624.11ms
step:1000/1000 val_loss:2.3551 val_bpb:1.3948 train_time:624107ms step_avg:624.11ms
peak memory allocated: 21397 MiB reserved: 21700 MiB
ema:applying EMA weights
Serialized model: 102806059 bytes
Code size: 90450 bytes
Total submission size: 102896509 bytes
gptq:generating autoregressive calibration data...
gptq:generated 32 seqs in 966.8s
gptq:collecting hessians...
gptq:collected hessians for 35 layers
gptq:quantization complete in 993.9s total
Serialized model int6+lzma-9: 8204688 bytes (payload:26989856 raw_torch:27033765 payload_ratio:3.81x)
Total submission size int6+lzma-9: 8295138 bytes
[rank0]:W0414 00:04:21.959000 9 site-packages/torch/_dynamo/convert_frame.py:1358] [13/8] torch._dynamo hit config.recompile_limit (8)
[rank0]:W0414 00:04:21.959000 9 site-packages/torch/_dynamo/convert_frame.py:1358] [13/8] function: 'forward' (/app/train_nemotron_hybrid.py:1102)
[rank0]:W0414 00:04:21.959000 9 site-packages/torch/_dynamo/convert_frame.py:1358] [13/8] last reason: 13/7: tensor 'self._modules['attn']._modules['rotary']._cos_cached' size mismatch at index 2. expected 4095, actual 4096
[rank0]:W0414 00:04:21.959000 9 site-packages/torch/_dynamo/convert_frame.py:1358] [13/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
[rank0]:W0414 00:04:21.959000 9 site-packages/torch/_dynamo/convert_frame.py:1358] [13/8] To diagnose recompilation issues, see https://pytorch.org/docs/main/torch.compiler_troubleshooting.html
final_int8_zlib_roundtrip val_loss:2.4930 val_bpb:1.4765 eval_mode:standard eval_time:40523ms
final_int8_zlib_roundtrip_exact val_loss:2.49298788 val_bpb:1.47648785
Stopping app - local entrypoint completed.
✓ App completed. View run at
https://modal.com/apps/yongkang-zou1999/main/ap-g38RYWF6UCTTyFVL8834tz
Loading