Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
131 commits
Select commit Hold shift + click to select a range
ae03494
[TorchComms] add testing badge at experiments readme (#2010)
mori360 Nov 10, 2025
f4514ef
[compiler toolkit] specify passes through config (#2006)
yiming0416 Nov 10, 2025
02990b0
[simplefsdp] fix region ac in zero2-style FSDP (#1970)
ruisizhang123 Nov 10, 2025
fddd9eb
[SimpleFSDP] Add typing to simple_fsdp.py (#2001)
fegin Nov 11, 2025
e37f83f
[Full DTensor][Reland] Add full_dtensor flag (#2013)
fegin Nov 11, 2025
20fcfd7
set pg names (#1986)
tushar00jain Nov 11, 2025
11d73a2
Fix the error message of maybe_enable_async_tp() (#2011)
fegin Nov 11, 2025
f5d2b18
Add dry run mode (#2012)
fegin Nov 11, 2025
edbf349
[easy] [compiler toolkit] Clean up unused function (#2014)
yiming0416 Nov 11, 2025
2f9b44d
Run Torchtitan ROCm workflow on cron schedule & push to Main branch o…
akashveramd Nov 12, 2025
55c63c1
Revert PR-2016 & Redo "Run Torchtitan ROCm workflow on cron schedule …
akashveramd Nov 12, 2025
cbfb8e1
[compiler toolkit] Add tests and scripts for numerics check (#2015)
yiming0416 Nov 12, 2025
4b2b31c
Add .claude to .gitignore (#2026)
fegin Nov 13, 2025
ce1c0fc
Fix dry run mode (#2027)
fegin Nov 13, 2025
e7ee95a
[Compiler Toolkit] Make compiler toolkit work with checkpoint (#2030)
fegin Nov 13, 2025
23c993c
[Flux] Update integration test badge in README.md (#2019)
rm-wu Nov 13, 2025
028a455
Print device and stride when print module (#2045)
BoyuanFeng Nov 14, 2025
d9bdfbb
[SimpleFSDP] add manual bucketing pass (#1881)
ruisizhang123 Nov 16, 2025
22e959a
Add export_dtype parameter to `convert_to_hf` function (#2041)
idoh Nov 17, 2025
3819737
[compiler toolkit] Port joint_ac_pass from simplefsdp (#2051)
yiming0416 Nov 18, 2025
bfdc974
[compiler toolkit] Port manual bucketing from SimpleFSDP experiment (…
yiming0416 Nov 18, 2025
4a5fa99
Re:Run Torchtitan ROCm workflow on cron schedule & push to Main branc…
akashveramd Nov 19, 2025
c8ebd7a
Add a loss comparison script (#2029)
fegin Nov 19, 2025
605a9a1
Fix integration test gpu_arch_type field (#2060)
yiming0416 Nov 19, 2025
f541d91
[compiler toolkit] Add Trainer subclass for compiler toolkit (#2064)
yiming0416 Nov 19, 2025
8bf2265
Let loss_compare.py check the repo cleaness (#2062)
fegin Nov 20, 2025
f5e3a84
CUDAGraph support for SimpleFSDP and TP (#2050)
BoyuanFeng Nov 20, 2025
d167a20
compiler_toolkit: fix args access (#2067)
crcrpar Nov 20, 2025
58fa181
3outeille/transformers backend (Dense model only) (#2048)
3outeille Nov 20, 2025
f8fa21e
adding variable length attention to llama3 8b (#2000)
liangel-02 Nov 21, 2025
e1f7f31
remove scatter_add in MoE implementation (#1974)
garrett361 Nov 22, 2025
ad9f188
Update transformers backend name (#2075)
3outeille Nov 23, 2025
c70310c
Enhance loss_compare.py: Add Import/Export Options and Enable CI Comp…
fegin Nov 24, 2025
7e1edb6
Print out the version number (#2083)
fegin Nov 24, 2025
7e10d60
Autoparallel as an experiment in main (#2054)
xmfan Nov 25, 2025
607c70d
skip varlen integration test on rocm (#2085)
liangel-02 Nov 25, 2025
d0393b3
[Local Tensor] Replace dry_run.py with fake mode implementation (#2057)
fegin Nov 25, 2025
1b9cfda
add varlen attention for qwen 3 (#2084)
liangel-02 Nov 25, 2025
cbdb311
[FLUX] Add FLUX inference test in CI (#1969)
wwwjn Nov 25, 2025
befb7ae
Improve logging by formatting the dict as JSON. (#2094)
rakkit Dec 1, 2025
b39377f
add all SDPA backends to op_sac_save_list (#2095)
rakkit Dec 1, 2025
53e949c
modify save list for varlen attn (#2082)
liangel-02 Dec 2, 2025
571ce7c
Make sure log after distributed initialized. (#2102)
CptGit Dec 3, 2025
b3da1a2
[mxfp8] [docs] [BE] add MXFP8 usage documentation and benchmarks (#2096)
danielvegamyhre Dec 3, 2025
8d020cc
Mark input tokens to routed experts as dynamic to avoid a recompile (…
xmfan Dec 3, 2025
341b155
fix mxfp8 loss image (#2104)
danielvegamyhre Dec 3, 2025
1168f9e
Update hf_assets_path for llama4 (#2110)
H-Huang Dec 4, 2025
e98ae99
Enables parsing of --compile.components through CLI (#2115)
syed-ahmed Dec 5, 2025
303f284
fix `ForgeEngine` compatibility issue with (#2121)
JenniferWang Dec 7, 2025
575674a
Remove the hack for SAC + FlexAttention (#2118)
fegin Dec 8, 2025
b41832a
Add warning to run_tests (#2123)
H-Huang Dec 9, 2025
d192411
[compiler toolkit] Disable CUDAGraph integration test (#2127)
yiming0416 Dec 9, 2025
1ebd914
Add CI for Autoparallel experiment llama3 on 4 GPUs (#2105)
xmfan Dec 9, 2025
f1d41a1
Support rope cache indexing using positions (#2112)
acisseJZhong Dec 9, 2025
f3f2e8f
[forge] allow torchforges to set checkpoint base folder (#2131)
rakkit Dec 9, 2025
fbafd44
Rename auto_parallel experiment to autoparallel (#2128)
xmfan Dec 9, 2025
a632855
PyTorch depends on psutil (#2132)
fegin Dec 11, 2025
4389efd
Remove caching for attention masks (#2117)
wwwjn Dec 11, 2025
669845f
Clarify contribution guidelines. (#2134)
dcci Dec 12, 2025
fcc5643
Enable PP and EP overlap for MoE (#1721)
H-Huang Dec 12, 2025
7a398ea
Fix apply_compile called multiple times in PP initialization (#2135)
xmfan Dec 12, 2025
64dc922
Enable static type checking with Pyrefly (#2136)
rchen152 Dec 12, 2025
995154f
[Autoparallel] Add local_map variant of DSv3 and 2D mesh AP (#2129)
xmfan Dec 13, 2025
9bc50ea
Implement ciflow/rocm on Torchtitan (#2114)
akashveramd Dec 13, 2025
2aac20a
[MoE] Add node limited routing support (#2111)
shuhuayu Dec 14, 2025
c1f4e94
Upgrade GitHub Actions to latest versions (#2152)
salmanmkc Dec 14, 2025
f3748d8
Upgrade GitHub Actions for Node 24 compatibility (#2151)
salmanmkc Dec 14, 2025
c283a84
Improve the loss_compare.sh logic (#2143)
fegin Dec 15, 2025
64997d2
[GPT-OSS] Add HF state dict adapter to support loading from HF checkp…
shuhuayu Dec 15, 2025
c08fa57
Add local built pytorch path for pyrefly (#2155)
fegin Dec 15, 2025
e36d027
Run vLLM inference using torchtitan model definition (single GPU) (#2…
wwwjn Dec 16, 2025
f64bbad
[RELAND] Let CUDA and ROCm read different loss result (#2157)
fegin Dec 16, 2025
183a0d2
Use new DeviceMesh unflatten to rewrite parallel_dims (#1660)
fegin Dec 17, 2025
36a4b69
Integrate DeepEP to torchtitan (#2107)
elfiegg Dec 18, 2025
4438764
Fix pypa/gh-action-pypi-publish version to use SHA pinning (#2161)
salmanmkc Dec 19, 2025
fd49b4b
Upgrade GitHub Actions for Node 24 compatibility (#2164)
salmanmkc Dec 19, 2025
658f94c
Expose common dataloader args (#2097)
divyanshk Dec 19, 2025
b786a3d
Replace `logger.warn()` to `logger.warning()` , allow `log_validation…
EquationWalker Dec 19, 2025
b21555f
Add Dependabot for GitHub Actions updates (#2163)
salmanmkc Dec 19, 2025
1bd2548
Bump tj-actions/changed-files from d6e91a2266cdb9d62096cebf1e8546899c…
dependabot[bot] Dec 19, 2025
4b3d25a
Multiprocess simple RL loop (#2158)
acisseJZhong Dec 22, 2025
29aafb9
Fix qwen3 attention scaling calculation (#2173)
wwwjn Dec 23, 2025
a452121
Add rocm support for models, flux & torchft integration tests. (#2172)
akashveramd Dec 24, 2025
30ab580
[RL] Support Trainer and Generator Unified Model (#2174)
acisseJZhong Dec 24, 2025
a95d203
Support TP when using vLLM engine to run inference w/ torchtitan mode…
wwwjn Dec 26, 2025
5077be6
add safety checks for varlen (#2179)
liangel-02 Dec 26, 2025
64b5e15
Bump torchtitan version to v0.2.1 (#2180)
wwwjn Dec 26, 2025
81af883
Remove psutil as part of requirements (#2181)
wwwjn Dec 26, 2025
5dd9f4c
add attention scaling to varlen for qwen3 (#2178)
liangel-02 Dec 29, 2025
62f5806
make get tp mesh optional in llama4 parallelize (#2185)
danielvegamyhre Dec 29, 2025
7e4ab85
Add docs to explain COMM_MODE (#2162)
fegin Dec 30, 2025
e16af85
[docs] Fix missing --model.flavor flags in compiler_toolkit README (#…
BryanBradfo Jan 6, 2026
795a7a0
[GPT-OSS] Graduate from experiments to main (#2203)
shuhuayu Jan 6, 2026
9f211ec
[Compiler Toolkit] Add option for full inductor. (#2150)
aditvenk Jan 7, 2026
ec246c9
[autoparallel] Update local_map_deepseek_v3 device mesh usage (#2231)
xmfan Jan 14, 2026
c26ea60
Disable dynamo LRU cache when AC is enabled (#2204)
soulitzer Jan 14, 2026
6408426
Enable memory snapshot for generic devices (#2228)
frost-intel Jan 15, 2026
9240172
Add test for dsv3 with flexattn + fsdp + ep + pp + sac op (#2234)
shuhuayu Jan 15, 2026
5ef90fa
[lint] ignore all existing pyrefly errors (#2240)
xmfan Jan 15, 2026
1556971
[Experimental][rl][vllm compat] Update simple_rl example to work with…
Lucaskabela Jan 16, 2026
a085b0e
[Experimental][rl][unified] Update infer.py example to work with vLLM…
Lucaskabela Jan 19, 2026
09c6d74
fix sdpa-varlen attention mismatch in qwen3 (#2229)
francesco-bertolotti Jan 19, 2026
2a642d0
Update README with libnvshmem_host.so troubleshooting
dmahan93 Jan 20, 2026
a25dd8f
[ROCm] Support mxfp8 on gfx950. (#2222)
RuibinCheung Jan 20, 2026
8e5f859
Merge pull request #44 from NousResearch/deepep-install-readme-updates
jquesnelle Jan 20, 2026
7fde8b6
[Typing] Fix CI Typing Issues (#2245)
fegin Jan 20, 2026
42fd903
[Typing] Improve ModelProtocol typing (#2246)
fegin Jan 20, 2026
69cf207
[Typing] Remove deprecated enable_symm_mem_for_group (#2260)
fegin Jan 20, 2026
1e8f9ac
[CP] Refactor Context Parallel to use new PyTorch CP APIs (#2144)
fegin Jan 21, 2026
0a2107f
[CP] Enable FlexCP for llama3 (#2145)
fegin Jan 21, 2026
8ff9e42
[MoE] Fix experts DTensor metadata bug for dcp (#2227)
shuhuayu Jan 22, 2026
3263b15
Update GRPO.md
dmahan93 Jan 23, 2026
5621112
[varlen_attn] change is_causal to window_size (#2267)
liangel-02 Jan 23, 2026
81f5a5a
Add ROCm CI support for simple fsdp experiments test (#2220)
akashveramd Jan 24, 2026
a8ac852
Merge branch 'dev-updated-again' into upstream-2026-24-01
jquesnelle Jan 26, 2026
6d35673
Merge branch 'dev-updated-again' into upstream-2026-24-01
jquesnelle Jan 27, 2026
865ebb8
context parallel support in dsv3 and qwen3
jquesnelle Jan 28, 2026
2ad47cb
fast path for initing bfloat16 params on cpu
jquesnelle Jan 21, 2026
81e54a4
add reference for init scheme
jquesnelle Jan 22, 2026
f04236d
overlapped cpu offload muon
jquesnelle Jan 23, 2026
e7ccfdc
merge fixups
jquesnelle Jan 29, 2026
98f53ee
merge fixups
jquesnelle Jan 30, 2026
668f23e
Add memory tracking and BF16 optimizer state features with Kimi K2 co…
xrsrke Jan 31, 2026
4071454
Add NaN tracker config, FSDP prefetch control, and nvidia-smi memory …
xrsrke Jan 31, 2026
375762b
Add partial resharding support (fsdp_reshard_after_forward accepts int)
xrsrke Jan 31, 2026
0a06429
Add device mesh visualizer for distributed training debugging
xrsrke Jan 31, 2026
fe8d1f0
add option to filter data when preprocessing by a specific string
jquesnelle Feb 1, 2026
f50b804
add kimi_k2_sft
jquesnelle Feb 1, 2026
ed6b753
fix wrong arg used for --push-to-hub
jquesnelle Feb 1, 2026
7f6f3a3
fix attention args, add kimi_k2_ep64_cp1_seq24k_lbs1 160 tps config
xrsrke Feb 4, 2026
c3a14a1
Merge branch 'upstream-2026-24-01' of https://github.com/NousResearch…
xrsrke Feb 4, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .ci/docker/common/install_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ install_pip_dependencies() {
pip_install -r /opt/conda/requirements.txt
pip_install -r /opt/conda/requirements-flux.txt
pip_install -r /opt/conda/requirements-vlm.txt
pip_install -r /opt/conda/requirements-transformers-modeling-backend.txt
popd
}

Expand Down
1 change: 1 addition & 0 deletions .ci/docker/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ expecttest==0.1.6
pytest==7.3.2
pytest-cov
pre-commit
pyrefly==0.45.1
tomli-w >= 1.1.0
transformers
2 changes: 0 additions & 2 deletions .ci/docker/requirements-flux.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
transformers>=4.51.1
einops
sentencepiece
pillow
1 change: 1 addition & 0 deletions .ci/docker/requirements-transformers-modeling-backend.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
transformers==4.57.1
2 changes: 2 additions & 0 deletions .ci/docker/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ fsspec
tyro
tokenizers >= 0.15.0
safetensors
einops
pillow
1 change: 1 addition & 0 deletions .ci/docker/ubuntu/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ COPY requirements-dev.txt /opt/conda/
COPY requirements.txt /opt/conda/
COPY requirements-flux.txt /opt/conda/
COPY requirements-vlm.txt /opt/conda/
COPY requirements-transformers-modeling-backend.txt /opt/conda/
COPY conda-env-ci.txt /opt/conda/
COPY ./common/install_conda.sh install_conda.sh
COPY ./common/utils.sh utils.sh
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,7 @@ slurm-*

# env files
.env
.venv/

# Vibe coding
.claude
8 changes: 8 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,11 @@ repos:
args: ["--ignore-words-list=MIS"]
additional_dependencies:
- tomli

- repo: https://github.com/facebook/pyrefly-pre-commit
rev: 0.45.1
hooks:
- id: pyrefly-check
name: Pyrefly (type checking)
pass_filenames: false
language: system
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin

### Setup
```
pip install -r requirements-dev.txt
pip install -r requirements.txt -r requirements-dev.txt
```

### Pull Requests
Expand Down
17 changes: 17 additions & 0 deletions GRPO.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ GRPO instructions

## Installation instructions
```shell
mkdir logs
chmod g+rw ./logs
pip install uv
uv pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu129
uv pip install -r requirements.txt
Expand All @@ -12,8 +14,23 @@ export VLLM_COMMIT=2918c1b49c88c29783c86f78d2c4221cb9622379
uv pip install vllm torch==2.9.0 --torch-backend=cu129 --prerelease=allow --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} --extra-index-url https://download.pytorch.org/whl/cu129
pip install flashinfer-python==0.4.1 flashinfer-cubin==0.4.1
pip install flashinfer-jit-cache==0.4.1 --index-url https://flashinfer.ai/whl/cu129
pip install transformers==4.57.1
```

## Configuration instructions

see `torchtitan/grop/configs/qwen25-7b-math.toml` for good initial values

## sbatch script

`online_multinode_vllm.slurm` contains some paths to edit,
- TRAIN_PATH - where this is installed on the cluster
- TRAIN_ENV - if you don't init the venv to .venv, this needs to be changed to that venv
- VLLM_ENV - same as TRAIN_ENV unless you're doing something different
- API_ENV - atropos venv

One that's done, you can do something like
```bash
sbatch --export=ALL,CONFIG_FILE=/home/dakota/github/torchtitan/torchtitan/grpo/configs/qwen25-7b-math.toml,MODEL_NAME=Qwen/Qwen2.5-7B,PYTHON_SCRIPT=/home/dakota/github/atropos/environments/math_server_zero.py,WANDB_PROJECT=qwen7b_debug online_multinode_vllm.slurm
```
to launch a run
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ The Guiding Principles when building `torchtitan`
* Minimal changes to the model code when applying multi-dimensional parallelism.
* Bias towards a clean, minimal codebase while providing basic reusable / swappable components.

`torchtitan` has been showcasing PyTorch's latest distributed training features, via pretraining Llama 3.1 LLMs of various sizes.
To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. We look forward to your contributions!
`torchtitan` has been showcasing PyTorch's latest distributed training features, via support for pretraining Llama 3.1 LLMs of various sizes.

## Contributing

We look forward to your contributions!

* To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. New ideas should start there. To contribute, follow the [`experiments guidelines`](torchtitan/experiments/README.md).
* For fixes and contributions to core, follow these [`guidelines`](CONTRIBUTING.md).

## Llama 3.1 training

Expand All @@ -59,6 +64,7 @@ To accelerate contributions to and innovations around torchtitan, we host an [`e
- [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning
5. `torch.compile` support
6. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
7. [MXFP8 training for dense and MoE models](docs/mxfp8.md) on Blackwell GPUs.
7. DDP and HSDP
8. [TorchFT](https://github.com/pytorch/torchft) integration
9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md)
Expand Down
Binary file added assets/images/mxfp8_with_loss.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion assets/version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.2.0
0.2.1
2 changes: 1 addition & 1 deletion docs/checkpoint.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE=<path_to_model_config> ./run_train.sh --checkpoint.enable --c
### HuggingFace
`torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu.

1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.
1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set.

2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt.

Expand Down
63 changes: 63 additions & 0 deletions docs/debugging.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,69 @@ python -m torchtitan.config.manager --help

This will print a structured configuration to `stdout`, allowing you to verify that overrides are being applied correctly.

## Communication Mode (COMM_MODE) for Debugging

The `COMM_MODE` environment variable provides specialized debugging modes that allow you to test and validate your training setup without requiring full multi-GPU distributed execution. This is particularly useful for rapid iteration during development and debugging.

### Available Modes

#### 1. `fake_backend` - Configuration Validation Mode

This mode enables dry-run validation of your configuration, model setup, and rank-0 program logic without actual distributed communication:

```bash
NGPU=32 COMM_MODE="fake_backend" ./run_train.sh
```

**What it does:**
- Uses fake process groups that simulate distributed communication without actual data transfer
- Runs on a single GPU without `torchrun` or NCCL initialization
- Validates configuration parsing, model initialization, and overall training workflow
- Executes only one training step by default

**When to use it:**
- Quick validation of configuration files before launching expensive multi-GPU jobs
- Debugging training and parallelism logic that doesn't require actual communication. Note that No data-dependent logic should be validated with "fake_backend".

**Example use case:**
```bash
# Validate a 128-GPU configuration on a single GPU
NGPU=128 COMM_MODE="fake_backend" CONFIG_FILE="./train_configs/llama3_70b.toml" ./run_train.sh
```

#### 2. `local_tensor` - Single-GPU Distributed Simulation

This mode simulates the full distributed training workflow on a single GPU by executing all communication and computation locally:

```bash
NGPU=32 COMM_MODE="local_tensor" ./run_train.sh
```

**What it does:**
- Simulates multi-GPU behavior on a single shared GPU
- Executes all collectives (all-reduce, all-gather, etc.) locally without network communication
- Maintains the same code paths as distributed training for accurate debugging
- Runs only one training step by default

**When to use it:**
- Debugging distributed training logic (FSDP, TP, PP, CP, EP) with data dependencies without multi-GPU setup. Note that local tensor doesn't support FSDP2 but should support SimpleFSDP.
- Verifying correctness of parallelism strategies locally
- Testing gradient synchronization and communication patterns
- Reproducing distributed training bugs in a simplified environment

**Example use case:**
```bash
# Debug 8-way TP + 2-way FSDP on a single GPU
NGPU=16 COMM_MODE="local_tensor" ./run_train.sh \
--parallelism.tensor_parallel_degree 8 \
--parallelism.data_parallel_shard_degree 2
```

### Limitations

- **Performance testing**: Neither mode provides accurate performance metrics; use actual distributed runs for benchmarking
- **Memory requirement**: Local tensor runs require more memory on a single GPU than the actual distributed runs

## Troubleshooting jobs that timeout

If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs.
Expand Down
190 changes: 190 additions & 0 deletions docs/mxfp8.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
## MXFP8 Training on B200 GPUs

MXFP8 training can provide substantial training speedups for models where the majority of GEMMs are sufficiently large. MXFP8 is a microscaling format from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) that uses block-based scaling to maintain numerical accuracy while leveraging low-precision tensor cores. On NVIDIA B200 GPUs, MXFP8 training achieves up to **28% speedup** over bfloat16 baseline with minimal accuracy degradation.

> **📖 For a comprehensive case study of using TorchTitan MXFP8 to train dense models at scale**, see our blog post: [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/)

### Table of Contents

- [Requirements](#requirements)
- [How MXFP8 Works](#how-mxfp8-works)
- [MXFP8 for Linear Modules](#mxfp8-for-linear-modules)
- [Usage](#usage)
- [MXFP8 for Grouped GEMMs (MoE)](#mxfp8-for-grouped-gemms-moe)
- [Usage](#usage-1)
- [Example TOML Configuration](#example-toml-configuration)
- [Performance](#performance)
- [Dense Models](#dense-models)
- [MoE models](#moe-models)
- [Composability](#composability)
- [Known Limitations](#known-limitations)
- [Additional Resources](#additional-resources)

### Requirements

- NVIDIA B200 (SM100 or SM100a)
- PyTorch nightly
- TorchAO v0.14.0 or newer ([TorchAO Installation Guide](https://github.com/pytorch/ao#installation))

Note: GB200 is also supported but requires building torchao from source (see installation guide above).

### How MXFP8 Works

MXFP8 differs from standard Float8 training in its scaling approach:

- **Granular scaling factor**: Instead of using a single scale factor per tensor (tensorwise) or per row/column (rowwise), MXFP8 uses a more granular, block-based scaling with a default block size of 1x32 elements. Each block of 32 elements shares a common scale factor. The data dtype is `torch.float8_e4m3fn`, and the scale factor dtype is `torch.float8_e8mfnu`.
- **Native hardware support**: On NVIDIA B200 (Blackwell) GPUs, MXFP8 GEMMs and Grouped GEMMs are accelerated using cuBLAS and CUTLASS kernels exposed via `torch._scaled_mm` and `torch._scaled_grouped_mm`, achieving up to 2x speedup over bfloat16 on common shapes.
- **Dynamic quantization**: For every MXFP8 Linear or Grouped GEMM, activations and weights are dynamically quantized to MXFP8, then a MXFP8 GEMM/Grouped GEMM is performed, resulting in a net speedup.

### MXFP8 for Linear Modules

#### Usage

To enable MXFP8 training for linear layers, launch your training job with the following command (or alternatively set configs in toml files):

```bash
CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \
--model.converters="quantize.linear.mx" \
--quantize.linear.mx.recipe_name="mxfp8_cublas" \
--compile.enable
```

**Configuration Options:**

* `--model.converters="quantize.linear.mx"`: Swap `nn.Linear` with `MXLinear` to perform MXFP8 matmul.
* `--quantize.linear.mx.recipe_name="mxfp8_cublas"`: Use the cuBLAS-based MXFP8 recipe for best performance on B200 GPUs. Alternative: `"mxfp8_cublas_rceil"` uses round-ceiling mode for scale calculation.
* `--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="triton"`: Choose the kernel for dimension-1 quantization. Options: `"triton"` (default), `"cuda"`, or `"torch"`.
* `--quantize.linear.mx.filter_fqns="..."` (optional): Comma-separated list of fully qualified names of modules not to convert to MXFP8 training.
* Example: `--quantize.linear.mx.filter_fqns="attention.wq,attention.wk,attention.wv,output"`
* This allows you to selectively apply MXFP8 only to layers that will benefit from it.
* `--compile.enable` (required for competitive performance): Use `torch.compile` to fuse the MXFP8 scaling/casting kernels.

**Hardware Requirements:**

MXFP8 training requires NVIDIA B200 (SM100) or newer GPUs.

### MXFP8 for Grouped GEMMs (MoE)

For Mixture-of-Experts (MoE) models, MXFP8 can accelerate the expert computation through dynamically quantized grouped GEMMs.

#### Usage

To enable MXFP8 for MoE expert layers:

```bash
CONFIG_FILE="./torchtitan/models/llama4/train_configs/llama4_17bx16e.toml" ./run_train.sh \
--model.converters="quantize.grouped_mm.mx" \
--quantize.grouped_mm.mx.fqns="experts" \
--quantize.grouped_mm.mx.recipe_name="mxfp8" \
--compile.enable \
--model.print_after_conversion
```

**Combined usage**: You can use MXFP8 for both linear modules and grouped GEMMs simultaneously by specifying both converters:
```bash
--model.converters="quantize.linear.mx,quantize.grouped_mm.mx"
```

**Configuration Options:**

* `--model.converters="quantize.grouped_mm.mx"`: Enable MXFP8 grouped GEMM conversion for MoE layers.
* `--quantize.grouped_mm.mx.fqns="experts"`: Comma-separated list of fully qualified names of MoE modules to apply MXFP8 dynamic quantization on grouped GEMM operations. Any module that matches the FQN will be converted, if it has (1) experts represented as 3d nn.Parameter instances (which is the case for TorchTitan MoEs), and (2) a `torch._grouped_mm` op performs the actual routed expert computation using those 3d expert weights.
* You can specify multiple FQNs to target different MoE layers in your model.
* `--quantize.grouped_mm.mx.recipe_name="mxfp8"`: Quantization recipe for grouped GEMMs (currently only `"mxfp8"` is supported).
* `--compile.enable`: Use `torch.compile` for best performance.

**Important Notes:**

* **Token group alignment**: For MoE training with MXFP8, token group sizes must be multiples of 32 (the MXFP8 block size). This is automatically configured [here](https://github.com/pytorch/torchtitan/blob/b39377f9fe33865fefb9bf64a33f6d74a598be87/torchtitan/components/quantization/mx.py#L131) when you enable MXFP8 grouped GEMMs in TorchTitan.

* **torch.compile recommendation**: All benchmarks in this document were run with `torch.compile` enabled. We recommend using `torch.compile` for best performance.

### Example TOML Configuration

Here's an example configuration for MXFP8 training in a TOML file:

```toml
[model]
converters = ["quantize.linear.mx", "quantize.grouped_mm.mx"]

[quantize.linear.mx]
recipe_name = "mxfp8_cublas"
mxfp8_dim1_cast_kernel_choice = "cuda"
filter_fqns = ["output", "router.gate"]

[quantize.grouped_mm.mx]
recipe_name = "mxfp8"
fqns = ["experts"]

[compile]
enable = true
components = ["model"]
```

### Performance

#### Dense Models

Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, torch.compile, FSDP2, per-op SAC:

| Scaling Method | Peak Memory (GB) | Median tokens/s | Speedup over BF16 |
|------------------------|------------------|-----------------|-------------------|
| None (bfloat16) | 33.71 | 8307.5 | - |
| mxfp8_cublas | 33.88 | 9969.0 | +20.0% |
| mxfp8_cublas_rceil | 33.88 | 9642.0 | +16.1% |
| float8 tensorwise | 33.38 | 10417.0 | +25.4% |

- pytorch version: `2.9.0.dev20250815+cu128`
- torchao version: `0.13.0+gite4e681be`
- torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30`

*Source: [TorchAO MX Formats Benchmarks](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#training-e2e-benchmarks-on-nvidia-b200)*

#### MoE models

512 GPU training on 64 node GB200 cluster:

| Scaling Method | Median tokens/s | Speedup over BF16 |
|------------------------|-----------------|-------------------|
| None (bfloat16) | 6169 | - |
| mxfp8 | 7401 | +20.3% |

Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/).

![MXFP8 vs BF16 Training Loss Curves](../assets/images/mxfp8_with_loss.png)

*Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.*

Training and model configurations for this run:
- Model: Llama4 Scout
- Dataset: C4
- Sequence length: 8192
- Local batch size: 10
- Learning rate: 1e-4
- LR scheduler warmup steps: 2000
- Parallelisms (64 nodes of 4 devices each = 256 chips):
- FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts)
- EP=16 (on routed experts)
- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it).
- `torch.compile` enabled
- `mxfp8` applied to routed experts computation (grouped GEMMs)
- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8)

### Composability
For distributed training, MXFP8 is compatible with:
- `torch.compile`
- FSDP2/TP/EP/PP
- Full activation checkpointing

All distributed communication for MXFP8 training is currently done in high precision.

### Known Limitations
- Currently in prototype stage - no BC guarantees.
- Requires torch nightly - important bug fixes have landed since 2.9.1
- For GB200s, requires building torchao from source

### Additional Resources

- [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) - Blog post on accelerating dense model training with MXFP8
- [TorchAO MX Formats Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats)
- [TorchAO MoE Training Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training)
Loading