diff --git a/.ci/docker/common/install_conda.sh b/.ci/docker/common/install_conda.sh index c2f316b04b..f47dc4cb8f 100755 --- a/.ci/docker/common/install_conda.sh +++ b/.ci/docker/common/install_conda.sh @@ -43,6 +43,7 @@ install_pip_dependencies() { pip_install -r /opt/conda/requirements.txt pip_install -r /opt/conda/requirements-flux.txt pip_install -r /opt/conda/requirements-vlm.txt + pip_install -r /opt/conda/requirements-transformers-modeling-backend.txt popd } diff --git a/.ci/docker/requirements-dev.txt b/.ci/docker/requirements-dev.txt index 6d53b2f817..0e5a6e491c 100644 --- a/.ci/docker/requirements-dev.txt +++ b/.ci/docker/requirements-dev.txt @@ -2,5 +2,6 @@ expecttest==0.1.6 pytest==7.3.2 pytest-cov pre-commit +pyrefly==0.45.1 tomli-w >= 1.1.0 transformers diff --git a/.ci/docker/requirements-flux.txt b/.ci/docker/requirements-flux.txt index daefd67ff0..8d6797a36b 100644 --- a/.ci/docker/requirements-flux.txt +++ b/.ci/docker/requirements-flux.txt @@ -1,4 +1,2 @@ transformers>=4.51.1 -einops sentencepiece -pillow diff --git a/.ci/docker/requirements-transformers-modeling-backend.txt b/.ci/docker/requirements-transformers-modeling-backend.txt new file mode 100644 index 0000000000..76e8886ed0 --- /dev/null +++ b/.ci/docker/requirements-transformers-modeling-backend.txt @@ -0,0 +1 @@ +transformers==4.57.1 diff --git a/.ci/docker/requirements.txt b/.ci/docker/requirements.txt index 9bf30b502c..5925bfd1d3 100644 --- a/.ci/docker/requirements.txt +++ b/.ci/docker/requirements.txt @@ -8,3 +8,5 @@ fsspec tyro tokenizers >= 0.15.0 safetensors +einops +pillow diff --git a/.ci/docker/ubuntu/Dockerfile b/.ci/docker/ubuntu/Dockerfile index baaca85824..dfb753c1f4 100644 --- a/.ci/docker/ubuntu/Dockerfile +++ b/.ci/docker/ubuntu/Dockerfile @@ -33,6 +33,7 @@ COPY requirements-dev.txt /opt/conda/ COPY requirements.txt /opt/conda/ COPY requirements-flux.txt /opt/conda/ COPY requirements-vlm.txt /opt/conda/ +COPY requirements-transformers-modeling-backend.txt /opt/conda/ COPY conda-env-ci.txt /opt/conda/ COPY ./common/install_conda.sh install_conda.sh COPY ./common/utils.sh utils.sh diff --git a/.gitignore b/.gitignore index bd9969f16a..e396e3418b 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,7 @@ slurm-* # env files .env +.venv/ + +# Vibe coding +.claude diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3dac48ec83..ca66e1f2c9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -53,3 +53,11 @@ repos: args: ["--ignore-words-list=MIS"] additional_dependencies: - tomli + +- repo: https://github.com/facebook/pyrefly-pre-commit + rev: 0.45.1 + hooks: + - id: pyrefly-check + name: Pyrefly (type checking) + pass_filenames: false + language: system diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 8de2b9df9d..de6373236a 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -4,7 +4,7 @@ possible. Contributions should follow the [Contributing Guidelines](#contributin ### Setup ``` -pip install -r requirements-dev.txt +pip install -r requirements.txt -r requirements-dev.txt ``` ### Pull Requests diff --git a/GRPO.md b/GRPO.md index 93653052fa..9c3e5a9c1b 100644 --- a/GRPO.md +++ b/GRPO.md @@ -4,6 +4,8 @@ GRPO instructions ## Installation instructions ```shell +mkdir logs +chmod g+rw ./logs pip install uv uv pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu129 uv pip install -r requirements.txt @@ -12,8 +14,23 @@ export VLLM_COMMIT=2918c1b49c88c29783c86f78d2c4221cb9622379 uv pip install vllm torch==2.9.0 --torch-backend=cu129 --prerelease=allow --extra-index-url https://wheels.vllm.ai/${VLLM_COMMIT} --extra-index-url https://download.pytorch.org/whl/cu129 pip install flashinfer-python==0.4.1 flashinfer-cubin==0.4.1 pip install flashinfer-jit-cache==0.4.1 --index-url https://flashinfer.ai/whl/cu129 +pip install transformers==4.57.1 ``` ## Configuration instructions see `torchtitan/grop/configs/qwen25-7b-math.toml` for good initial values + +## sbatch script + +`online_multinode_vllm.slurm` contains some paths to edit, +- TRAIN_PATH - where this is installed on the cluster +- TRAIN_ENV - if you don't init the venv to .venv, this needs to be changed to that venv +- VLLM_ENV - same as TRAIN_ENV unless you're doing something different +- API_ENV - atropos venv + +One that's done, you can do something like +```bash +sbatch --export=ALL,CONFIG_FILE=/home/dakota/github/torchtitan/torchtitan/grpo/configs/qwen25-7b-math.toml,MODEL_NAME=Qwen/Qwen2.5-7B,PYTHON_SCRIPT=/home/dakota/github/atropos/environments/math_server_zero.py,WANDB_PROJECT=qwen7b_debug online_multinode_vllm.slurm +``` +to launch a run diff --git a/README.md b/README.md index b9292c5e63..0fa3ae0466 100644 --- a/README.md +++ b/README.md @@ -40,9 +40,14 @@ The Guiding Principles when building `torchtitan` * Minimal changes to the model code when applying multi-dimensional parallelism. * Bias towards a clean, minimal codebase while providing basic reusable / swappable components. -`torchtitan` has been showcasing PyTorch's latest distributed training features, via pretraining Llama 3.1 LLMs of various sizes. -To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. We look forward to your contributions! +`torchtitan` has been showcasing PyTorch's latest distributed training features, via support for pretraining Llama 3.1 LLMs of various sizes. +## Contributing + +We look forward to your contributions! + +* To accelerate contributions to and innovations around torchtitan, we host an [`experiments`](torchtitan/experiments) folder. New ideas should start there. To contribute, follow the [`experiments guidelines`](torchtitan/experiments/README.md). +* For fixes and contributions to core, follow these [`guidelines`](CONTRIBUTING.md). ## Llama 3.1 training @@ -59,6 +64,7 @@ To accelerate contributions to and innovations around torchtitan, we host an [`e - [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning 5. `torch.compile` support 6. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md)) +7. [MXFP8 training for dense and MoE models](docs/mxfp8.md) on Blackwell GPUs. 7. DDP and HSDP 8. [TorchFT](https://github.com/pytorch/torchft) integration 9. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries) and support for [custom datasets](docs/datasets.md) diff --git a/assets/images/mxfp8_with_loss.png b/assets/images/mxfp8_with_loss.png new file mode 100644 index 0000000000..47e2967aed Binary files /dev/null and b/assets/images/mxfp8_with_loss.png differ diff --git a/assets/version.txt b/assets/version.txt index 0ea3a944b3..0c62199f16 100644 --- a/assets/version.txt +++ b/assets/version.txt @@ -1 +1 @@ -0.2.0 +0.2.1 diff --git a/docs/checkpoint.md b/docs/checkpoint.md index 6e3112309b..8aca58eb06 100644 --- a/docs/checkpoint.md +++ b/docs/checkpoint.md @@ -68,7 +68,7 @@ NGPU=1 CONFIG_FILE= ./run_train.sh --checkpoint.enable --c ### HuggingFace `torchtitan` offers two ways to work with Hugging Face models: either by directly saving and loading a Hugging Face checkpoint during training, or by using an example conversion script to directly reformat the model weights on cpu. -1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_safetensors_format` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. +1. You can directly save huggingface model weights during training by using the `--checkpoint.last_save_in_hf` and `--checkpoint.last_save_model_only` options together. To directly load a `torchtitan` training session from a huggingface safetensors file, enable `--checkpoint.initial_load_in_hf`, and set either `--model.hf_assets_path` or `--checkpoint.initial_load_path` to the directory containing the huggingface checkpoint. `--checkpoint.initial_load_path` overrides `--model.hf_assets_path` if both are set. 2. To directly reformat the weights without the need to run a training loop, run the corresponding conversion script. The naming scheme is `torchtitan`-centric, e.g. convert_from_hf means convert hf->tt. diff --git a/docs/debugging.md b/docs/debugging.md index 4deb20bbac..7a14606b51 100644 --- a/docs/debugging.md +++ b/docs/debugging.md @@ -54,6 +54,69 @@ python -m torchtitan.config.manager --help This will print a structured configuration to `stdout`, allowing you to verify that overrides are being applied correctly. +## Communication Mode (COMM_MODE) for Debugging + +The `COMM_MODE` environment variable provides specialized debugging modes that allow you to test and validate your training setup without requiring full multi-GPU distributed execution. This is particularly useful for rapid iteration during development and debugging. + +### Available Modes + +#### 1. `fake_backend` - Configuration Validation Mode + +This mode enables dry-run validation of your configuration, model setup, and rank-0 program logic without actual distributed communication: + +```bash +NGPU=32 COMM_MODE="fake_backend" ./run_train.sh +``` + +**What it does:** +- Uses fake process groups that simulate distributed communication without actual data transfer +- Runs on a single GPU without `torchrun` or NCCL initialization +- Validates configuration parsing, model initialization, and overall training workflow +- Executes only one training step by default + +**When to use it:** +- Quick validation of configuration files before launching expensive multi-GPU jobs +- Debugging training and parallelism logic that doesn't require actual communication. Note that No data-dependent logic should be validated with "fake_backend". + +**Example use case:** +```bash +# Validate a 128-GPU configuration on a single GPU +NGPU=128 COMM_MODE="fake_backend" CONFIG_FILE="./train_configs/llama3_70b.toml" ./run_train.sh +``` + +#### 2. `local_tensor` - Single-GPU Distributed Simulation + +This mode simulates the full distributed training workflow on a single GPU by executing all communication and computation locally: + +```bash +NGPU=32 COMM_MODE="local_tensor" ./run_train.sh +``` + +**What it does:** +- Simulates multi-GPU behavior on a single shared GPU +- Executes all collectives (all-reduce, all-gather, etc.) locally without network communication +- Maintains the same code paths as distributed training for accurate debugging +- Runs only one training step by default + +**When to use it:** +- Debugging distributed training logic (FSDP, TP, PP, CP, EP) with data dependencies without multi-GPU setup. Note that local tensor doesn't support FSDP2 but should support SimpleFSDP. +- Verifying correctness of parallelism strategies locally +- Testing gradient synchronization and communication patterns +- Reproducing distributed training bugs in a simplified environment + +**Example use case:** +```bash +# Debug 8-way TP + 2-way FSDP on a single GPU +NGPU=16 COMM_MODE="local_tensor" ./run_train.sh \ + --parallelism.tensor_parallel_degree 8 \ + --parallelism.data_parallel_shard_degree 2 +``` + +### Limitations + +- **Performance testing**: Neither mode provides accurate performance metrics; use actual distributed runs for benchmarking +- **Memory requirement**: Local tensor runs require more memory on a single GPU than the actual distributed runs + ## Troubleshooting jobs that timeout If you encounter jobs that timeout, you'll need to debug them to identify the root cause. To help with this process, we've enabled Flight Recorder, a tool that continuously collects diagnostic information about your jobs. diff --git a/docs/mxfp8.md b/docs/mxfp8.md new file mode 100644 index 0000000000..ad9f62ee3c --- /dev/null +++ b/docs/mxfp8.md @@ -0,0 +1,190 @@ +## MXFP8 Training on B200 GPUs + +MXFP8 training can provide substantial training speedups for models where the majority of GEMMs are sufficiently large. MXFP8 is a microscaling format from the [MX OCP spec](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) that uses block-based scaling to maintain numerical accuracy while leveraging low-precision tensor cores. On NVIDIA B200 GPUs, MXFP8 training achieves up to **28% speedup** over bfloat16 baseline with minimal accuracy degradation. + +> **📖 For a comprehensive case study of using TorchTitan MXFP8 to train dense models at scale**, see our blog post: [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) + +### Table of Contents + +- [Requirements](#requirements) +- [How MXFP8 Works](#how-mxfp8-works) +- [MXFP8 for Linear Modules](#mxfp8-for-linear-modules) + - [Usage](#usage) +- [MXFP8 for Grouped GEMMs (MoE)](#mxfp8-for-grouped-gemms-moe) + - [Usage](#usage-1) +- [Example TOML Configuration](#example-toml-configuration) +- [Performance](#performance) + - [Dense Models](#dense-models) + - [MoE models](#moe-models) +- [Composability](#composability) +- [Known Limitations](#known-limitations) +- [Additional Resources](#additional-resources) + +### Requirements + +- NVIDIA B200 (SM100 or SM100a) +- PyTorch nightly +- TorchAO v0.14.0 or newer ([TorchAO Installation Guide](https://github.com/pytorch/ao#installation)) + +Note: GB200 is also supported but requires building torchao from source (see installation guide above). + +### How MXFP8 Works + +MXFP8 differs from standard Float8 training in its scaling approach: + +- **Granular scaling factor**: Instead of using a single scale factor per tensor (tensorwise) or per row/column (rowwise), MXFP8 uses a more granular, block-based scaling with a default block size of 1x32 elements. Each block of 32 elements shares a common scale factor. The data dtype is `torch.float8_e4m3fn`, and the scale factor dtype is `torch.float8_e8mfnu`. +- **Native hardware support**: On NVIDIA B200 (Blackwell) GPUs, MXFP8 GEMMs and Grouped GEMMs are accelerated using cuBLAS and CUTLASS kernels exposed via `torch._scaled_mm` and `torch._scaled_grouped_mm`, achieving up to 2x speedup over bfloat16 on common shapes. +- **Dynamic quantization**: For every MXFP8 Linear or Grouped GEMM, activations and weights are dynamically quantized to MXFP8, then a MXFP8 GEMM/Grouped GEMM is performed, resulting in a net speedup. + +### MXFP8 for Linear Modules + +#### Usage + +To enable MXFP8 training for linear layers, launch your training job with the following command (or alternatively set configs in toml files): + +```bash +CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_train.sh \ + --model.converters="quantize.linear.mx" \ + --quantize.linear.mx.recipe_name="mxfp8_cublas" \ + --compile.enable +``` + +**Configuration Options:** + +* `--model.converters="quantize.linear.mx"`: Swap `nn.Linear` with `MXLinear` to perform MXFP8 matmul. +* `--quantize.linear.mx.recipe_name="mxfp8_cublas"`: Use the cuBLAS-based MXFP8 recipe for best performance on B200 GPUs. Alternative: `"mxfp8_cublas_rceil"` uses round-ceiling mode for scale calculation. +* `--quantize.linear.mx.mxfp8_dim1_cast_kernel_choice="triton"`: Choose the kernel for dimension-1 quantization. Options: `"triton"` (default), `"cuda"`, or `"torch"`. +* `--quantize.linear.mx.filter_fqns="..."` (optional): Comma-separated list of fully qualified names of modules not to convert to MXFP8 training. + * Example: `--quantize.linear.mx.filter_fqns="attention.wq,attention.wk,attention.wv,output"` + * This allows you to selectively apply MXFP8 only to layers that will benefit from it. +* `--compile.enable` (required for competitive performance): Use `torch.compile` to fuse the MXFP8 scaling/casting kernels. + +**Hardware Requirements:** + +MXFP8 training requires NVIDIA B200 (SM100) or newer GPUs. + +### MXFP8 for Grouped GEMMs (MoE) + +For Mixture-of-Experts (MoE) models, MXFP8 can accelerate the expert computation through dynamically quantized grouped GEMMs. + +#### Usage + +To enable MXFP8 for MoE expert layers: + +```bash +CONFIG_FILE="./torchtitan/models/llama4/train_configs/llama4_17bx16e.toml" ./run_train.sh \ + --model.converters="quantize.grouped_mm.mx" \ + --quantize.grouped_mm.mx.fqns="experts" \ + --quantize.grouped_mm.mx.recipe_name="mxfp8" \ + --compile.enable \ + --model.print_after_conversion +``` + +**Combined usage**: You can use MXFP8 for both linear modules and grouped GEMMs simultaneously by specifying both converters: + ```bash + --model.converters="quantize.linear.mx,quantize.grouped_mm.mx" + ``` + +**Configuration Options:** + +* `--model.converters="quantize.grouped_mm.mx"`: Enable MXFP8 grouped GEMM conversion for MoE layers. +* `--quantize.grouped_mm.mx.fqns="experts"`: Comma-separated list of fully qualified names of MoE modules to apply MXFP8 dynamic quantization on grouped GEMM operations. Any module that matches the FQN will be converted, if it has (1) experts represented as 3d nn.Parameter instances (which is the case for TorchTitan MoEs), and (2) a `torch._grouped_mm` op performs the actual routed expert computation using those 3d expert weights. + * You can specify multiple FQNs to target different MoE layers in your model. +* `--quantize.grouped_mm.mx.recipe_name="mxfp8"`: Quantization recipe for grouped GEMMs (currently only `"mxfp8"` is supported). +* `--compile.enable`: Use `torch.compile` for best performance. + +**Important Notes:** + +* **Token group alignment**: For MoE training with MXFP8, token group sizes must be multiples of 32 (the MXFP8 block size). This is automatically configured [here](https://github.com/pytorch/torchtitan/blob/b39377f9fe33865fefb9bf64a33f6d74a598be87/torchtitan/components/quantization/mx.py#L131) when you enable MXFP8 grouped GEMMs in TorchTitan. + +* **torch.compile recommendation**: All benchmarks in this document were run with `torch.compile` enabled. We recommend using `torch.compile` for best performance. + +### Example TOML Configuration + +Here's an example configuration for MXFP8 training in a TOML file: + +```toml +[model] +converters = ["quantize.linear.mx", "quantize.grouped_mm.mx"] + +[quantize.linear.mx] +recipe_name = "mxfp8_cublas" +mxfp8_dim1_cast_kernel_choice = "cuda" +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.mx] +recipe_name = "mxfp8" +fqns = ["experts"] + +[compile] +enable = true +components = ["model"] +``` + +### Performance + +#### Dense Models + +Single-node training on 8x power limited B200 GPUs, batch size 1, sequence length 8192, steps 100, torch.compile, FSDP2, per-op SAC: + +| Scaling Method | Peak Memory (GB) | Median tokens/s | Speedup over BF16 | +|------------------------|------------------|-----------------|-------------------| +| None (bfloat16) | 33.71 | 8307.5 | - | +| mxfp8_cublas | 33.88 | 9969.0 | +20.0% | +| mxfp8_cublas_rceil | 33.88 | 9642.0 | +16.1% | +| float8 tensorwise | 33.38 | 10417.0 | +25.4% | + +- pytorch version: `2.9.0.dev20250815+cu128` +- torchao version: `0.13.0+gite4e681be` +- torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30` + +*Source: [TorchAO MX Formats Benchmarks](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats#training-e2e-benchmarks-on-nvidia-b200)* + +#### MoE models + +512 GPU training on 64 node GB200 cluster: + +| Scaling Method | Median tokens/s | Speedup over BF16 | +|------------------------|-----------------|-------------------| +| None (bfloat16) | 6169 | - | +| mxfp8 | 7401 | +20.3% | + +Training runs on 64 node GB200 cluster with TorchTitan Llama4 Scout show that MXFP8 MoE training has equivalent convergence to bfloat16 training baseline. In fact, after 3,000 steps it finishes with slightly *lower* loss than bfloat16! This is consistent with our scaling experiments with [MXFP8 training for dense models](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/). + +![MXFP8 vs BF16 Training Loss Curves](../assets/images/mxfp8_with_loss.png) + +*Training loss curves over 3,000 steps showing MXFP8 achieves equivalent convergence to bfloat16 baseline.* + +Training and model configurations for this run: +- Model: Llama4 Scout +- Dataset: C4 +- Sequence length: 8192 +- Local batch size: 10 +- Learning rate: 1e-4 +- LR scheduler warmup steps: 2000 +- Parallelisms (64 nodes of 4 devices each = 256 chips): + - FSDP=256 (on attention layers, shared experts, dense layer FFNs) and 256/4=64 (on routed experts) + - EP=16 (on routed experts) +- Activation checkpointing mode: `none` (ideally this should use selective per op AC but there was a bug at the time preventing us from using it). +- `torch.compile` enabled +- `mxfp8` applied to routed experts computation (grouped GEMMs) +- `mxfp8` applied to all linear layers except: `output`, `router.gate`, `attention.wk`, `attention.wv` (Wk and Wv too small to benefit from mxfp8) + +### Composability +For distributed training, MXFP8 is compatible with: +- `torch.compile` +- FSDP2/TP/EP/PP +- Full activation checkpointing + +All distributed communication for MXFP8 training is currently done in high precision. + +### Known Limitations +- Currently in prototype stage - no BC guarantees. +- Requires torch nightly - important bug fixes have landed since 2.9.1 +- For GB200s, requires building torchao from source + +### Additional Resources + +- [Accelerating 2K+ Scale Pre-training up to 1.28x with TorchAO MXFP8 and TorchTitan on Crusoe B200 Cluster](https://pytorch.org/blog/accelerating-2k-scale-pre-training-up-to-1-28x-with-torchao-mxfp8-and-torchtitan-on-crusoe-b200-cluster/) - Blog post on accelerating dense model training with MXFP8 +- [TorchAO MX Formats Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/mx_formats) +- [TorchAO MoE Training Documentation](https://github.com/pytorch/ao/tree/main/torchao/prototype/moe_training) diff --git a/pyproject.toml b/pyproject.toml index ed47cd11fc..f26314671c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,6 +24,8 @@ dependencies = [ "fsspec", "tyro", "tensorboard", + "einops", + "pillow", ] dynamic = ["version"] @@ -63,3 +65,8 @@ include = ["torchtitan*"] [tool.pytest.ini_options] addopts = ["--showlocals"] # show local variables in tracebacks testpaths = ["tests"] + +[tool.pyrefly] +project-excludes = ["torchtitan/experiments", "**/tests/**"] +ignore-missing-imports = ["torchao.*", "torchft"] # optional dependencies +search-path = ["../pytorch"] # local built pytorch diff --git a/qwen3_30b_a3b_memory_test.toml b/qwen3_30b_a3b_memory_test.toml new file mode 100644 index 0000000000..aba127bbd9 --- /dev/null +++ b/qwen3_30b_a3b_memory_test.toml @@ -0,0 +1,78 @@ +# Qwen3 30B-A3B with memory tracking features enabled +# Tests: detailed memory tracking, aggressive memory manager, bf16 optimizer states +# Reduced settings to fit in memory + +[job] +dump_folder = "./outputs/qwen3_30b_a3b_memory_test" +description = "Qwen3 30B-A3B - memory tracking test" + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "30B-A3B" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 +# Test BF16 optimizer states (50% memory savings) +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 1 +seq_len = 2048 +max_norm = 1.0 +steps = 3 +dataset = "c4" +enable_cpu_offload = true +# Enable detailed memory tracking +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +# Enable aggressive memory management +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +context_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 8 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = false +export_dtype = "float16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = "op" + +[compile] +enable = true +components = ["loss"] + +[debug] +# Test NaN tracker +enable_nan_tracker = true +nan_tracker_verbose = false diff --git a/run_train.sh b/run_train.sh index 3a42f53b22..069ea084b0 100755 --- a/run_train.sh +++ b/run_train.sh @@ -10,15 +10,38 @@ set -ex # use envs as local overwrites for convenience # e.g. # LOG_RANK=0,1 NGPU=4 ./run_train.sh +# +# COMM_MODE options for debugging: +# +# 1. "fake_backend" - Dry-run mode for config validation without GPU execution +# - Uses fake process groups (no actual communication) +# - Runs on a single GPU without torchrun or NCCL initialization +# - Useful for validating configuration and model setup +# Example: NGPU=32 COMM_MODE="fake_backend" ./run_train.sh +# +# 2. "local_tensor" - Single-GPU debugging mode with simulated multi-GPU behavior +# - All communication and computation execute on a single shared GPU +# - Simulates the full training workflow without actual distributed communication +# - Useful for debugging distributed training logic locally +# Example: NGPU=32 COMM_MODE="local_tensor" ./run_train.sh + NGPU=${NGPU:-"8"} export LOG_RANK=${LOG_RANK:-0} CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/debug_model.toml"} TRAIN_FILE=${TRAIN_FILE:-"torchtitan.train"} +COMM_MODE=${COMM_MODE:-""} TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE:-"http://localhost:29510"} -PYTORCH_ALLOC_CONF=${PYTORCH_ALLOC_CONF:-"expandable_segments:True"} \ -TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ -torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ ---local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ --m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" +if [ -n "$COMM_MODE" ]; then + # Communication mode specified: validate configuration or run in debug mode + echo "Running with comm_mode=${COMM_MODE}" + NGPU="${NGPU}" LOCAL_RANK=0 python3 -m "${TRAIN_FILE}" --job.config_file "${CONFIG_FILE}" "$@" --comm.mode=${COMM_MODE} --training.steps=1 +else + # Normal training with torchrun + PYTORCH_ALLOC_CONF="expandable_segments:True" \ + TORCHFT_LIGHTHOUSE=${TORCHFT_LIGHTHOUSE} \ + torchrun --nproc_per_node=${NGPU} --rdzv_backend c10d --rdzv_endpoint="localhost:0" \ + --local-ranks-filter ${LOG_RANK} --role rank --tee 3 \ + -m ${TRAIN_FILE} --job.config_file ${CONFIG_FILE} "$@" +fi diff --git a/scripts/checkpoint_conversion/convert_from_hf.py b/scripts/checkpoint_conversion/convert_from_hf.py index db61b10148..b1076e4ea6 100644 --- a/scripts/checkpoint_conversion/convert_from_hf.py +++ b/scripts/checkpoint_conversion/convert_from_hf.py @@ -17,8 +17,6 @@ @torch.inference_mode() def convert_from_hf(input_dir, output_dir, model_name, model_flavor): - if model_name == "flux": - import torchtitan.experiments.flux # noqa: F401 # initialize model to allocate memory for state dict train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] @@ -30,6 +28,7 @@ def convert_from_hf(input_dir, output_dir, model_name, model_flavor): model = train_spec.model_cls(model_args) model = ModelWrapper(model) + # pyrefly: ignore[bad-instantiation, not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, None) assert ( sd_adapter is not None diff --git a/scripts/checkpoint_conversion/convert_to_hf.py b/scripts/checkpoint_conversion/convert_to_hf.py index ca7470d162..4a8f507804 100644 --- a/scripts/checkpoint_conversion/convert_to_hf.py +++ b/scripts/checkpoint_conversion/convert_to_hf.py @@ -13,6 +13,7 @@ import torchtitan.protocols.train_spec as train_spec_module from torch.distributed.checkpoint import HuggingFaceStorageWriter from torchtitan.components.checkpoint import ModelWrapper +from torchtitan.config import TORCH_DTYPE_MAP from torchtitan.config.job_config import PEFT # Config files to copy from source HF model @@ -53,9 +54,14 @@ def find_step_dirs(checkpoint_dir: Path) -> list[Path]: @torch.inference_mode() -def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_path): - if model_name == "flux": - import torchtitan.experiments.flux # noqa: F401 +def convert_to_hf( + input_dir, + output_dir, + model_name, + model_flavor, + hf_assets_path, + export_dtype, +): # load model and model args so that we can get the state dict shape train_spec = train_spec_module.get_train_spec(model_name) model_args = train_spec.model_args[model_flavor] @@ -67,6 +73,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat model = train_spec.model_cls(model_args) model = ModelWrapper(model) + # pyrefly: ignore[bad-instantiation, not-callable] sd_adapter = train_spec.state_dict_adapter(model_args, hf_assets_path) assert ( sd_adapter is not None @@ -90,6 +97,11 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat thread_count_consolidation=5, ) + # map and apply export dtype if needed + target_dtype = TORCH_DTYPE_MAP[export_dtype] + if target_dtype != torch.float32: + hf_state_dict = {k: v.to(target_dtype) for k, v in hf_state_dict.items()} + dcp.save( hf_state_dict, storage_writer=storage_writer, @@ -101,7 +113,7 @@ def convert_to_hf(input_dir, output_dir, model_name, model_flavor, hf_assets_pat def convert_all_checkpoints( - checkpoint_dir: Path, model_name: str, model_flavor: str, hf_assets_path: Path + checkpoint_dir: Path, model_name: str, model_flavor: str, hf_assets_path: Path, export_dtype, ): """Convert all step-* checkpoints in a directory to HF format. @@ -130,7 +142,7 @@ def convert_all_checkpoints( continue print(f"Converting {step_name}...") - convert_to_hf(step_dir, output_dir, model_name, model_flavor, hf_assets_path) + convert_to_hf(step_dir, output_dir, model_name, model_flavor, hf_assets_path, export_dtype) print(f" -> {output_dir}") @@ -156,6 +168,14 @@ def convert_all_checkpoints( ) parser.add_argument("--model_name", type=str, nargs="?", default="llama3") parser.add_argument("--model_flavor", type=str, nargs="?", default="8B") + parser.add_argument( + "--export_dtype", + type=str, + nargs="?", + choices=["float16", "bfloat16", "float32"], + default="float32", + help="Export dtype for HF checkpoint (default: float32)", + ) parser.add_argument( "--all", action="store_true", @@ -169,6 +189,7 @@ def convert_all_checkpoints( args.model_name, args.model_flavor, args.hf_assets_path, + args.export_dtype, ) else: if args.output_dir is None: @@ -179,4 +200,5 @@ def convert_all_checkpoints( args.model_name, args.model_flavor, args.hf_assets_path, + args.export_dtype, ) diff --git a/scripts/checkpoint_conversion/numerical_tests_example.py b/scripts/checkpoint_conversion/numerical_tests_example.py index 079faee211..7a7c648d75 100644 --- a/scripts/checkpoint_conversion/numerical_tests_example.py +++ b/scripts/checkpoint_conversion/numerical_tests_example.py @@ -15,6 +15,8 @@ from torchtitan.config.job_config import PEFT from torchtitan.protocols.train_spec import get_train_spec from torchtitan.tools.logging import logger + +# pyrefly: ignore[import-error] from transformers import AutoModelForCausalLM device_type = "cuda" if torch.cuda.is_available() else "cpu" @@ -26,7 +28,7 @@ def loss_fn(logits1, logits2): probs2 = F.softmax(logits2, dim=-1) # Calculate KL Divergence - kl_loss = F.kl_div(probs1, probs2, "mean") + kl_loss = F.kl_div(probs1, probs2, reduction="mean") return kl_loss @@ -124,6 +126,7 @@ def forward_tt(config_path, checkpoint_path, test_set): config_manager = ConfigManager() config = config_manager.parse_args([f"--job.config_file={config_path}"]) train_spec = get_train_spec(config.model.name) + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) # Build test set of randomly generated token ids @@ -154,10 +157,11 @@ def forward_tt(config_path, checkpoint_path, test_set): avg_losses = {} for test_name, (baseline_outputs, conversion_outputs) in test_configs.items(): - total_loss = 0 + total_loss: int | torch.Tensor = 0 for baseline, outputs in zip(baseline_outputs, conversion_outputs): total_loss += loss_fn(baseline, outputs) avg_loss = total_loss / len(test_set) + # pyrefly: ignore [missing-attribute] avg_losses[test_name] = avg_loss.item() for test_name, avg_loss in avg_losses.items(): diff --git a/scripts/download_hf_assets.py b/scripts/download_hf_assets.py index 3a0b06a4d9..75c57b725c 100644 --- a/scripts/download_hf_assets.py +++ b/scripts/download_hf_assets.py @@ -171,6 +171,7 @@ def should_download(patterns: list[str], filename: str) -> bool: missed_files = [] # Download files with progress bar + # pyrefly: ignore [bad-context-manager] with tqdm(total=len(files_found), desc="Downloading files", unit="file") as pbar: for filename in files_found: try: diff --git a/scripts/estimate/estimation.py b/scripts/estimate/estimation.py index 7694e9fdca..4e23818e67 100644 --- a/scripts/estimate/estimation.py +++ b/scripts/estimate/estimation.py @@ -115,7 +115,10 @@ def estimate_memory(job_config: JobConfig): # build optimizer after applying parallelisms to the model optimizers = build_optimizers([model], job_config.optimizer, parallel_dims) lr_schedulers = build_lr_schedulers( - optimizers.optimizers, job_config.lr_scheduler, job_config.training.steps + # pyrefly: ignore [bad-argument-type] + optimizers.optimizers, + job_config.lr_scheduler, + job_config.training.steps, ) # Post optimizer step model converters hook. # e.g. calculate float8 dynamic amax/scale for all-parameter for FSDP2 @@ -124,18 +127,23 @@ def estimate_memory(job_config: JobConfig): lambda *args, **kwargs: model_converters.post_optimizer_hook(model) ) + # pyrefly: ignore [missing-attribute] logger.info(f"Vocab size: {model_args.vocab_size}") # Create a dummy batch instead of loading from a dataset batch = ( torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), torch.randint( 0, + # pyrefly: ignore [missing-attribute] model_args.vocab_size, + # pyrefly: ignore [missing-attribute] (job_config.training.local_batch_size, model_args.max_seq_len), device="cuda", ), @@ -156,7 +164,9 @@ def estimate_memory(job_config: JobConfig): # clip gradients torch.nn.utils.clip_grad_norm_( - model.parameters(), job_config.training.max_norm, foreach=True + model.parameters(), + job_config.training.max_norm, + foreach=True, ) # optimizer step optimizers.step() diff --git a/scripts/generate/test_generate.py b/scripts/generate/test_generate.py index c94997e6d1..66ceaa5208 100644 --- a/scripts/generate/test_generate.py +++ b/scripts/generate/test_generate.py @@ -37,6 +37,7 @@ wd = Path(__file__).parent.parent.resolve() sys.path.append(str(wd)) +# pyrefly: ignore[missing-import] from generate._generation import generate @@ -50,6 +51,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): }, ) + # pyrefly: ignore [missing-attribute] for _, transformer_block in model.layers.items(): layer_plan = { "attention.wq": ColwiseParallel(), @@ -64,6 +66,7 @@ def apply_tp_minus_sp(model: nn.Module, tp_mesh: DeviceMesh): parallelize_module( module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) @@ -105,6 +108,7 @@ def test_generate( logger.info(f"World Size: {world_size}, Local Rank: {local_rank} on {device}") # Tokenizer setup + # pyrefly: ignore [not-callable] tokenizer = train_spec.build_tokenizer_fn(config) model_args = train_spec.model_args[config.model.flavor] @@ -123,7 +127,7 @@ def test_generate( except TypeError: model = train_spec.model_cls(model_args) - world_mesh = None + parallel_dims = None # Init distributed env if world_size > 1: dist_utils.init_distributed(config.comm) @@ -137,15 +141,25 @@ def test_generate( etp=1, world_size=world_size, ) - world_mesh = parallel_dims.world_mesh # apply_tp (with Sequence Parallel) on unevenly sharded # sequences would require https://github.com/pytorch/torchtitan/pull/686 - apply_tp_minus_sp(model, parallel_dims.world_mesh["tp"]) + apply_tp_minus_sp(model, parallel_dims.get_mesh("tp")) + else: + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) debug_config = DebugConfig(seed=seed, deterministic=deterministic) dist_utils.set_determinism( - world_mesh=world_mesh, + parallel_dims=parallel_dims, device=device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], @@ -226,7 +240,7 @@ def test_generate( "input_text": input_text, "output_text": output_text, } - output_data["responses"].append(_data) + output_data["responses"].append(_data) # pyrefly: ignore[missing-attribute] logger.info(f"{r}\n{input_text}{b}{output_text}\n{color.reset}") diff --git a/scripts/loss_compare.py b/scripts/loss_compare.py new file mode 100644 index 0000000000..a084880e0c --- /dev/null +++ b/scripts/loss_compare.py @@ -0,0 +1,1107 @@ +#!/usr/bin/env python3 + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This script compares training losses between different git commits +and/or different training configurations. --debug.deterministic is +always enabled and seed checkpoint is also enabled by default for +reproducible comparisons. You can disable seed checkpoint with +--no-seed-checkpoint if you don't need it to speed up comparisons. +If --output-folder is specified, all outputs are organized in that +folder with detailed analysis and statistical summaries. + +The --assert-equal flag can be used for CI testing to verify that +losses are identical between runs. If losses differ, the script will +exit with a non-zero status code. + +Example usages: +1. Compare losses between two different git commits with default config: + loss_compare.py main my_branch + +2. Compare losses between two commits with custom config and options: + loss_compare.py main my_branch \ + --baseline-config='./custom.toml' \ + --baseline-options='--parallelism.tensor_parallel_degree=2' \ + --output-folder=my_comparison + +3. Compare commits with the same command but skip seed checkpoint for + faster execution: + loss_compare.py main my_branch --no-seed-checkpoint + +4. Compare the same commit with different training configurations: + loss_compare.py . . \ + --baseline-options='--parallelism.dp=1' \ + --test-options='--parallelism.dp=2' + +5. Compare with different train files: + loss_compare.py main my_branch \ + --baseline-train-file='torchtitan.train' \ + --test-train-file='torchtitan.custom_train' + +6. Assert that losses are equal (for CI testing): + loss_compare.py main my_branch --assert-equal +""" + +import argparse +import os +import re +import subprocess +import sys +import unittest +from typing import Any + +# ============================================================================= +# GLOBAL CONFIGURATION +# ============================================================================= + +LOG_PREFIX = "[LOSS_COMPARE]" + +# Fixed options that are always appended +FIXED_OPTIONS = "--debug.deterministic --debug.seed=42" + + +# ============================================================================= +# UTILITY FUNCTIONS +# ============================================================================= + + +def log_print(message: str = "") -> None: + """Print message with LOG_PREFIX.""" + if message: + print(f"{LOG_PREFIX} {message}") + else: + print(f"{LOG_PREFIX}") + + +def get_log_path(scenario: str, output_folder: str | None) -> str: + """Get log file path for a scenario.""" + if output_folder: + return f"{output_folder}/{scenario}_training.log" + return f"/tmp/{scenario}_training.log" + + +def get_loss_file_path(scenario: str, output_folder: str) -> str: + """Get loss file path for a scenario.""" + return f"{output_folder}/{scenario}_losses.txt" + + +def get_clean_log_path(scenario: str, output_folder: str) -> str: + """Get cleaned log file path for a scenario.""" + return f"{output_folder}/{scenario}_training_clean.log" + + +def build_base_command( + config_file: str, train_file: str, options: str, job_dump_folder: str +) -> str: + """Build the base command from config file, train file, and options.""" + cmd = f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' ./run_train.sh" + cmd += f" --job.dump_folder={job_dump_folder}" + if options: + cmd += f" {options}" + return cmd + + +def strip_ansi_codes(input_file: str, output_file: str) -> None: + """Strip ANSI escape codes from log files.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + with open(input_file, "r") as f_in: + with open(output_file, "w") as f_out: + for line in f_in: + f_out.write(ansi_escape.sub("", line)) + + +def run_with_realtime_output(cmd: str, logfile: str, env: dict[str, Any]) -> None: + """Run command with real-time output to both console and log file.""" + log_print(f"Executing: {cmd}") + + # Set PYTHONUNBUFFERED for better output handling + env["PYTHONUNBUFFERED"] = "1" + + # Run command and tee output to both stdout and log file + with open(logfile, "w") as log_f: + process = subprocess.Popen( + cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env, + text=True, + bufsize=1, + ) + + # pyrefly: ignore [not-iterable] + for line in process.stdout: + print(line, end="") + log_f.write(line) + log_f.flush() + + process.wait() + + if process.returncode != 0: + raise subprocess.CalledProcessError(process.returncode, cmd) + + +def log_and_save(message: str, stats_file: str | None) -> None: + """Output message to both stdout and stats file if provided.""" + print(message) + if stats_file: + with open(stats_file, "a") as f: + f.write(message + "\n") + + +# ============================================================================= +# VALIDATION FUNCTIONS +# ============================================================================= + + +def validate_arguments( + baseline_commit: str, + test_commit: str, + baseline_config: str, + baseline_train_file: str, + baseline_options: str, + test_config: str, + test_train_file: str, + test_options: str, + steps: int, + assert_equal: bool, + export_result: str | None, + import_result: str | None, +) -> None: + """Validate command line arguments.""" + # Validate that we are comparing different settings + commits_differ = baseline_commit != test_commit + configs_differ = baseline_config != test_config + train_files_differ = baseline_train_file != test_train_file + options_differ = baseline_options != test_options + + if not (commits_differ or configs_differ or train_files_differ or options_differ): + log_print("Error: All settings are identical") + log_print(" Cannot compare identical configurations") + log_print( + " Please provide different commits, configs, train files, or options" + ) + sys.exit(1) + + # Validate steps is a positive integer + if steps <= 0: + log_print(f"Error: --steps must be a positive integer, got: {steps}") + sys.exit(1) + + # Validate export-result requires assert-equal + if export_result and not assert_equal: + log_print("Error: --export-result requires --assert-equal") + log_print(" Export only happens when losses are verified to match") + sys.exit(1) + + # Validate import-result requires assert-equal + if import_result and not assert_equal: + log_print("Error: --import-result requires --assert-equal") + log_print(" Import is used to verify all losses match") + sys.exit(1) + + # Validate export-result and import-result are mutually exclusive + if export_result and import_result: + log_print( + "Error: --export-result and --import-result cannot be " "used together" + ) + log_print( + " Use export to save results or import to compare " + "against saved results" + ) + sys.exit(1) + + # Validate import file exists + if import_result and not os.path.exists(import_result): + log_print(f"Error: Import file does not exist: {import_result}") + sys.exit(1) + + +# ============================================================================= +# SETUP FUNCTIONS +# ============================================================================= + + +def setup_output_directory(output_folder: str | None) -> str | None: + """Setup output directory and return stats file path. + Returns None if no output folder specified. + """ + if not output_folder: + return None + + # Check if output folder already exists + if os.path.exists(output_folder): + log_print(f"Error: Output folder '{output_folder}' already exists") + log_print(f"Please delete it first: rm -rf {output_folder}") + sys.exit(1) + + # Create the output folder + log_print(f"Creating output folder: {output_folder}") + os.makedirs(output_folder) + + # Set statistics file path + stats_file = os.path.join(output_folder, "comparison_statistics.txt") + return stats_file + + +def build_training_command( + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + job_dump_folder: str, +) -> str: + """Build the final training command with all options.""" + base_cmd = build_base_command(config_file, train_file, options, job_dump_folder) + cmd = f"{base_cmd} {FIXED_OPTIONS} --training.steps={steps}" + if enable_seed_checkpoint: + cmd += ( + " --checkpoint.enable --checkpoint.export_dtype=bfloat16" + " --checkpoint.load_only" + ) + return cmd + + +def print_configuration( + baseline_commit: str, + test_commit: str, + baseline_config: str, + baseline_train_file: str, + baseline_options: str, + test_config: str, + test_train_file: str, + test_options: str, + steps: int, + enable_seed_checkpoint: bool, + job_dump_folder: str, +) -> None: + """Print configuration summary.""" + log_print( + f"Starting loss comparison between baseline commit: " + f"{baseline_commit} and test commit: {test_commit}" + ) + log_print(f"Training steps: {steps}") + log_print(f"Seed checkpoint enabled: {enable_seed_checkpoint}") + log_print() + + # Build and display final commands + baseline_final_cmd = build_training_command( + baseline_config, + baseline_train_file, + baseline_options, + steps, + enable_seed_checkpoint, + job_dump_folder, + ) + test_final_cmd = build_training_command( + test_config, + test_train_file, + test_options, + steps, + enable_seed_checkpoint, + job_dump_folder, + ) + + log_print("Baseline command:") + log_print(f" {baseline_final_cmd}") + log_print() + log_print("Test command:") + log_print(f" {test_final_cmd}") + log_print() + + +# ============================================================================= +# GIT OPERATIONS +# ============================================================================= + + +def check_git_clean_state() -> None: + """Check if git working directory is clean before switching commits. + + Raises SystemExit if there are uncommitted changes to tracked files. + Untracked files are ignored. + """ + result = subprocess.run( + ["git", "status", "--porcelain"], + capture_output=True, + text=True, + check=True, + ) + + # Filter out untracked files (lines starting with "??") + modified_tracked_files = [] + for line in result.stdout.strip().split("\n"): + if line and not line.startswith("??"): + modified_tracked_files.append(line) + + if modified_tracked_files: + log_print( + "Error: Git working directory has uncommitted changes to tracked files" + ) + log_print(" Cannot switch commits with uncommitted changes") + log_print("") + log_print("Modified tracked files:") + for line in modified_tracked_files: + log_print(f" {line}") + log_print("") + log_print( + "Please commit, stash, or discard your changes before running this script" + ) + log_print(" - To commit: git add -A && git commit -m 'message'") + log_print(" - To stash: git stash") + log_print(" - To discard: git checkout -- . && git clean -fd") + sys.exit(1) + + +def checkout_commit(commit: str, commit_name: str) -> None: + """Checkout git commit.""" + if commit != ".": + log_print(f"Checking out {commit_name} commit: {commit}") + subprocess.run(["git", "checkout", commit], check=True) + else: + log_print(f"Using current working directory for {commit_name} (commit: '.')") + + +def get_current_commit() -> str: + """Get the current git commit hash or branch name. + + Returns the current branch name if on a branch, otherwise returns the commit hash. + """ + # Try to get current branch name + result = subprocess.run( + ["git", "rev-parse", "--abbrev-ref", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + ref = result.stdout.strip() + + # If in detached HEAD state, ref will be "HEAD", so get the commit hash instead + if ref == "HEAD": + result = subprocess.run( + ["git", "rev-parse", "HEAD"], + capture_output=True, + text=True, + check=True, + ) + ref = result.stdout.strip() + + return ref + + +def restore_original_commit(original_commit: str) -> None: + """Restore the original git commit/branch.""" + log_print(f"Restoring original commit/branch: {original_commit}") + subprocess.run(["git", "checkout", original_commit], check=True) + + +# ============================================================================= +# TRAINING OPERATIONS +# ============================================================================= + + +def create_seed_checkpoint( + enable_seed_checkpoint: bool, + config_file: str, + train_file: str, + output_folder: str | None, + job_dump_folder: str, +) -> None: + """Create seed checkpoint.""" + if enable_seed_checkpoint: + log_file = get_log_path("seed", output_folder) + log_print(f"Creating seed checkpoint and logging output to {log_file}") + + # Build seed checkpoint command + seed_cmd = ( + f"TRAIN_FILE='{train_file}' CONFIG_FILE='{config_file}' " + f"./run_train.sh --job.dump_folder={job_dump_folder} " + f"--checkpoint.create_seed_checkpoint " + f"--checkpoint.enable {FIXED_OPTIONS}" + ) + + env = os.environ.copy() + env["NGPU"] = "1" + + run_with_realtime_output(seed_cmd, log_file, env) + + +def run_training( + scenario: str, + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, + job_dump_folder: str, + ngpus: int, +) -> str: + """Run training for a specific scenario. Returns the log file path.""" + log_file = get_log_path(scenario, output_folder) + log_print( + f"Running training with {scenario} commit and logging output " f"to {log_file}" + ) + + # Build the final command + full_cmd = build_training_command( + config_file, train_file, options, steps, enable_seed_checkpoint, job_dump_folder + ) + + env = os.environ.copy() + env["NGPU"] = str(ngpus) + + run_with_realtime_output(full_cmd, log_file, env) + + return log_file + + +# ============================================================================= +# LOG PROCESSING AND ANALYSIS +# ============================================================================= + + +def extract_losses_from_log(log_file: str) -> dict[int, float]: + """Extract step and loss pairs from a log file.""" + losses = {} + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + ansi_escape = re.compile(r"\x1b\[[0-9;]*m") + + with open(log_file, "r") as f: + for line in f: + # Strip ANSI codes before matching + clean_line = ansi_escape.sub("", line) + match = step_loss_pattern.search(clean_line) + if match: + step, loss = match.groups() + losses[int(step)] = float(loss) + + return losses + + +def read_losses_from_file(loss_file: str) -> dict[int, float]: + """Read losses from a processed loss file.""" + losses = {} + with open(loss_file, "r") as f: + for line in f: + step, loss = line.strip().split() + losses[int(step)] = float(loss) + return losses + + +def export_losses_to_file(losses: dict[int, float], export_path: str) -> None: + """Export losses to file and stdout. + + Args: + losses: Dictionary mapping step numbers to loss values + export_path: Path to export file + """ + log_print(f"Exporting losses to {export_path}") + + # Write to file and collect output for stdout + with open(export_path, "w") as f: + for step in sorted(losses.keys()): + loss = losses[step] + line = f"{step} {loss}" + f.write(line + "\n") + + log_print(f"Exported {len(losses)} loss values:") + log_print() + + # Output to stdout in same format + for step in sorted(losses.keys()): + loss = losses[step] + print(f"{step} {loss}") + + log_print() + log_print(f"Losses saved to: {export_path}") + + +def extract_loss_data(output_folder: str | None) -> None: + """Extract loss data from logs.""" + if not output_folder: + return + + log_print("Cleaning ANSI escape codes from log files...") + + # Strip ANSI escape codes from log files before processing + scenarios = ["baseline", "test"] + for scenario in scenarios: + strip_ansi_codes( + get_log_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ) + + # Extract step and loss from cleaned logs + step_loss_pattern = re.compile(r"step:\s*(\d+)\s*loss:\s*(\d+\.\d+)") + + for scenario in scenarios: + with open(get_clean_log_path(scenario, output_folder), "r") as f_in: + with open(get_loss_file_path(scenario, output_folder), "w") as f_out: + for line in f_in: + match = step_loss_pattern.search(line) + if match: + step, loss = match.groups() + f_out.write(f"{step} {loss}\n") + + +def generate_step_comparison( + baseline_losses: dict[int, float], + test_losses: dict[int, float], + stats_file: str | None, +) -> None: + """Generate step-by-step comparison.""" + log_and_save("", stats_file) + log_and_save(f"{LOG_PREFIX} Step-by-step loss comparison:", stats_file) + log_and_save( + f"{LOG_PREFIX} Step Baseline Loss Test Loss Difference", + stats_file, + ) + log_and_save( + f"{LOG_PREFIX} ---- ------------- --------- ----------", + stats_file, + ) + + # Generate comparison for common steps + for step in sorted(set(baseline_losses.keys()) & set(test_losses.keys())): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + diff = test_loss - baseline_loss + + formatted_line = ( + f"{LOG_PREFIX} {step:<6} {baseline_loss:<13} " + f"{test_loss:<14} {diff:.6f}" + ) + log_and_save(formatted_line, stats_file) + + +def generate_summary_statistics( + baseline_losses: dict[int, float], + test_losses: dict[int, float], + stats_file: str | None, +) -> None: + """Generate summary statistics.""" + log_and_save(f"{LOG_PREFIX}", stats_file) + log_and_save(f"{LOG_PREFIX} Summary statistics:", stats_file) + + # Calculate average losses + def calculate_average(losses: dict[int, float]) -> float | None: + """Calculate average loss from losses dict.""" + if not losses: + return None + return sum(losses.values()) / len(losses) + + baseline_avg = calculate_average(baseline_losses) + test_avg = calculate_average(test_losses) + + baseline_avg_str = f"{baseline_avg}" if baseline_avg is not None else "N/A" + test_avg_str = f"{test_avg}" if test_avg is not None else "N/A" + + log_and_save(f"{LOG_PREFIX} Average baseline loss: {baseline_avg_str}", stats_file) + log_and_save(f"{LOG_PREFIX} Average test loss: {test_avg_str}", stats_file) + + # Calculate overall difference if both averages are available + if baseline_avg is not None and test_avg is not None: + avg_diff = test_avg - baseline_avg + log_and_save(f"{LOG_PREFIX} Average difference: {avg_diff:.6f}", stats_file) + + +def perform_loss_analysis( + baseline_log: str, test_log: str, stats_file: str | None +) -> None: + """Perform loss comparison analysis.""" + # Initialize stats file and add header + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + log_and_save(f"{LOG_PREFIX} LOSS COMPARISON ANALYSIS", stats_file) + log_and_save(f"{LOG_PREFIX} ==========================================", stats_file) + + # Extract losses directly from log files + baseline_losses = extract_losses_from_log(baseline_log) + test_losses = extract_losses_from_log(test_log) + + # Check if losses were extracted successfully + name_losses = [("baseline", baseline_losses), ("test", test_losses)] + for name, losses in name_losses: + if not losses: + log_and_save( + f"{LOG_PREFIX} Warning: Could not extract loss data from " + f"{name} training log.", + stats_file, + ) + log_and_save( + f"{LOG_PREFIX} Please check that the training completed " + "successfully.", + stats_file, + ) + return + + # Generate comparison outputs + generate_step_comparison(baseline_losses, test_losses, stats_file) + generate_summary_statistics(baseline_losses, test_losses, stats_file) + + +def assert_losses_equal( + baseline_log: str, test_log: str, import_result: str | None = None +) -> None: + """Assert that losses are equal between baseline and test using unittest. + + If import_result is provided, also compares baseline with imported losses. + """ + log_print("Asserting losses are equal...") + log_print(f"Baseline log: {baseline_log}") + log_print(f"Test log: {test_log}") + if import_result: + log_print(f"Import file: {import_result}") + + # Extract losses from both logs + baseline_losses = extract_losses_from_log(baseline_log) + test_losses = extract_losses_from_log(test_log) + + log_print(f"Extracted {len(baseline_losses)} steps from baseline log") + log_print(f"Extracted {len(test_losses)} steps from test log") + + if not baseline_losses: + log_print("Error: No losses found in baseline log") + sys.exit(1) + + if not test_losses: + log_print("Error: No losses found in test log") + sys.exit(1) + + # Load imported losses if provided + imported_losses = None + if import_result: + imported_losses = read_losses_from_file(import_result) + log_print(f"Loaded {len(imported_losses)} steps from import file") + if not imported_losses: + log_print("Error: No losses found in import file") + sys.exit(1) + + # Create a test case + class LossEqualityTest(unittest.TestCase): + def test_losses_equal(self): + # Check that both have the same steps + baseline_steps = set(baseline_losses.keys()) + test_steps = set(test_losses.keys()) + + self.assertEqual( + baseline_steps, + test_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, " + f"test has {len(test_steps)} steps", + ) + + # If imported losses exist, check steps match + if imported_losses: + imported_steps = set(imported_losses.keys()) + self.assertEqual( + baseline_steps, + imported_steps, + f"Steps mismatch: baseline has {len(baseline_steps)} steps, " + f"imported has {len(imported_steps)} steps", + ) + + # Check that losses are equal for each step + for step in sorted(baseline_steps): + baseline_loss = baseline_losses[step] + test_loss = test_losses[step] + + # Compare baseline vs test + self.assertEqual( + baseline_loss, + test_loss, + f"Loss mismatch at step {step}: " + f"baseline={baseline_loss}, test={test_loss}", + ) + + # Compare baseline vs imported (if provided) + # No need to compare test vs imported since: + # baseline==test and baseline==imported implies test==imported + if imported_losses: + imported_loss = imported_losses[step] + self.assertEqual( + baseline_loss, + imported_loss, + f"Loss mismatch at step {step}: " + f"baseline={baseline_loss}, imported={imported_loss}", + ) + + # Run the test + suite = unittest.TestLoader().loadTestsFromTestCase(LossEqualityTest) + runner = unittest.TextTestRunner(verbosity=2) + result = runner.run(suite) + + if not result.wasSuccessful(): + log_print("Loss assertion failed!") + sys.exit(1) + else: + if import_result: + log_print( + "All losses are equal (baseline, test, and imported). " + "Assertion passed!" + ) + else: + log_print("All losses are equal. Assertion passed!") + + +def cleanup_temp_files(output_folder: str | None) -> None: + """Cleanup temporary files.""" + if not output_folder: + return + + scenarios = ["baseline", "test"] + for scenario in scenarios: + for temp_file in [ + get_loss_file_path(scenario, output_folder), + get_clean_log_path(scenario, output_folder), + ]: + if os.path.exists(temp_file): + os.remove(temp_file) + + +# ============================================================================= +# OUTPUT FUNCTIONS +# ============================================================================= + + +def print_completion_summary( + output_folder: str | None, enable_seed_checkpoint: bool +) -> None: + """Print completion summary.""" + log_print() + if output_folder: + log_print(f"Loss comparison complete. Results saved in {output_folder}/:") + log_print(" - baseline_outputs/") + log_print(" - test_outputs/") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint_outputs/") + log_print() + log_print(f"Training logs saved in {output_folder}/:") + if enable_seed_checkpoint: + log_print(" - seed_checkpoint.log") + log_print(" - baseline_training.log") + log_print(" - test_training.log") + log_print() + log_print(f"All outputs organized in: {output_folder}/") + else: + log_print( + "Loss comparison complete. No results saved " + "(no output folder specified)." + ) + + +# ============================================================================= +# MAIN EXECUTION +# ============================================================================= + + +def parse_arguments() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description=( + "Compare training losses between different git commits " + "and/or different training configurations." + ), + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s abc123 def456 + %(prog)s abc123 def456 --steps=200 + %(prog)s abc123 def456 --baseline-config='./custom.toml' \\ + --baseline-options='--parallelism.tensor_parallel_degree=2' --steps=50 + %(prog)s abc123 def456 --no-seed-checkpoint + %(prog)s . . --baseline-options='--parallelism.dp=1' \\ + --test-options='--parallelism.dp=2' --steps=30 + """, + ) + + parser.add_argument("baseline_commit", help="Git commit hash for baseline") + parser.add_argument("test_commit", help="Git commit hash for test") + parser.add_argument( + "--baseline-config", + default="./torchtitan/models/llama3/train_configs/debug_model.toml", + help=( + "Config file for baseline run " + "(default: ./torchtitan/models/llama3/train_configs/" + "llama3_debug.toml)" + ), + ) + parser.add_argument( + "--test-config", + default="", + help="Config file for test run (default: uses baseline-config)", + ) + parser.add_argument( + "--baseline-options", + default="", + help="Additional CLI arguments for baseline run (default: empty)", + ) + parser.add_argument( + "--test-options", + default="", + help="Additional CLI arguments for test run (default: empty)", + ) + parser.add_argument( + "--baseline-train-file", + default="torchtitan.train", + help=( + "Train file (Python module path) for baseline run " + "(default: torchtitan.train)" + ), + ) + parser.add_argument( + "--test-train-file", + default="", + help=( + "Train file (Python module path) for test run " + "(default: uses baseline-train-file)" + ), + ) + parser.add_argument( + "--steps", + type=int, + default=100, + help="Number of training steps (default: 100)", + ) + parser.add_argument( + "--no-seed-checkpoint", + action="store_true", + help=("Disable seed checkpoint creation and checkpoint functionality"), + ) + parser.add_argument( + "--output-folder", + default="", + help=( + "Output folder for results (optional, if not specified, " + "results will not be saved)" + ), + ) + parser.add_argument( + "--assert-equal", + action="store_true", + help=( + "Assert that all losses are equal (for CI testing). " + "Script exits with error if losses differ." + ), + ) + parser.add_argument( + "--export-result", + default="", + help=( + "Export losses to specified file path (requires --assert-equal). " + "Exports only when losses match. Format: '{step} {loss}' per line." + ), + ) + parser.add_argument( + "--import-result", + default="", + help=( + "Import losses from specified file path for comparison " + "(requires --assert-equal). " + "Compares imported losses with both baseline and test " + "(all 3 must match)." + ), + ) + parser.add_argument( + "--job-dump-folder", + default="outputs", + help="Job dump folder path (default: outputs)", + ) + parser.add_argument( + "--baseline-ngpus", + type=int, + default=8, + help="Number of GPUs for baseline run (default: 8)", + ) + parser.add_argument( + "--test-ngpus", + type=int, + default=8, + help="Number of GPUs for test run (default: 8)", + ) + + args = parser.parse_args() + + # Set default values if not provided + if not args.test_config: + args.test_config = args.baseline_config + + if not args.test_train_file: + args.test_train_file = args.baseline_train_file + + # Convert empty output_folder to None + if not args.output_folder: + args.output_folder = None + + # Convert empty export_result to None + if not args.export_result: + args.export_result = None + + # Convert empty import_result to None + if not args.import_result: + args.import_result = None + + return args + + +def run_scenario( + scenario: str, + commit: str, + config_file: str, + train_file: str, + options: str, + steps: int, + enable_seed_checkpoint: bool, + output_folder: str | None, + job_dump_folder: str, + ngpus: int, +) -> str: + """Run training for a specific scenario (baseline or test). + + Args: + scenario: Name of the scenario ("baseline" or "test") + commit: Git commit to checkout + config_file: Config file path + train_file: Train file (Python module path) + options: Additional CLI options + steps: Number of training steps + enable_seed_checkpoint: Whether to use seed checkpoint + output_folder: Output folder for results + job_dump_folder: Job dump folder path + ngpus: Number of GPUs to use + + Returns: + Path to the log file + """ + checkout_commit(commit, scenario) + + log_file = run_training( + scenario, + config_file, + train_file, + options, + steps, + enable_seed_checkpoint, + output_folder, + job_dump_folder, + ngpus, + ) + + return log_file + + +def main() -> None: + """Main function that orchestrates the entire comparison process.""" + # Parse and validate arguments + args = parse_arguments() + validate_arguments( + args.baseline_commit, + args.test_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + args.assert_equal, + args.export_result, + args.import_result, + ) + + # Setup environment + stats_file = setup_output_directory(args.output_folder) + enable_seed_checkpoint = not args.no_seed_checkpoint + print_configuration( + args.baseline_commit, + args.test_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + enable_seed_checkpoint, + args.job_dump_folder, + ) + + # Check if git working directory is clean before switching commits + # Skip check only if both commits are "." (comparing configs on same commit) + needs_git_checkout = args.baseline_commit != "." or args.test_commit != "." + if needs_git_checkout: + check_git_clean_state() + + # Save original commit if we're going to do checkouts + original_commit = None + if needs_git_checkout: + original_commit = get_current_commit() + log_print(f"Saving original commit/branch: {original_commit}") + log_print() + + try: + create_seed_checkpoint( + enable_seed_checkpoint, + args.baseline_config, + args.baseline_train_file, + args.output_folder, + args.job_dump_folder, + ) + # Run baseline and test training + baseline_log = run_scenario( + "baseline", + args.baseline_commit, + args.baseline_config, + args.baseline_train_file, + args.baseline_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.baseline_ngpus, + ) + + test_log = run_scenario( + "test", + args.test_commit, + args.test_config, + args.test_train_file, + args.test_options, + args.steps, + enable_seed_checkpoint, + args.output_folder, + args.job_dump_folder, + args.test_ngpus, + ) + log_print() + + # Assert losses are equal if requested + if args.assert_equal: + # Pass import_result if provided for 3-way comparison + assert_losses_equal(baseline_log, test_log, args.import_result) + + # Export losses if requested (only after assertion passes) + if args.export_result: + # Extract baseline losses (they equal test losses since assertion passed) + baseline_losses = extract_losses_from_log(baseline_log) + export_losses_to_file(baseline_losses, args.export_result) + + # Analysis and reporting + perform_loss_analysis(baseline_log, test_log, stats_file) + cleanup_temp_files(args.output_folder) + print_completion_summary(args.output_folder, enable_seed_checkpoint) + finally: + # Restore original commit if we did checkouts + if original_commit is not None: + log_print() + restore_original_commit(original_commit) + + +if __name__ == "__main__": + main() diff --git a/scripts/preprocess_data.py b/scripts/preprocess_data.py index 6184634a8a..aca7d0cd47 100644 --- a/scripts/preprocess_data.py +++ b/scripts/preprocess_data.py @@ -709,7 +709,29 @@ def _get_conversation_len(x): return len(x["messages"]) return 0 + len_before = len(dataset) dataset = dataset.filter(lambda x: _get_conversation_len(x) > 3) + print(f"Filtered by multiturn: {len_before} -> {len(dataset)} samples") + if args.required_text: + def _contains_required_text(x): + if args.chat: + if "conversations" in x: + messages = x["conversations"] + elif "messages" in x: + messages = x["messages"] + else: + return False + for message in messages: + content = message.get("content") or message.get("value") or "" + if args.required_text in content: + return True + return False + else: + return args.required_text in x.get("text", "") + + len_before = len(dataset) + dataset = dataset.filter(_contains_required_text) + print(f"Filtered by required_text '{args.required_text}': {len_before} -> {len(dataset)} samples") original_column_names = list(dataset.features.keys()) dataset = dataset.map( @@ -914,7 +936,7 @@ def _add_position_ids_and_seq_lengths(sample): dataset.save_to_disk(args.save_to_disk) if args.push_to_hub: print(f"Pushing to Hugging Face repo {args.push_to_hub}") - dataset.push_to_hub(args.save_to_disk, private=True) + dataset.push_to_hub(args.push_to_hub, private=True) example = dataset[0] @@ -960,6 +982,7 @@ def _add_position_ids_and_seq_lengths(sample): parser.add_argument("--limit", type=int) parser.add_argument("--chat", action="store_true") parser.add_argument("--multiturn-only", action="store_true") + parser.add_argument("--required-text", type=str) parser.add_argument("--pack-to-sequence-length", type=int) parser.add_argument( "--epochs", diff --git a/tests/assets/losses/llama3_cuda.txt b/tests/assets/losses/llama3_cuda.txt new file mode 100644 index 0000000000..5ccea64b17 --- /dev/null +++ b/tests/assets/losses/llama3_cuda.txt @@ -0,0 +1,10 @@ +1 8.1376 +2 7.841 +3 7.1815 +4 6.3509 +5 5.5272 +6 4.9244 +7 4.5606 +8 4.3724 +9 4.347 +10 4.2004 diff --git a/tests/assets/losses/llama3_rocm.txt b/tests/assets/losses/llama3_rocm.txt new file mode 100644 index 0000000000..3aa7c24a1d --- /dev/null +++ b/tests/assets/losses/llama3_rocm.txt @@ -0,0 +1,5 @@ +1 8.1376 +2 7.8409 +3 7.1815 +4 6.3509 +5 5.7090 diff --git a/tests/integration_tests/features.py b/tests/integration_tests/features.py index 8bf3a0249f..3662aa6bf6 100755 --- a/tests/integration_tests/features.py +++ b/tests/integration_tests/features.py @@ -121,7 +121,6 @@ def build_features_test_list() -> list[OverrideDefinitions]: ], "Checkpoint Integration Test - save load model only checkpoint in HF definition and format", "model_only_hf_checkpoint", - skip_rocm_test=True, ), OverrideDefinitions( [ @@ -346,6 +345,20 @@ def build_features_test_list() -> list[OverrideDefinitions]: "fsdp+flex_attn+per_op_sac", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--parallelism.data_parallel_shard_degree=4", + "--activation_checkpoint.mode=selective", + "--activation_checkpoint.selective_ac_option=op", + "--model.flavor=debugmodel_varlen_attn", + ] + ], + "FSDP+VARLEN_ATTN + per op SAC", + "fsdp+varlen_attn+per_op_sac", + ngpu=4, + skip_rocm_test=True, + ), OverrideDefinitions( [ [ @@ -543,6 +556,21 @@ def build_features_test_list() -> list[OverrideDefinitions]: "validation_tp_cp_pp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--training.dataloader.num_workers", + "2", + "--training.dataloader.pin_memory", + "--training.dataloader.persistent_workers", + "--training.dataloader.prefetch_factor", + "4", + ], + ], + "Dataloader kwargs (via CLI args)", + "dataloader_kwargs", + ngpu=2, + ), ] return integration_tests_flavors diff --git a/tests/integration_tests/flux.py b/tests/integration_tests/flux.py index 321ac1280c..a7ed51832f 100755 --- a/tests/integration_tests/flux.py +++ b/tests/integration_tests/flux.py @@ -26,20 +26,15 @@ def build_flux_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.data_parallel_replicate_degree 2", "--parallelism.context_parallel_degree 2", - ] - ], - "HSDP+CP", - "hsdp+cp", - ngpu=8, - ), - OverrideDefinitions( - [ - [ "--validation.enable", - ] + "--validation.steps 5", + "--checkpoint.enable", + ], + [], ], - "Flux Validation Test", - "validation", + "HSDP+CP+Validation+Inference", + "hsdp+cp+validation+inference", + ngpu=8, ), ] return integration_tests_flavors @@ -63,7 +58,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir t5_encoder_version_arg = ( "--encoder.t5_encoder tests/assets/flux_test_encoders/t5-v1_1-xxl/" ) - tokenzier_path_arg = "--model.tokenizer_path tests/assets/tokenizer" + hf_assets_path_arg = "--model.hf_assets_path tests/assets/tokenizer" all_ranks = ",".join(map(str, range(test_flavor.ngpu))) @@ -73,7 +68,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd = f'TORCH_TRACE="{output_dir}/{test_name}/compile_trace" ' + cmd # save checkpoint (idx == 0) and load it for generation (idx == 1) - if test_name == "test_generate" and idx == 1: + if test_name == "hsdp+cp+validation+inference" and idx == 1: # For flux generation, test using inference script cmd = ( f"CONFIG_FILE={full_path} NGPU={test_flavor.ngpu} LOG_RANK={all_ranks} " @@ -84,7 +79,7 @@ def run_single_test(test_flavor: OverrideDefinitions, full_path: str, output_dir cmd += " " + random_init_encoder_arg cmd += " " + clip_encoder_version_arg cmd += " " + t5_encoder_version_arg - cmd += " " + tokenzier_path_arg + cmd += " " + hf_assets_path_arg if override_arg: cmd += " " + " ".join(override_arg) diff --git a/tests/integration_tests/models.py b/tests/integration_tests/models.py index 37f588765b..5ba1c18c59 100755 --- a/tests/integration_tests/models.py +++ b/tests/integration_tests/models.py @@ -32,6 +32,21 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "deepseek_v3_fsdp+ep+compile", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name deepseek_v3", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.expert_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule DualPipeV", + # AC is not supported for DualPipeV yet + "--activation_checkpoint.mode 'none'", + ], + ], + "PP dual pipe v schedule test", + "pp_dualpipev", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -64,6 +79,23 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "deepseek_v3_pp+fsdp+tp+ep+etp", ngpu=8, ), + OverrideDefinitions( + [ + [ + "--model.name deepseek_v3", + "--model.flavor debugmodel_flex_attn", + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule Interleaved1F1B", + "--parallelism.expert_parallel_degree 4", + "--activation_checkpoint.mode 'selective'", + "--activation_checkpoint.selective_ac_option 'op'", + ], + ], + "DeepSeek V3 Flex+PP+FSDP+EP+SACOP", + "deepseek_v3_flex+pp+fsdp+ep+sacop", + ngpu=8, + ), # Integration Test Cases for Qwen3 dense and MoE model OverrideDefinitions( [ @@ -110,6 +142,22 @@ def build_model_tests_list() -> list[OverrideDefinitions]: "llama4_pp+fsdp+tp+ep+compile", ngpu=8, ), + # Integration Test Cases for gpt-oss + OverrideDefinitions( + [ + [ + "--model.name gpt_oss", + "--parallelism.data_parallel_shard_degree 4", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.expert_parallel_degree 4", + "--parallelism.expert_tensor_parallel_degree 1", + "--compile.enable", + ], + ], + "Gpt-oss FSDP+TP+EP+compile", + "gpt_oss_fsdp+tp+ep+compile", + ngpu=8, + ), ] return model_tests diff --git a/tests/integration_tests/run_tests.py b/tests/integration_tests/run_tests.py index 011fa25554..7081215c83 100644 --- a/tests/integration_tests/run_tests.py +++ b/tests/integration_tests/run_tests.py @@ -25,9 +25,6 @@ } -TEST_WITH_ROCM = os.getenv("TEST_WITH_ROCM", "0") == "1" - - def _run_cmd(cmd): return subprocess.run([cmd], text=True, shell=True) @@ -83,6 +80,7 @@ def run_tests(args, test_list: list[OverrideDefinitions]): args.config_path ), f"Base config path {args.config_path} does not exist" + ran_any_test = False for test_flavor in test_list: # Filter by test_name if specified if args.test_name != "all" and test_flavor.test_name != args.test_name: @@ -92,7 +90,10 @@ def run_tests(args, test_list: list[OverrideDefinitions]): continue # Skip the test for ROCm - if TEST_WITH_ROCM and test_flavor.skip_rocm_test: + if ( + getattr(args, "gpu_arch_type", "cuda") == "rocm" + and test_flavor.skip_rocm_test + ): continue # Check if we have enough GPUs @@ -103,6 +104,14 @@ def run_tests(args, test_list: list[OverrideDefinitions]): ) else: run_single_test(test_flavor, args.config_path, args.output_dir) + ran_any_test = True + + if not ran_any_test: + available_tests = [t.test_name for t in test_list if not t.disabled] + logger.warning( + f"No tests were run for --test_name '{args.test_name}' in test suite '{args.test_suite}'.\n" + f"Available test names in '{args.test_suite}' suite: {available_tests}" + ) def main(): @@ -110,6 +119,12 @@ def main(): parser.add_argument( "output_dir", help="Directory to dump results generated by tests" ) + parser.add_argument( + "--gpu_arch_type", + default="cuda", + choices=["cuda", "rocm"], + help="GPU architecture type. Must be specified as either 'cuda' or 'rocm'.", + ) parser.add_argument( "--test_suite", default="features", diff --git a/tests/unit_tests/test_activation_checkpoint.py b/tests/unit_tests/test_activation_checkpoint.py index 202f7b1e48..2b05505e4a 100644 --- a/tests/unit_tests/test_activation_checkpoint.py +++ b/tests/unit_tests/test_activation_checkpoint.py @@ -19,12 +19,16 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, } @@ -84,7 +88,6 @@ def get_bw_flops(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_selective_ac = get_bw_flops(model_selective_ac) @@ -102,7 +105,6 @@ def get_bw_flops(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_first = get_bw_flops(model_with_force_first) @@ -119,7 +121,6 @@ def get_bw_flops(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_with_force_last = get_bw_flops(model_with_force_last) @@ -134,7 +135,6 @@ def get_bw_flops(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) flops_full_ac = get_bw_flops(model_with_full_ac) @@ -177,7 +177,6 @@ def get_act_mem(model_fn): model_selective_ac, ac_config_no_force, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_selective_ac = get_act_mem(model_selective_ac) @@ -194,7 +193,6 @@ def get_act_mem(model_fn): model_with_force_first, ac_config_with_force_first, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_first = get_act_mem(model_with_force_first) @@ -210,7 +208,6 @@ def get_act_mem(model_fn): model_with_force_last, ac_config_with_force_last, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_with_force_last = get_act_mem(model_with_force_last) @@ -224,7 +221,6 @@ def get_act_mem(model_fn): model_with_full_ac, ac_config_full_ac, model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) mem_full_ac = get_act_mem(model_with_full_ac) @@ -251,7 +247,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=[], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) model_force_first = ToyModule() @@ -264,7 +259,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["moe.router.gate"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) @@ -278,7 +272,6 @@ def test_correctness(self): per_op_sac_force_recompute_mm_shapes_by_fqns=["output"], ), model_compile_enabled=False, - use_flex_attn=False, op_sac_save_list=_op_sac_save_list, ) diff --git a/tests/unit_tests/test_compile_moe.py b/tests/unit_tests/test_compile_moe.py new file mode 100644 index 0000000000..52a6b99ef5 --- /dev/null +++ b/tests/unit_tests/test_compile_moe.py @@ -0,0 +1,81 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +import torch.nn as nn + +from torchtitan.config.job_config import Compile as CompileConfig +from torchtitan.models.llama4.infra.parallelize import apply_compile + + +class TransformerBlock(nn.Module): + def __init__(self, dim=512): + super().__init__() + self.attention = nn.Linear(dim, dim, bias=False) + self.mlp = nn.Linear(dim, dim, bias=False) + self.moe_enabled = False + + def forward(self, x): + x = self.attention(x) + x = self.mlp(x) + return x + + +class TinyModel(nn.Module): + def __init__(self, num_layers=2, dim=512): + super().__init__() + self.layers = nn.ModuleDict( + {str(i): TransformerBlock(dim) for i in range(num_layers)} + ) + + def forward(self, x): + for layer in self.layers.values(): + x = layer(x) + return x + + +class TestApplyCompile(unittest.TestCase): + def test_patched_once(self): + """ + Calls apply_compile multiple times, as in the case with PP. + But patches should only happen once + """ + unused_model1 = TinyModel(num_layers=2, dim=128) + unused_model2 = TinyModel(num_layers=2, dim=128) + compile_config = CompileConfig(backend="eager") + + apply_compile(unused_model1, compile_config, ep_enabled=True) + apply_compile(unused_model2, compile_config, ep_enabled=True) + + from torchtitan.models.moe import moe as moe_module + + # Generate sample inputs for _run_experts_grouped_mm + num_experts = 8 + dim = 128 + hidden_dim = 256 + w1 = torch.randn(num_experts, hidden_dim, dim) + w2 = torch.randn(num_experts, dim, hidden_dim) + w3 = torch.randn(num_experts, hidden_dim, dim) + num_tokens_per_expert = torch.tensor( + [10, 8, 12, 9, 11, 7, 10, 13], dtype=torch.int32 + ) + total_tokens = num_tokens_per_expert.sum().item() + x = torch.randn(total_tokens, dim) + + # Call the function, should not error + output = moe_module._run_experts_grouped_mm( + w1, w2, w3, x, num_tokens_per_expert + ) + + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + print(f"Num tokens per expert: {num_tokens_per_expert}") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_dataloader.py b/tests/unit_tests/test_dataloader.py new file mode 100644 index 0000000000..82625e5e06 --- /dev/null +++ b/tests/unit_tests/test_dataloader.py @@ -0,0 +1,153 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +from torch.utils.data import IterableDataset + +from torchtitan.components.dataloader import ParallelAwareDataloader +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.config import ConfigManager + + +class DummyDataset(IterableDataset): + """A simple dummy dataset for testing.""" + + def __iter__(self): + for i in range(100): + yield {"input": i}, i + + +class DummyTokenizer(BaseTokenizer): + """A dummy tokenizer for testing that implements BaseTokenizer interface.""" + + def __init__(self): + super().__init__() + self.eos_id = 2 + + def encode( + self, text: str, add_bos: bool = False, add_eos: bool = False + ) -> list[int]: + # Simple encoding: convert each character to its ASCII value + tokens = [ord(c) for c in text] + if add_bos: + tokens.insert(0, 1) # BOS token + if add_eos: + tokens.append(self.eos_id) + return tokens + + def decode(self, token_ids: list[int]) -> str: + # Simple decoding: convert ASCII values back to characters + return "".join(chr(t) for t in token_ids if t > 2) + + def get_vocab_size(self) -> int: + return 256 # ASCII range + + +class TestParallelAwareDataloader(unittest.TestCase): + def test_dataloader_yields_correct_batches(self): + """Test that the dataloader correctly yields batched data from the dataset.""" + dataset = DummyDataset() + batch_size = 4 + + dataloader = ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + batch_size=batch_size, + ) + + batches = list(dataloader) + + # DummyDataset yields 100 items, so we expect 25 batches of size 4 + self.assertEqual(len(batches), 25) + + # Check first batch structure and values + first_batch_input, first_batch_label = batches[0] + self.assertEqual(len(first_batch_input["input"]), batch_size) + self.assertEqual(len(first_batch_label), batch_size) + + # Verify first batch contains expected values (0, 1, 2, 3) + self.assertEqual(first_batch_input["input"].tolist(), [0, 1, 2, 3]) + self.assertEqual(first_batch_label.tolist(), [0, 1, 2, 3]) + + # Check last batch + last_batch_input, last_batch_label = batches[-1] + self.assertEqual(last_batch_input["input"].tolist(), [96, 97, 98, 99]) + self.assertEqual(last_batch_label.tolist(), [96, 97, 98, 99]) + + def test_validate_kwargs_rejects_invalid_kwargs(self): + """Test that passing invalid kwargs raises ValueError.""" + dataset = DummyDataset() + + with self.assertRaises(ValueError) as context: + ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + invalid_arg=42, + ) + + self.assertIn("Invalid dataloader kwargs", str(context.exception)) + self.assertIn("invalid_arg", str(context.exception)) + + def test_config_batch_size_overwritten_by_explicit_batch_size(self): + """Test that batch_size in config kwargs is overwritten by explicit batch_size.""" + dataset = DummyDataset() + + config_kwargs = {"batch_size": 2, "num_workers": 0} + + explicit_batch_size = 8 + + # Merge kwargs with explicit args taking precedence (same pattern as in dataset files) + dataloader_kwargs = { + **config_kwargs, + "batch_size": explicit_batch_size, + } + + dataloader = ParallelAwareDataloader( + dataset, + dp_rank=0, + dp_world_size=1, + **dataloader_kwargs, + ) + + # Verify that batch_size is the explicit one, not the config one + self.assertEqual(dataloader.batch_size, explicit_batch_size) + + def test_build_dataloader_with_job_config(self): + """Verify batch_size from job_config.training.local_batch_size is correctly used.""" + from torchtitan.hf_datasets.text_datasets import build_text_dataloader + + tokenizer = DummyTokenizer() + + config_manager = ConfigManager() + config = config_manager.parse_args( + [ + "--training.dataset", + "c4_test", + "--training.local_batch_size", + "8", + "--training.seq_len", + "512", + "--training.dataloader.num_workers", + "2", + ] + ) + + dataloader = build_text_dataloader( + tokenizer=tokenizer, + dp_world_size=1, + dp_rank=0, + job_config=config, + ) + + self.assertEqual(dataloader.batch_size, 8) + self.assertEqual(dataloader.num_workers, 2) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_parallel_dims.py b/tests/unit_tests/test_parallel_dims.py new file mode 100644 index 0000000000..86b860065e --- /dev/null +++ b/tests/unit_tests/test_parallel_dims.py @@ -0,0 +1,569 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest +from unittest.mock import patch + +import torch.distributed as dist +from torch.testing._internal.distributed._tensor.common_dtensor import ( + DTensorTestBase, + with_comms, +) +from torchtitan.distributed import ParallelDims + + +class TestParallelDimsValidation(unittest.TestCase): + """Test ParallelDims validation logic without mesh building.""" + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_basic_initialization(self): + """Test basic initialization with valid parameters.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_replicate, 2) + self.assertEqual(parallel_dims.dp_shard, 2) + self.assertEqual(parallel_dims.cp, 1) + self.assertEqual(parallel_dims.tp, 2) + self.assertEqual(parallel_dims.pp, 1) + self.assertEqual(parallel_dims.ep, 1) + self.assertEqual(parallel_dims.etp, 1) + self.assertEqual(parallel_dims.world_size, 8) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_auto_calculate_dp_shard(self): + """Test automatic calculation of dp_shard when set to -1.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=-1, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertEqual(parallel_dims.dp_shard, 2) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_world_size(self): + """Test validation fails when parallelism degrees don't match world_size.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=10, # Invalid: 2*2*1*2*1 = 8, not 10 + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_etp(self): + """Test validation fails when etp is not equal to tp or 1.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=4, + pp=1, + ep=2, + etp=2, # Invalid: etp must be tp or 1 when ep > 1 + world_size=8, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_zero_parallelism(self): + """Test validation fails when parallelism degree is 0.""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=0, # Invalid: must be >= 1 + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_validation_invalid_dp_shard(self): + """Test validation fails when dp_shard is invalid (not -1 and not >=1).""" + with self.assertRaises(AssertionError): + ParallelDims( + dp_replicate=1, + dp_shard=0, # Invalid: must be -1 or >= 1 + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_enabled_properties(self): + """Test all enabled properties.""" + # Test with DP enabled + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with CP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=2, + ) + self.assertFalse(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.cp_enabled) + self.assertTrue(parallel_dims.dp_cp_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + + # Test with EP and ETP enabled (EP * ETP must not contribute to world_size) + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with PP enabled + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=2, + ep=1, + etp=1, + world_size=2, + ) + self.assertTrue(parallel_dims.pp_enabled) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_fsdp_gradient_divide_factor(self): + """Test fsdp_gradient_divide_factor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=3, + cp=2, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=12, + ) + # Should be dp_replicate * dp_shard * cp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.fsdp_gradient_divide_factor, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_non_data_parallel_size(self): + """Test non_data_parallel_size calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=2, + tp=3, + pp=2, + ep=1, + etp=1, + world_size=48, + ) + # Should be cp * tp * pp = 2 * 3 * 2 = 12 + self.assertEqual(parallel_dims.non_data_parallel_size, 12) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_seq_len_divisor(self): + """Test seq_len_divisor calculation.""" + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=1, + cp=2, + tp=4, + pp=1, + ep=1, + etp=1, + world_size=16, + ) + # Should be tp * (cp * 2) = 4 * 4 = 16 + self.assertEqual(parallel_dims.seq_len_divisor, 16) + + +class TestParallelDimsMeshOperations(unittest.TestCase): + """Test ParallelDims mesh operations with single-rank distributed environment.""" + + def setUp(self): + """Initialize distributed environment for CPU testing.""" + if not dist.is_initialized(): + dist.init_process_group( + backend="gloo", + init_method="tcp://localhost:12356", + world_size=1, + rank=0, + ) + + def tearDown(self): + """Clean up distributed environment.""" + if dist.is_initialized(): + dist.destroy_process_group() + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_invalid_name(self): + """Test getting mesh with invalid name raises error.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + with self.assertRaises(ValueError) as context: + parallel_dims.get_mesh("invalid_mesh") + self.assertIn("Invalid mesh dim", str(context.exception)) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_lazy_initialization(self): + """Test that get_optional_mesh triggers build_mesh if not built yet.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + # Don't call build_mesh explicitly + self.assertEqual(len(parallel_dims._meshes), 0) + + # get_optional_mesh should trigger build_mesh + result = parallel_dims.get_optional_mesh("tp") + # Result is None because tp has size 1, but build_mesh should have been called + self.assertGreater(len(parallel_dims._meshes), 0) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_single_rank_mesh_operations(self): + """Comprehensive test for all single-rank mesh operations. + + This test verifies mesh building, mesh retrieval, mesh sizes, and property + access when all parallelism dimensions are set to 1 (single rank). + """ + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 1) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + + # Validate 1D mesh sizes - all should be 1 for single rank + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 1) + self.assertEqual(parallel_dims._meshes["fsdp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 1) + self.assertEqual(parallel_dims._meshes["batch"].size(), 1) + self.assertEqual(parallel_dims._meshes["loss"].size(), 1) + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual(parallel_dims._meshes["efsdp"].size(), 1) + + # Validate 2D mesh shapes + dp_replicate_fsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "fsdp"] + ) + self.assertIsNone(dp_replicate_fsdp_mesh) # Both dimensions have size 1 + dp_replicate_efsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "efsdp"] + ) + self.assertIsNone(dp_replicate_efsdp_mesh) # Both dimensions have size 1 + ep_etp_mesh = parallel_dims.get_optional_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 + + # Test get_optional_mesh returns None when all dimensions have size 1 + self.assertIsNone(parallel_dims.get_optional_mesh("tp")) + self.assertIsNone(parallel_dims.get_optional_mesh("dp_replicate")) + self.assertIsNone(parallel_dims.get_optional_mesh("pp")) + self.assertIsNone(parallel_dims.get_optional_mesh("cp")) + self.assertIsNone(parallel_dims.get_optional_mesh("fsdp")) + + # Test get_optional_mesh with list input + self.assertIsNone(parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"])) + + # Test get_all_one_dimensional_meshes returns empty when all dimensions have size 1 + one_d_meshes = parallel_dims.get_all_one_dimensional_meshes() + self.assertEqual(len(one_d_meshes), 0) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 1) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_get_mesh_with_list_input(self): + """Test get_optional_mesh accepts both string and list inputs.""" + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=1, + cp=1, + tp=1, + pp=1, + ep=1, + etp=1, + world_size=1, + ) + parallel_dims.build_mesh() + + # Should accept list input + result = parallel_dims.get_optional_mesh(["dp_replicate", "fsdp"]) + # Returns None because both dimensions have size 1 + self.assertIsNone(result) + + @patch("torchtitan.distributed.parallel_dims.device_type", "cpu") + def test_expert_parallelism_validation(self): + """Test expert parallelism configurations.""" + # EP with ETP = 1 (valid) - world_size = dp_replicate * dp_shard * cp * tp * pp + parallel_dims = ParallelDims( + dp_replicate=1, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=2, + etp=1, + world_size=2, # 1 * 2 * 1 * 1 * 1 = 2 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + + # Test with larger configuration + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=1, + pp=1, + ep=3, + etp=1, + world_size=4, # 2 * 2 * 1 * 1 * 1 = 4 + ) + self.assertTrue(parallel_dims.ep_enabled) + self.assertFalse(parallel_dims.etp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + + +class TestParallelDimsWorld8MeshOperations(DTensorTestBase): + """Test ParallelDims mesh operations with 8-rank distributed environment.""" + + @property + def world_size(self): + return 8 + + @with_comms + def test_world_size_8_mesh_operations(self): + """Comprehensive test for world_size=8 mesh operations. + + This test validates mesh building, mesh retrieval, mesh sizes, and properties + for a world_size=8 configuration with multiple parallelism dimensions enabled. + Configuration: dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1 (2*2*1*2*1 = 8) + """ + with patch( + "torchtitan.distributed.parallel_dims.device_type", self.device_type + ): + parallel_dims = ParallelDims( + dp_replicate=2, + dp_shard=2, + cp=1, + tp=2, + pp=1, + ep=1, + etp=1, + world_size=8, + ) + + # Test mesh building + world_mesh = parallel_dims.build_mesh() + self.assertIsNotNone(world_mesh) + self.assertEqual(world_mesh.size(), 8) + + # Verify all expected meshes are created + self.assertIsNotNone(parallel_dims._meshes) + self.assertIn("pp", parallel_dims._meshes) + self.assertIn("batch", parallel_dims._meshes) + self.assertIn("loss", parallel_dims._meshes) + self.assertIn("dp_replicate", parallel_dims._meshes) + self.assertIn("fsdp", parallel_dims._meshes) + self.assertIn("cp", parallel_dims._meshes) + self.assertIn("tp", parallel_dims._meshes) + self.assertIn("ep", parallel_dims._meshes) + self.assertIn("etp", parallel_dims._meshes) + self.assertIn("efsdp", parallel_dims._meshes) + + # Validate 1D mesh sizes match parallelism configuration + self.assertEqual(parallel_dims._meshes["pp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["batch"].size(), 4 + ) # dp_replicate * dp_shard = 2 * 2 + self.assertEqual( + parallel_dims._meshes["loss"].size(), 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual(parallel_dims._meshes["dp_replicate"].size(), 2) + self.assertEqual( + parallel_dims._meshes["fsdp"].size(), 2 + ) # dp_shard * cp = 2 * 1 + self.assertEqual(parallel_dims._meshes["cp"].size(), 1) + self.assertEqual(parallel_dims._meshes["tp"].size(), 2) + self.assertEqual(parallel_dims._meshes["ep"].size(), 1) + self.assertEqual(parallel_dims._meshes["etp"].size(), 1) + self.assertEqual( + parallel_dims._meshes["efsdp"].size(), 4 + ) # fsdp * tp / (etp * ep) = 2 * 2 / (1 * 1) = 4 + + # Validate 2D mesh shapes + dp_replicate_fsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp_mesh) + self.assertEqual( + dp_replicate_fsdp_mesh.shape, (2, 2) + ) # (dp_replicate, fsdp) + # efsdp mesh only exists when ep > 1, so dp_replicate_efsdp should be None when ep=1 + dp_replicate_efsdp_mesh = parallel_dims.get_optional_mesh( + ["dp_replicate", "efsdp"] + ) + self.assertIsNone(dp_replicate_efsdp_mesh) # efsdp disabled when ep=1 + ep_etp_mesh = parallel_dims.get_optional_mesh(["ep", "etp"]) + self.assertIsNone(ep_etp_mesh) # Both dimensions have size 1 + + # Test get_mesh returns valid meshes for enabled dimensions (size > 1) + self.assertIsNotNone(parallel_dims.get_mesh("tp")) + self.assertIsNotNone(parallel_dims.get_mesh("dp_replicate")) + self.assertIsNotNone(parallel_dims.get_mesh("fsdp")) + self.assertIsNotNone(parallel_dims.get_mesh("batch")) + self.assertIsNotNone(parallel_dims.get_mesh("loss")) + + # Test get_optional_mesh returns None for disabled dimensions (size = 1) + self.assertIsNone(parallel_dims.get_optional_mesh("pp")) + self.assertIsNone(parallel_dims.get_optional_mesh("cp")) + self.assertIsNone(parallel_dims.get_optional_mesh("ep")) + + # Test get_mesh with 2D mesh names + self.assertIsNotNone(parallel_dims.get_mesh(["dp_replicate", "fsdp"])) + hsdp_mesh = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertEqual(hsdp_mesh.shape, (2, 2)) + + # Test get_all_one_dimensional_meshes returns only meshes with size > 1 + one_d_meshes = parallel_dims.get_all_one_dimensional_meshes() + self.assertGreater(len(one_d_meshes), 0) + # Should include: dp_replicate, fsdp, tp, batch, loss, efsdp (all with size > 1) + self.assertIn("dp_replicate", one_d_meshes) + self.assertIn("fsdp", one_d_meshes) + self.assertIn("tp", one_d_meshes) + self.assertIn("batch", one_d_meshes) + self.assertIn("loss", one_d_meshes) + self.assertIn("efsdp", one_d_meshes) + # Should not include: pp, cp, ep, etp (all with size = 1) + self.assertNotIn("pp", one_d_meshes) + self.assertNotIn("cp", one_d_meshes) + self.assertNotIn("ep", one_d_meshes) + self.assertNotIn("etp", one_d_meshes) + + # Test that we can get 2D meshes via get_mesh() instead + dp_replicate_fsdp = parallel_dims.get_mesh(["dp_replicate", "fsdp"]) + self.assertIsNotNone(dp_replicate_fsdp) + self.assertEqual(dp_replicate_fsdp.ndim, 2) + + # Test world_mesh property + world_mesh_property = parallel_dims.world_mesh + self.assertIsNotNone(world_mesh_property) + self.assertEqual(world_mesh_property.size(), 8) + + # Validate enabled properties + self.assertTrue(parallel_dims.dp_enabled) + self.assertTrue(parallel_dims.dp_replicate_enabled) + self.assertTrue(parallel_dims.dp_shard_enabled) + self.assertTrue(parallel_dims.fsdp_enabled) + self.assertTrue(parallel_dims.tp_enabled) + self.assertFalse(parallel_dims.cp_enabled) + self.assertFalse(parallel_dims.pp_enabled) + self.assertFalse(parallel_dims.ep_enabled) + + # Validate calculated properties + self.assertEqual( + parallel_dims.fsdp_gradient_divide_factor, 4 + ) # dp_replicate * dp_shard * cp = 2 * 2 * 1 + self.assertEqual( + parallel_dims.non_data_parallel_size, 2 + ) # cp * tp * pp = 1 * 2 * 1 + self.assertEqual( + parallel_dims.seq_len_divisor, 4 + ) # tp * (cp * 2) = 2 * (1 * 2) = 2 * 2 + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit_tests/test_set_determinism.py b/tests/unit_tests/test_set_determinism.py index c8087731c5..2be196b7e1 100644 --- a/tests/unit_tests/test_set_determinism.py +++ b/tests/unit_tests/test_set_determinism.py @@ -13,8 +13,8 @@ from torchtitan.distributed.utils import set_determinism -class FakeDeviceMesh: - """Fake DeviceMesh for testing seed uniqueness. +class FakeParallelDims: + """Fake ParallelDims for testing seed uniqueness. Args: mesh_dim_names: List of dimension names (e.g., ["dp", "pp", "tp"]) @@ -26,25 +26,68 @@ def __init__(self, mesh_dim_names, mesh_sizes, rank_coords): self.mesh_dim_names = mesh_dim_names self.mesh_sizes = dict(zip(mesh_dim_names, mesh_sizes)) self.rank_coords = dict(zip(mesh_dim_names, rank_coords)) - - def __getitem__(self, key): - """Return a submesh for the given dimension(s).""" + # Calculate world_size as product of all mesh sizes + self.world_size = 1 + for size in mesh_sizes: + self.world_size *= size + + # Add individual parallelism degree attributes to match real ParallelDims interface + self.pp = self.mesh_sizes.get("pp", 1) + self.tp = self.mesh_sizes.get("tp", 1) + self.cp = self.mesh_sizes.get("cp", 1) + self.dp_replicate = self.mesh_sizes.get("dp_replicate", 1) + self.dp_shard = self.mesh_sizes.get("dp_shard", 1) + self.ep = self.mesh_sizes.get("ep", 1) + self.etp = self.mesh_sizes.get("etp", 1) + + # For backward compatibility with 'dp' dimension name + if "dp" in self.mesh_sizes: + self.dp_replicate = self.mesh_sizes["dp"] + + # Create a world_mesh mock + self.world_mesh = MagicMock() + self.world_mesh.device_type = "cpu" + + def get_mesh(self, key): + """Return a submesh for the given dimension.""" if isinstance(key, str): # Single dimension + if key not in self.mesh_dim_names: + return None submesh = MagicMock() submesh.get_local_rank.return_value = self.rank_coords[key] submesh.size.return_value = self.mesh_sizes[key] submesh.get_coordinate.return_value = self.rank_coords[key] + submesh.device_type = "cpu" return submesh elif isinstance(key, list): - # Multiple dimensions + # Multiple dimensions - check if all exist + if not all(dim in self.mesh_dim_names for dim in key): + return None submesh = MagicMock() # For multiple dimensions, get_coordinate should return None # since we're not testing this path submesh.get_coordinate.return_value = None + submesh.device_type = "cpu" return submesh else: - raise ValueError(f"Unsupported key type: {type(key)}") + return None + + def get_optional_mesh(self, key): + """Return a submesh for the given dimension, or None if not available. + + This is the same as get_mesh() for FakeParallelDims since get_mesh() + already returns None for unavailable meshes. + """ + return self.get_mesh(key) + + def get_all_meshes(self): + """Return a dict of all meshes.""" + return {dim: self.get_mesh(dim) for dim in self.mesh_dim_names} + + def __getitem__(self, key): + """Return a submesh for the given dimension(s) - for backward compatibility.""" + return self.get_mesh(key) def get_coordinate(self): """Return the coordinate tuple for this rank.""" @@ -85,12 +128,12 @@ def test_seed_uniqueness_2d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_rank, pp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims(mesh_dim_names, mesh_sizes, rank_coords) # Call set_determinism with distinct seeds only on PP dimension debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], @@ -154,12 +197,14 @@ def test_seed_uniqueness_3d_mesh(self, mock_get_rank, mock_get_world_size): # Create fake mesh for this rank rank_coords = (dp_shard_rank, dp_replicate_rank, tp_rank) - fake_mesh = FakeDeviceMesh(mesh_dim_names, mesh_sizes, rank_coords) + fake_mesh = FakeParallelDims( + mesh_dim_names, mesh_sizes, rank_coords + ) # Call set_determinism with distinct seeds on dp_shard and dp_replicate only debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], @@ -218,12 +263,15 @@ def test_set_determinism_single_gpu(self, mock_get_rank, mock_get_world_size): base_seed = 42 fake_mesh = MagicMock() - fake_mesh.mesh_dim_names = None - fake_mesh.get_coordinate.return_value = None + fake_mesh.world_size = 1 + fake_mesh.world_mesh = MagicMock() + fake_mesh.get_mesh.return_value = None + fake_mesh.get_optional_mesh.return_value = None + fake_mesh.get_all_meshes.return_value = {} debug_config = DebugConfig(seed=base_seed, deterministic=False) set_determinism( - world_mesh=fake_mesh, + parallel_dims=fake_mesh, device=self.device, debug_config=debug_config, distinct_seed_mesh_dims=["pp"], diff --git a/tests/unit_tests/test_train_spec.py b/tests/unit_tests/test_train_spec.py index 2f8986705e..07d5cd94e6 100644 --- a/tests/unit_tests/test_train_spec.py +++ b/tests/unit_tests/test_train_spec.py @@ -26,9 +26,9 @@ ) -class FakeModel(nn.Module, ModelProtocol): +class FakeModel(ModelProtocol): def __init__(self, model_args: BaseModelArgs) -> None: - super().__init__() + super().__init__(model_args) self.linear = nn.Linear(8, 8) def forward(self, x: torch.Tensor) -> torch.Tensor: diff --git a/torchtitan/__init__.py b/torchtitan/__init__.py index 176bce9b60..52c3ff3e22 100644 --- a/torchtitan/__init__.py +++ b/torchtitan/__init__.py @@ -4,5 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from importlib.metadata import version + # Import to register quantization modules. import torchtitan.components.quantization # noqa: F401 + +try: + __version__ = version("torchtitan") +except Exception as e: + __version__ = "0.0.0+unknown" diff --git a/torchtitan/components/checkpoint.py b/torchtitan/components/checkpoint.py index 9e7bc81e27..776cd05810 100644 --- a/torchtitan/components/checkpoint.py +++ b/torchtitan/components/checkpoint.py @@ -42,7 +42,10 @@ set_model_state_dict, StateDictOptions, ) -from torch.distributed.checkpoint.state_dict_saver import AsyncCheckpointerType +from torch.distributed.checkpoint.state_dict_saver import ( + AsyncCheckpointerType, + AsyncSaveResponse, +) from torch.distributed.checkpoint.stateful import Stateful from torchtitan.components.dataloader import BaseDataLoader @@ -188,6 +191,9 @@ class CheckpointManager: """ + mp_queue_send: queue.Queue + purge_thread: threading.Thread | None + def __init__( self, dataloader: BaseDataLoader | None, @@ -225,12 +231,13 @@ def __init__( ) if self.ft_manager and not self.enable_ft_dataloader_checkpoints: - logger.warn( + logger.warning( "Fault tolerance is enabled but enable_ft_dataloader_checkpoints is False. " "This means replicas can retrain over the same data multiple times, which can result in overfitting." ) if self.ft_manager: + # pyrefly: ignore [missing-attribute] optimizers.init_cache_state_dict() def state_dict(): @@ -250,7 +257,9 @@ def load_state_dict(state_dict): for k, v in state_dict.items(): self.states[k].load_state_dict(v) + # pyrefly: ignore [missing-attribute] self.ft_manager.set_state_dict_fns(load_state_dict, state_dict) + # pyrefly: ignore [missing-attribute] self.ft_replica_id = ft_manager.replica_id async_mode = checkpoint_config.async_mode.lower() @@ -361,7 +370,7 @@ def dcp_save( async_mode: AsyncMode, enable_garbage_collection: bool = False, to_hf: bool = False, - ) -> Future | None: + ) -> Future | AsyncSaveResponse | None: """Save the checkpoint with dcp. Args: state_dict (dict): The state dict to save. @@ -374,7 +383,7 @@ def dcp_save( Future: The future object if the checkpoint is async, otherwise None. """ - ret: Future | None = None + ret: Future | AsyncSaveResponse | None = None storage_writer: HuggingFaceStorageWriter | None = None checkpoint_save_id: str | None = None @@ -415,6 +424,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, ) elif async_mode == AsyncMode.ASYNC_WITH_PINNED_MEM: @@ -422,6 +432,7 @@ def dcp_save( state_dict, storage_writer=storage_writer, checkpoint_id=checkpoint_save_id, + # pyrefly: ignore [bad-argument-type] process_group=self.pg, async_checkpointer_type=AsyncCheckpointerType.PROCESS, async_stager=self.stager, @@ -436,10 +447,12 @@ def dcp_save( process_group=self.pg, ) + # pyrefly: ignore [missing-attribute] if to_hf and self.sd_adapter.fqn_to_index_mapping: consolidate_safetensors_files_on_every_rank( input_dir=os.path.join(checkpoint_id, "sharded"), output_dir=checkpoint_id, + # pyrefly: ignore [bad-argument-type] fqn_to_index_mapping=self.sd_adapter.fqn_to_index_mapping, num_threads=5, ) @@ -572,7 +585,9 @@ def save(self, curr_step: int, last_step: bool = False) -> None: begin = time.monotonic() if not self.enable_ft_dataloader_checkpoints or ( - self.ft_manager and self.ft_manager.participating_rank() == 0 + self.ft_manager + # pyrefly: ignore [missing-attribute] + and self.ft_manager.participating_rank() == 0 ): logger.info("Saving the checkpoint (or staging if async is enabled).") checkpoint_id = self._create_checkpoint_id(curr_step) @@ -598,17 +613,21 @@ def save(self, curr_step: int, last_step: bool = False) -> None: ) GarbageCollection.collect("GC collection invoked by checkpointer.") if self.stager is None: + # pyrefly: ignore[bad-assignment] self.stager = DefaultStager(StagingOptions(True, True, True, True)) result = self.dcp_save( states, checkpoint_id=checkpoint_id, async_mode=self.async_mode, ) + # pyrefly: ignore [missing-attribute] self.save_future = result.upload_completion + # pyrefly: ignore [missing-attribute] self.staging_future = result.staging_completion self.staging = True elif self.async_mode == AsyncMode.ASYNC: GarbageCollection.collect("GC collection invoked by checkpointer.") + # pyrefly: ignore[bad-assignment] self.save_future = self.dcp_save( states, checkpoint_id=checkpoint_id, async_mode=self.async_mode ) @@ -630,6 +649,7 @@ def save(self, curr_step: int, last_step: bool = False) -> None: assert self.ft_manager is not None logger.info( "Replica %d doesn't save checkpoint.", + # pyrefly: ignore [missing-attribute] self.ft_manager.participating_rank(), ) @@ -689,6 +709,7 @@ def load(self, step: int = -1) -> bool: f"loading from HF safetensors from --checkpoint.initial_load_path: {self.initial_load_path}" ) elif from_hf: + # pyrefly: ignore [missing-attribute] checkpoint_id = self.sd_adapter.hf_assets_path if not os.path.isdir(checkpoint_id): raise ValueError( @@ -696,6 +717,7 @@ def load(self, step: int = -1) -> bool: Either make sure hf_assets_path is correct or provide a valid checkpoint.initial_load_path" ) logger.info( + # pyrefly: ignore [missing-attribute] f"loading HF safetensors from --model.hf_assets_path: {self.sd_adapter.hf_assets_path}" ) else: @@ -746,6 +768,7 @@ def maybe_wait_for_staging(self) -> None: with ``async_checkpoint_with_pinned_memory``. """ if self.enable_staging and self.staging: + # pyrefly: ignore [missing-attribute] self.staging_future.result() self.staging = False @@ -791,6 +814,7 @@ def _ft_save(self, step: int) -> None: begin = time.monotonic() self._async_wait() checkpoint_id = self._create_checkpoint_id(step, folder=self._ft_folder()) + # pyrefly: ignore[bad-assignment] self.save_future = self.dcp_save( self.ft_states, checkpoint_id=checkpoint_id, async_mode=AsyncMode.ASYNC ) @@ -937,7 +961,7 @@ def _async_wait(self) -> None: ): if self.save_future is not None: self.save_future.result() - self.save_future = None + self.save_future = None # pyrefly: ignore[bad-assignment] elif self.save_future is not None: raise RuntimeError( "self.save_future is not None, but self.async_mode is not enabled " @@ -951,6 +975,7 @@ def _purge_stale_checkpoints(self): and os.path.isdir(self.folder) and ( not self.enable_ft_dataloader_checkpoints + # pyrefly: ignore [missing-attribute] or (self.ft_manager and self.ft_manager.participating_rank() == 0) ) ): diff --git a/torchtitan/components/dataloader.py b/torchtitan/components/dataloader.py index 071af84d54..a1fc08e39f 100644 --- a/torchtitan/components/dataloader.py +++ b/torchtitan/components/dataloader.py @@ -6,9 +6,9 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. +import inspect import pickle from abc import ABC, abstractmethod -from collections.abc import Callable from typing import Any from torch.distributed.checkpoint.stateful import Stateful @@ -16,6 +16,7 @@ from torchdata.stateful_dataloader import StatefulDataLoader from torchtitan.tools.logging import logger + # NOTE: This class deliberately inherits from `Exception` and not `StopIteration`. # According to PEP 479, raising a `StopIteration` or its subclass from within a # generator will wrap it in a `RuntimeError`. Since this exception is designed @@ -41,6 +42,7 @@ def __iter__(self): ... +# pyrefly: ignore [inconsistent-inheritance] class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): """Dataloader that is aware of distributed data parallelism. @@ -52,28 +54,63 @@ class ParallelAwareDataloader(StatefulDataLoader, BaseDataLoader): dataset (IterableDataset): The dataset to iterate over. dp_rank: Data parallelism rank for this dataloader. dp_world_size: The world size of the data parallelism. - batch_size: The batch size to use for each iteration. - collate_fn: Optional function to collate samples in a batch. + **kwargs: Additional keyword arguments passed to StatefulDataLoader (e.g., + batch_size, collate_fn, num_workers, persistent_workers, prefetch_factor, + pin_memory). """ dp_rank: int dp_world_size: int - batch_size: int def __init__( self, dataset: IterableDataset, dp_rank: int, dp_world_size: int, - batch_size: int, - collate_fn: Callable | None = None, + **kwargs, ): + self._validate_kwargs(kwargs) + self.dp_world_size = dp_world_size self.dp_rank = dp_rank - self.batch_size = batch_size - super().__init__(dataset, batch_size, collate_fn=collate_fn) self._rank_id = f"dp_rank_{dp_rank}" + super().__init__(dataset, **kwargs) + + @staticmethod + def _validate_kwargs(kwargs: dict[str, Any]) -> None: + """Validate and sanitize kwargs passed to the dataloader. + + Args: + kwargs: Dictionary of keyword arguments to validate. This dict is + modified in-place to remove invalid combinations. + + Raises: + ValueError: If 'dataset' is in kwargs or if any invalid kwargs are passed. + """ + if "dataset" in kwargs: + raise ValueError( + "'dataset' should not be passed in kwargs; " + "it must be provided as the first positional argument." + ) + + sig = inspect.signature(StatefulDataLoader.__init__) + valid_kwargs = frozenset( + name for name in sig.parameters.keys() if name not in ("self", "dataset") + ) + invalid_kwargs = set(kwargs.keys()) - valid_kwargs + if invalid_kwargs: + raise ValueError( + f"Invalid dataloader kwargs: {invalid_kwargs}. " + f"Valid kwargs are: {sorted(valid_kwargs)}" + ) + + # persistent_workers and prefetch_factor are only valid when num_workers > 0. + # Removing them here if num_workers is 0 to avoid StatefulDataLoader errors + if kwargs.get("num_workers", 0) == 0: + kwargs.pop("persistent_workers", None) + kwargs.pop("prefetch_factor", None) + def state_dict(self) -> dict[str, Any]: # Store state only for dp rank to avoid replicating the same state across other dimensions. return { diff --git a/torchtitan/components/ft/manager.py b/torchtitan/components/ft/manager.py index 5d64d34b09..03778dd6d0 100644 --- a/torchtitan/components/ft/manager.py +++ b/torchtitan/components/ft/manager.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import importlib +import importlib.util from contextlib import nullcontext from datetime import timedelta from typing import Callable, ContextManager, Optional, TYPE_CHECKING, Union diff --git a/torchtitan/components/lr_scheduler.py b/torchtitan/components/lr_scheduler.py index 6384feb641..15a3fc6bd1 100644 --- a/torchtitan/components/lr_scheduler.py +++ b/torchtitan/components/lr_scheduler.py @@ -176,6 +176,8 @@ def linear_warmup_stable_decay( curr_adjustment = 1 - math.sqrt(progress) elif lr_decay_type == "cosine": curr_adjustment = 0.5 * (1.0 + math.cos(math.pi * progress)) + else: + raise ValueError(f"Unknown lr_decay_type: {lr_decay_type}") curr_adjustment = min_lr_factor + (1 - min_lr_factor) * curr_adjustment return curr_adjustment diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 62ec4331ef..b3e4499329 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -36,12 +36,15 @@ "max_reserved_pct", "num_alloc_retries", "num_ooms", + "nvidia_smi_used_gib", # nvidia-smi reported memory for verification + "nvidia_smi_used_pct", ], ) class DeviceMemoryMonitor: def __init__(self, device: str = f"{device_type}:0"): + # pyrefly: ignore [read-only] self.device = torch.device(device) # device object self.device_name = device_module.get_device_name(self.device) self.device_index = device_module.current_device() @@ -62,6 +65,48 @@ def _to_gib(self, memory_in_bytes): def _to_pct(self, memory): return 100 * memory / self.device_capacity + def _get_nvidia_smi_memory(self): + """Get GPU memory usage from nvidia-smi for verification.""" + try: + import subprocess + + # In SLURM with CUDA_VISIBLE_DEVICES, PyTorch device index 0-7 maps to + # physical GPUs listed in CUDA_VISIBLE_DEVICES. We need the physical GPU index. + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible: + # CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" means device 0 is physical GPU 0 + # But it could also be "4,5,6,7,0,1,2,3" meaning device 0 is physical GPU 4 + visible_gpus = [ + int(x.strip()) for x in cuda_visible.split(",") if x.strip() + ] + if self.device_index < len(visible_gpus): + physical_gpu_index = visible_gpus[self.device_index] + else: + physical_gpu_index = self.device_index + else: + physical_gpu_index = self.device_index + + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + f"--id={physical_gpu_index}", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # nvidia-smi reports in MiB + used_mib = float(result.stdout.strip()) + used_gib = used_mib / 1024 + used_pct = (used_mib * 1024 * 1024) / self.device_capacity * 100 + return used_gib, used_pct + except Exception: + pass + return -1.0, -1.0 + def get_peak_stats(self): device_info = device_module.memory_stats(self.device) @@ -83,6 +128,9 @@ def get_peak_stats(self): if num_ooms > 0: logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") + # Get nvidia-smi memory for verification + nvidia_smi_gib, nvidia_smi_pct = self._get_nvidia_smi_memory() + return DeviceMemStats( max_active_gib, max_active_pct, @@ -90,6 +138,8 @@ def get_peak_stats(self): max_reserved_pct, num_retries, num_ooms, + nvidia_smi_gib, + nvidia_smi_pct, ) def reset_peak_stats(self): @@ -164,16 +214,20 @@ def __init__( # Create logging directory os.makedirs(log_dir, exist_ok=True) - if group is None: - group = wandb.sdk.lib.runid.generate_id() self.wandb.init( entity=os.getenv("WANDB_TEAM", None), project=os.getenv("WANDB_PROJECT", "torchtitan"), name=os.getenv("WANDB_RUN_NAME", None), + id=os.getenv("WANDB_RUN_ID", None), + notes=os.getenv("WANDB_RUN_NOTES", None), + tags=os.getenv("WANDB_RUN_TAGS", None), + group=os.getenv("WANDB_RUN_GROUP", None), + job_type=os.getenv("WANDB_RUN_JOB_TYPE", None), + resume_from=os.getenv("WANDB_RESUME_FROM", None), + fork_from=os.getenv("WANDB_FORK_FROM", None), dir=log_dir, config=job_config.to_dict(), - group=group, ) logger.info("WandB logging enabled") @@ -384,7 +438,7 @@ class MetricsProcessor: device_memory_monitor: DeviceMemoryMonitor color: utils.NoColor | utils.Color - gpu_peak_flops: int + gpu_peak_flops: float ntokens_since_last_log: int data_loading_times: list[float] time_last_log: float @@ -496,7 +550,9 @@ def log( self.time_last_log = time.perf_counter() self.device_memory_monitor.reset_peak_stats() - def log_validation(self, loss: float, step: int): + def log_validation( + self, loss: float, step: int, extra_metrics: dict[str, Any] | None = None + ): time_delta = time.perf_counter() - self.time_last_log device_mem_stats = self.device_memory_monitor.get_peak_stats() @@ -514,6 +570,10 @@ def log_validation(self, loss: float, step: int): "validation_metrics/memory/max_reserved(GiB)": device_mem_stats.max_reserved_gib, "validation_metrics/memory/max_reserved(%)": device_mem_stats.max_reserved_pct, } + + if extra_metrics: + metrics.update(extra_metrics) + self.logger.log(metrics, step) color = self.color diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 77d317c0b9..c271923fdc 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -8,6 +8,7 @@ from typing import Any, Generic, Iterator, TypeVar import torch +import torch.distributed.tensor import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import CheckpointImpl from torch.distributed.checkpoint.state_dict import ( @@ -16,11 +17,13 @@ StateDictOptions, ) from torch.distributed.checkpoint.stateful import Stateful +from torch.distributed.tensor import Replicate from torch.optim import Optimizer from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config import Optimizer as OptimizerConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger # Dion optimizer availability will be checked lazily when needed DION_AVAILABLE = None @@ -76,6 +79,115 @@ def _check_muon_availability(): T = TypeVar("T", bound=Optimizer) +def preinit_optimizer_states_bf16(optimizers_container: "OptimizersContainer") -> None: + """ + Pre-initialize optimizer states (exp_avg, exp_avg_sq) directly in bfloat16. + This MUST be called BEFORE the first optimizer.step() to avoid fp32 allocation spike. + + This reduces optimizer state memory by ~50% (from fp32 to bf16). + States are allocated in bf16 from the start, avoiding the memory spike from fp32 allocation. + """ + total_params = 0 + total_bytes = 0 + dtype_device_samples = [] + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + for opt_idx, optimizer in enumerate(optimizers_container.optimizers): + for pg_idx, param_group in enumerate(optimizer.param_groups): + for p_idx, p in enumerate(param_group["params"]): + if p.requires_grad: + if total_params < 5: + dtype_device_samples.append( + f"param[{opt_idx}][{pg_idx}][{p_idx}]: dtype={p.dtype}, device={p.device}, shape={list(p.shape)}" + ) + + state = optimizer.state[p] + if len(state) == 0: + state["step"] = torch.tensor(0, dtype=torch.float32, device=p.device) + state["exp_avg"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + total_params += 1 + bytes_per_element = 2 if p.dtype == torch.bfloat16 else 4 + total_bytes += p.numel() * 2 * bytes_per_element + + if total_params <= 3: + logger.info( + f"[Rank {rank}] State init sample: param dtype={p.dtype}, device={p.device}, " + f"exp_avg dtype={state['exp_avg'].dtype}, device={state['exp_avg'].device}" + ) + + for sample in dtype_device_samples: + logger.info(f"[Rank {rank}] {sample}") + + logger.info( + f"[Rank {rank}] Pre-initialized {total_params} optimizer states matching param dtype, " + f"this rank: {total_bytes / 1e9:.2f} GB" + ) + + +class BF16StateOptimizersContainer(Generic[T]): + """ + Wrapper that pre-initializes optimizer states in bfloat16 BEFORE first step. + This prevents the memory spike from fp32 state allocation. + + IMPORTANT: Call init_bf16_states() BEFORE the first step() to avoid + rank skew during state allocation. This should be called after model + setup but before training starts, ideally with a barrier afterwards. + """ + + def __init__( + self, + base_container: "OptimizersContainer", + state_dtype: torch.dtype = torch.bfloat16, + ): + self._base = base_container + self._state_dtype = state_dtype + self._states_initialized = False + + def init_bf16_states(self): + """ + Pre-initialize optimizer states in bf16. + Call this BEFORE training starts, then call a distributed barrier. + This avoids rank skew during the first optimizer.step(). + """ + if not self._states_initialized: + logger.info("Pre-initializing optimizer states in bfloat16...") + preinit_optimizer_states_bf16(self._base) + self._states_initialized = True + logger.info("BF16 optimizer state pre-initialization complete.") + + def step(self, *args, **kwargs) -> None: + if not self._states_initialized: + logger.warning( + "BF16 optimizer states not pre-initialized! " + "Call init_bf16_states() before training to avoid rank skew." + ) + self.init_bf16_states() + self._base.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs) -> None: + self._base.zero_grad(*args, **kwargs) + + def state_dict(self) -> dict[str, Any]: + return self._base.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._base.load_state_dict(state_dict) + + def __iter__(self): + return iter(self._base) + + def __len__(self): + return len(self._base) + + def __getattr__(self, name): + return getattr(self._base, name) + + class OptimizersContainer(Optimizer, Stateful, Generic[T]): """A container for multiple optimizers. @@ -127,6 +239,7 @@ def __iter__(self) -> Iterator[T]: def __len__(self) -> int: return len(self.optimizers) + # pyrefly: ignore [bad-override] def step(self, *args, **kwargs) -> None: for optimizer in self.optimizers: optimizer.step(*args, **kwargs) @@ -209,9 +322,11 @@ def optim_hook(param) -> None: ) self._post_init(all_params, optimizer_kwargs) + # pyrefly: ignore [bad-override] def step(self) -> None: pass + # pyrefly: ignore [bad-override] def zero_grad(self) -> None: pass @@ -504,7 +619,15 @@ def build_optimizers( use_ft_optimizer=ft_manager.use_async_quorum, ) - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + container = OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + # Wrap with BF16 state container if configured + state_dtype = getattr(optimizer_config, "state_dtype", "float32") + if state_dtype == "bfloat16": + logger.info("Using bfloat16 optimizer states (will pre-init before first step)") + return BF16StateOptimizersContainer(container, torch.bfloat16) + + return container def build_optimizers_with_moe_load_balancing( @@ -522,9 +645,12 @@ def build_optimizers_with_moe_load_balancing( def _should_register_moe_balancing_hook(model_parts: list[nn.Module]) -> bool: for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if transformer_block.moe_enabled: # Assumption: load_balance_coeff is set universally on all moe blocks. + # pyrefly: ignore [missing-attribute] return bool(transformer_block.moe.load_balance_coeff) return False @@ -536,18 +662,20 @@ def _update_expert_bias( model_parts: list[nn.Module], parallel_dims: ParallelDims, ): - dp_cp_mesh = ( - parallel_dims.world_mesh["dp_cp"] if parallel_dims.dp_cp_enabled else None - ) + loss_mesh = parallel_dims.get_optional_mesh("loss") # TODO: Currently this sync is blocking (thus exposed) and happens on the # default compute stream. Need to assess if this is OK performance-wise. tokens_per_expert_list = [] for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] if transformer_block.moe.load_balance_coeff is None: return + # pyrefly: ignore [missing-attribute] tokens_per_expert = transformer_block.moe.tokens_per_expert if _is_recomputation_enabled(transformer_block): # TODO: This is a hack, we assume with full AC, the tokens_per_expert is counted twice. @@ -559,19 +687,30 @@ def _update_expert_bias( tokens_per_expert_by_layer = torch.vstack(tokens_per_expert_list) - if dp_cp_mesh is not None: - # Perform single all-reduce to get global statistics across all processes - pg = dp_cp_mesh.get_group() - torch.distributed.all_reduce( - tokens_per_expert_by_layer, group=pg, op=torch.distributed.ReduceOp.SUM - ) + if loss_mesh is not None: + if isinstance(tokens_per_expert_by_layer, torch.distributed.tensor.DTensor): + tokens_per_expert_by_layer = tokens_per_expert_by_layer.redistribute( + placements=[Replicate()] + * tokens_per_expert_by_layer.device_mesh.ndim + ) + else: + # Perform single all-reduce to get global statistics across all processes + pg = loss_mesh.get_group() + torch.distributed.all_reduce( + tokens_per_expert_by_layer, + group=pg, + op=torch.distributed.ReduceOp.SUM, + ) moe_layer_idx = 0 with torch.no_grad(): for model_part in model_parts: + # pyrefly: ignore [not-callable] for transformer_block in model_part.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue + # pyrefly: ignore [missing-attribute] moe = transformer_block.moe tokens_per_expert = tokens_per_expert_by_layer[ diff --git a/torchtitan/components/quantization/__init__.py b/torchtitan/components/quantization/__init__.py index de94c37b3e..49faf60733 100644 --- a/torchtitan/components/quantization/__init__.py +++ b/torchtitan/components/quantization/__init__.py @@ -42,7 +42,7 @@ def _validate(job_config: JobConfig): # quantization converter format: # `quantize.[linear | grouped_mm].[float8 | mx]` quantization_type = lambda converter: converter.split(".")[-1] - existing_quantization_converter = None + existing_quantization_converter: str | None = None for converter in job_config.model.converters: if "quantize" in converter: if existing_quantization_converter is None: diff --git a/torchtitan/components/quantization/float8.py b/torchtitan/components/quantization/float8.py index 86932a17bd..9b575876e7 100644 --- a/torchtitan/components/quantization/float8.py +++ b/torchtitan/components/quantization/float8.py @@ -6,6 +6,7 @@ from functools import partial import torch +import torch._inductor.config import torch.nn as nn from torchtitan.components.quantization import ( FP8_GROUP_ALIGNMENT_SIZE, diff --git a/torchtitan/components/quantization/mx.py b/torchtitan/components/quantization/mx.py index a474cc3918..3bdd250c15 100644 --- a/torchtitan/components/quantization/mx.py +++ b/torchtitan/components/quantization/mx.py @@ -19,7 +19,7 @@ from torchtitan.models.moe.utils import set_token_group_alignment_size_m from torchtitan.protocols.model_converter import register_model_converter from torchtitan.tools.logging import logger -from torchtitan.tools.utils import has_cuda_capability +from torchtitan.tools.utils import has_cuda_capability, has_rocm_capability from .utils import module_filter_fn @@ -39,9 +39,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): ) # Can be removed if we enable the emulated versions - assert has_cuda_capability( - 10, 0 - ), "MXFP8 is only supported on SM100 or architectures" + assert has_cuda_capability(10, 0) or has_rocm_capability( + 9, 5 + ), "MXFP8 is only supported on CUDA SM100 or later, or ROCm gfx950 or later" # TP not yet supported with torch.compile model_compile_enabled = ( diff --git a/torchtitan/components/tokenizer.py b/torchtitan/components/tokenizer.py index 022fcbc266..6956d3298f 100644 --- a/torchtitan/components/tokenizer.py +++ b/torchtitan/components/tokenizer.py @@ -56,6 +56,7 @@ def __init__( # Initialize BOS/EOS token attributes (frequently used) self.bos_id = None + # pyrefly: ignore [bad-assignment] self.eos_id = None self.bos_token = None self.eos_token = None @@ -217,14 +218,15 @@ def _process_special_token( # Store BOS/EOS tokens as class attributes if they match if token_str == config_bos_token: - self.bos_token = token_str + self.bos_token = token_str # pyrefly: ignore[bad-assignment] self.bos_id = ( + # pyrefly: ignore[bad-assignment] token_id if token_id is not None else self.tokenizer.token_to_id(token_str) ) elif token_str == config_eos_token: - self.eos_token = token_str + self.eos_token = token_str # pyrefly: ignore[bad-assignment] self.eos_id = ( token_id if token_id is not None @@ -316,7 +318,7 @@ def _infer_should_add_bos_eos(self): # First, determine if underlying tokenizer auto-adds BOS/EOS tokens empirically encoded_empty_str = self.tokenizer.encode("").ids if self.bos_id is not None and self.bos_id in encoded_empty_str: - self.hf_adds_bos = True + self.hf_adds_bos = True # pyrefly: ignore[bad-assignment] if self.eos_id is not None and self.eos_id in encoded_empty_str: self.hf_adds_eos = True diff --git a/torchtitan/components/validate.py b/torchtitan/components/validate.py index 93fb68a3cc..7917ff2fc2 100644 --- a/torchtitan/components/validate.py +++ b/torchtitan/components/validate.py @@ -4,7 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Generator +from collections.abc import Callable +from contextlib import AbstractContextManager +from typing import Any, TypeAlias import torch import torch.nn as nn @@ -15,10 +17,13 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_validation_dataloader from torchtitan.tools import utils from torchtitan.tools.logging import logger +ValidationContext: TypeAlias = Callable[[], AbstractContextManager[None]] + class BaseValidator: def __init__(self, job_config: JobConfig): @@ -52,14 +57,15 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer self.parallel_dims = parallel_dims self.loss_fn = loss_fn self.validation_dataloader = build_text_validation_dataloader( @@ -82,7 +88,73 @@ def __init__( "unequal sample counts across ranks when dataset is exhausted." ) + def post_dataloading_process( + self, + input_dict: dict[str, torch.Tensor], + labels: torch.Tensor, + model_parts: list[nn.Module], + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: + """ + Post-processing hook after data loading and before model forward pass. + + This method processes the raw data from the dataloader and prepares it for + the model's forward pass. It separates the main input tensor from auxiliary + inputs and constructs additional keyword arguments (e.g., attention masks). + + Args: + input_dict: Dictionary containing tensors from the dataloader. Must + contain an "input" key with the main input tensor. May contain + additional keys for auxiliary inputs (e.g., position ids). + labels: Target labels for the batch. + model_parts: List of model parts for accessing model methods. + + Returns: + A tuple of (inputs, labels, extra_inputs, extra_kwargs) where: + - inputs: Main input tensor extracted from input_dict["input"]. + - labels: Target labels (potentially modified by CP sharding). + - extra_inputs: Dict of auxiliary input tensors (all keys except + "input" from input_dict). These are passed to the model forward + but are NOT forwarded across pipeline parallel stages. + - extra_kwargs: Dict of additional keyword arguments for model forward. + These ARE forwarded across pipeline parallel stages. Contains + attention_masks if flex attention is enabled. + + Note: + The distinction between extra_inputs and extra_kwargs is important for + pipeline parallelism: extra_kwargs are forwarded to all pipeline stages, + while extra_inputs are only available to the first stage. + """ + inputs = input_dict["input"] + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + try: + # pyrefly: ignore [not-callable] + extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + input_batch=inputs, + tokenizer=self.tokenizer, + extra_inputs=extra_inputs, + ) + except TypeError: + pass + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + inputs.device, + self.job_config.parallelism.context_parallel_load_balancer, + ) + + return inputs, labels, extra_inputs, extra_kwargs + @torch.no_grad() + # pyrefly: ignore [bad-override] def validate( self, model_parts: list[nn.Module], @@ -98,6 +170,7 @@ def validate( device_type = utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -108,19 +181,11 @@ def validate( self.metrics_processor.ntokens_since_last_log += labels.numel() for k, v in input_dict.items(): input_dict[k] = v.to(device_type) - inputs = input_dict["input"] labels = labels.to(device_type) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None + # Process data (extract inputs, handle attention masks, CP sharding) + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels, model_parts ) if parallel_dims.pp_enabled: @@ -128,18 +193,24 @@ def validate( assert self.pp_has_first_stage is not None assert self.pp_has_last_stage is not None # Pipeline Parallel forward inside eval() call - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.eval( inputs, + **extra_inputs, + **extra_kwargs, target=targets, losses=losses, ) else: - self.pp_schedule.eval(target=targets, losses=losses) + self.pp_schedule.eval( + **extra_kwargs, + target=targets, + losses=losses, + ) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU @@ -152,10 +223,12 @@ def validate( else torch.tensor([-1.0], device=device_type) ) else: - with self.validation_context(optional_context_parallel_ctx): + with self.validation_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - predictions = model_parts[0](inputs) + predictions = model_parts[0]( + inputs, **extra_inputs, **extra_kwargs + ) loss = self.loss_fn(predictions, labels) accumulated_losses.append(loss.detach()) @@ -167,7 +240,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_optional_mesh("loss") ) else: global_avg_loss = loss.item() @@ -186,8 +259,8 @@ def build_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, @@ -203,6 +276,7 @@ def build_validator( loss_fn=loss_fn, validation_context=validation_context, maybe_enable_amp=maybe_enable_amp, + # pyrefly: ignore [bad-argument-type] metrics_processor=metrics_processor, pp_schedule=pp_schedule, pp_has_first_stage=pp_has_first_stage, diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index ff909e6ae9..3eb2c20238 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -5,7 +5,6 @@ # LICENSE file in the root directory of this source tree. import json - import os from dataclasses import asdict, dataclass, field from typing import Any, List, Literal @@ -287,6 +286,13 @@ class Optimizer: use_triton: bool = False """Whether to use Triton kernel for Newton-Schulz in Muon optimizer.""" + state_dtype: Literal["float32", "bfloat16"] = "float32" + """ + Dtype for optimizer states (exp_avg, exp_avg_sq for Adam/AdamW). + Using bfloat16 reduces memory by ~50% but may affect training stability. + Only applies to Adam/AdamW optimizers. + """ + @dataclass class LRScheduler: @@ -320,6 +326,40 @@ class LRScheduler: """ +@dataclass +class DataLoader: + """ + Configuration for PyTorch DataLoader settings. + + These settings are passed directly to StatefulDataLoader. + + Note: + persistent_workers and prefetch_factor are only valid if num_workers > 0. + + Example (TOML config file): + [training.dataloader] + num_workers = 4 + pin_memory = true + persistent_workers = true + prefetch_factor = 2 + """ + + num_workers: int = 0 + """Number of worker processes for data loading.""" + + persistent_workers: bool = False + """Keep workers alive between epochs. Only valid when num_workers > 0.""" + + pin_memory: bool = False + """Copy tensors to CUDA pinned memory before returning them.""" + + prefetch_factor: int | None = None + """ + Number of batches loaded in advance by each worker. Only valid when num_workers > 0. + Default is 2 when num_workers > 0, otherwise None. + """ + + @dataclass class Training: dataset: str = "c4_test" @@ -402,6 +442,44 @@ class Training: many temporary files. """ + dataloader: DataLoader = field(default_factory=DataLoader) + """DataLoader configuration""" + + enable_detailed_memory_tracking: bool = False + """ + Whether to enable detailed memory tracking at every training phase + """ + + clear_cache_between_steps: bool = False + """ + Whether to clear CUDA cache between training steps to measure minimum memory requirements + """ + + skip_optimizer_step: bool = False + """ + Whether to skip the optimizer step (for memory profiling purposes only) + """ + + aggressive_memory_mode: Literal[ + "minimal", "balanced", "aggressive", "maximum" + ] | None = None + """ + Enable aggressive memory management to reduce CUDA memory fragmentation. + This clears CUDA cache and Python GC at strategic points (post-backward, post-optimizer). + Modes: + - None: Disabled (default) + - "minimal": Only clear on high fragmentation (<1% overhead) + - "balanced": Clear after backward and optimizer (2-3% overhead) + - "aggressive": Clear frequently with sync (5-8% overhead) + - "maximum": Clear after every operation (10-15% overhead, for debugging) + """ + + aggressive_memory_verbose: bool = False + """ + Enable verbose logging for aggressive memory manager. + Logs detailed memory stats after each clear operation. + """ + @dataclass class Parallelism: @@ -428,19 +506,28 @@ class Parallelism: only `data_parallel_shard_degree` can be negative. 1 means disabled. """ - fsdp_reshard_after_forward: Literal["default", "always", "never"] = "default" + fsdp_reshard_after_forward: Literal["default", "always", "never"] | int = "default" """ `reshard_after_forward` specifies the policy for applying `reshard_after_forward` within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. - The supported policies include "default", "always" and "never": + The supported policies include "default", "always", "never", or an integer: - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + - integer N: Partially reshard to groups of N GPUs after forward. Must be a factor of + the FSDP shard world size. Use N=8 for intra-node resharding (reduces memory while + keeping communication fast via NVLink). This trades memory for communication. + """ + + fsdp_disable_prefetch: bool = False + """ + Whether to disable FSDP forward/backward prefetching. Disabling prefetch can reduce memory + at the cost of performance (less overlap of communication and computation). """ tensor_parallel_degree: int = 1 @@ -512,9 +599,32 @@ class Parallelism: The global training batch size must be evenly divisible by pipeline_parallel_microbatch_size. """ + pipeline_parallel_expert_parallel_overlap: bool = True + """Whether to turn on the optimization to overlap expert parallel and pipeline parallel + communication. This is only effective when the pipeline parallel schedule is DualPipeV and + pipeline_parallel_degree > 1 and expert_parallel_degree > 1. + + TODO: Does not support activation_checkpoint, set mode="none" + """ + context_parallel_degree: int = 1 """Context parallelism degree. 1 means disabled.""" + context_parallel_load_balancer: str | None = "headtail" + """ + Load balancer type for context parallelism. Options: + - "headtail": Use HeadTailLoadBalancer for SDPA + - "ptrr": Use PTRRLoadBalancer for FlexAttention + - None: Disable load balancing + """ + + def __post_init__(self): + if self.context_parallel_load_balancer == "": + raise ValueError( + "context_parallel_load_balancer cannot be an empty string. " + "Use None to disable load balancing." + ) + context_parallel_rotate_method: Literal["allgather", "alltoall"] = "allgather" """ The collective to use in context parallel SDPA for kv shards exchange. @@ -527,19 +637,7 @@ class Parallelism: """ Expert parallelism degree. 1 means disabled. No effect for non-MoE models. - Currently, it is supported with the following constraints: - - - when etp = tp: - - - cp <= ep <= dp_shard * cp - - ep % cp == 0 - - dp_shard * cp % ep == 0 - - - when etp = 1: - - - cp * tp <= ep <= dp_shard * cp * tp - - ep % (cp * tp) == 0 - - dp_shard * cp * tp % ep == 0 + Currently, etp is either 1 or is the same as tp. Note that this is still an experimental feature. Some constraints will be relaxed soon when we have more flexible DeviceMesh support. @@ -555,6 +653,17 @@ class Parallelism: Note that this is still an experimental feature. """ + expert_parallel_comm_backend: Literal["standard", "deepep"] = "standard" + """ + Expert-parallel communication backend. No effect for non-MoE models or when ep = 1. + + - "standard": Uses PyTorch all-to-all collectives (default) + - "deepep": Uses DeepEP custom kernels for more efficient communication + + DeepEP requires installation: + https://github.com/deepseek-ai/DeepEP. + """ + @dataclass class DeepEP: @@ -852,10 +961,7 @@ class Compile: enable: bool = False """Whether to apply torch.compile""" - components: list[Literal["model", "loss"]] = field( - default_factory=lambda: ["model", "loss"] - ) - + components: list[str] = field(default_factory=lambda: ["model", "loss"]) """Which components to compile""" backend: str = "inductor" @@ -991,6 +1097,22 @@ class Comm: save_traces_file_prefix: str = "rank_" """Flight recorder trace files prefix""" + mode: Literal["default", "fake_backend", "local_tensor"] = "default" + """ + Communication mode for distributed training. + + Options: + - "default": Normal distributed training with real communication + - "fake_backend": Fake comm backend for dry run mode only (configuration validation without GPU) + - "local_tensor": Local tensor mode for debugging purposes. There will be only one process + regardless of the number of GPUs. LocalTensor will simulate the computation by running one + rank after another. While the performance will be slow, the numerics should be the same. + This enables us to verify numerics with fewer GPUs. For example, we can directly run 5D + parallelisms within a single node to reduce the combinations we need to use in integration tests. + + NOTE: local_tensor is an experimental feature and automatically uses fake_backend internally. + """ + @dataclass class MemoryEstimation: @@ -1092,6 +1214,9 @@ class Validation: WARNING: When setting to -1 there could be hangs due to mismatch among ranks """ + dataloader: DataLoader = field(default_factory=DataLoader) + """DataLoader configuration""" + def __post_init__(self): assert ( self.steps > 0 or self.steps == -1 @@ -1221,6 +1346,12 @@ class Debug: moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + enable_nan_tracker: bool = False + """If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model.""" + + nan_tracker_verbose: bool = False + """If True, print stats for every layer (very verbose output).""" + @dataclass class JobConfig: @@ -1257,7 +1388,9 @@ def to_dict(self) -> dict[str, Any]: def maybe_log(self) -> None: if self.job.print_config: - logger.info(f"Running with configs: {self.to_dict()}") + logger.info( + f"Running with configs: {json.dumps(self.to_dict(), indent=2, ensure_ascii=False)}" + ) if self.job.save_config_file is not None: config_file = os.path.join(self.job.dump_folder, self.job.save_config_file) diff --git a/torchtitan/config/manager.py b/torchtitan/config/manager.py index f5e3f31157..14f4b44836 100644 --- a/torchtitan/config/manager.py +++ b/torchtitan/config/manager.py @@ -16,6 +16,7 @@ try: import tomllib except ModuleNotFoundError: + # pyrefly: ignore[import-error] import tomli as tomllib from torchtitan.tools.logging import logger @@ -178,7 +179,7 @@ def _dict_to_dataclass(self, cls, data: dict[str, Any]) -> Any: result[f.name] = self._dict_to_dataclass(f.type, value) else: result[f.name] = value - return cls(**result) + return cls(**result) # pyrefly: ignore[not-callable, bad-instantiation] def _validate_config(self) -> None: if self.config.experimental.custom_args_module: @@ -253,10 +254,20 @@ def list_str_rule(type_info: tyro.constructors.PrimitiveTypeInfo): # # ----------------------------------------------------------------------------- - from rich import print as rprint - from rich.pretty import Pretty + try: - config_manager = ConfigManager() - config = config_manager.parse_args() + # pyrefly: ignore[missing-import] + from rich import print as rprint - rprint(Pretty(config)) + # pyrefly: ignore[missing-import] + from rich.pretty import Pretty + + config_manager = ConfigManager() + config = config_manager.parse_args() + + rprint(Pretty(config)) + except ImportError: + config_manager = ConfigManager() + config = config_manager.parse_args() + logger.info(config) + logger.warning("rich is not installed, show the raw config") diff --git a/torchtitan/distributed/__init__.py b/torchtitan/distributed/__init__.py index 63690a660b..72d1298648 100644 --- a/torchtitan/distributed/__init__.py +++ b/torchtitan/distributed/__init__.py @@ -14,8 +14,10 @@ from torchtitan.distributed.parallel_dims import ParallelDims - -__all__ = ["ParallelDims", "NoParallel"] +__all__ = [ + "ParallelDims", + "NoParallel", +] # NOTE: This is to achieve replicate computation on the gate module in the MoE router. @@ -65,7 +67,10 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: device_mesh, None, partial( - self._prepare_input_fn, self.input_layout, self.desired_input_layout + # pyrefly: ignore [bad-argument-type] + self._prepare_input_fn, + self.input_layout, + self.desired_input_layout, ), partial(self._prepare_output_fn, self.output_layout, self.use_local_output), ) diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 8359f71730..9107b8bc73 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -11,13 +11,14 @@ from collections import defaultdict import torch +import torch._functorch.config import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( checkpoint_wrapper as ptd_checkpoint_wrapper, ) from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.tools.logging import logger, warn_once +from torchtitan.tools.logging import logger _layer_sac_count = 0 @@ -155,88 +156,12 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: ) -def _apply_op_sac_to_transformer_block_with_flex( - module: nn.Module, - ac_config: ACConfig, - *, - base_fqn: str | None = None, - model_compile_enabled: bool = False, - op_sac_save_list: set[torch._ops.OpOverload], -) -> nn.Module: - """Apply SAC to the transformer block that uses FlexAttention. - - Args: - module (nn.Module): The transformer block to apply SAC to. - ac_config (ACConfig): The Activation Checkpoint config. - base_fqn (str, optional): The base fqn of the module. Defaults to None. - model_compile_enabled (bool): Whether model compilation is enabled. - Defaults to False. - op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead - of recomputing. - - Returns: - nn.Module: The transformer block with SAC applied. - """ - - warn_once( - logger, - ( - "Flex Attention requires compilation for good performance.\n" - "Thus, torch.compile is always used for Flex Attention, " - "regardless of the compile.enable flag.\n" - "However, when selective activation checkpointing (SAC) is enabled, " - "torch.compile may be invalidated:\n" - "1. If compile.enable is False, SAC will ignore any torch.compile " - "inside the SAC region.\n" - "2. If compile.enable is True but the transformer block contains an MoE module.\n\n" - "For both cases, we will not wrap the entire TransformerBlock with SAC:\n" - " - For case 1: SAC will be used for MoE and FeedForward modules, " - "while full AC will be used for the Attention module.\n" - " - For case 2: SAC will be applied to MoE and Attention modules if the block " - "is sparse. But we still apply SAC to an entire dense block.\n" - ), - ) - - def wrap_submodule(name: str, full_ac: bool = False) -> None: - submodule = getattr(module, name) - if full_ac: - submodule = _apply_full_ac(submodule, ac_config) - else: - submodule = _apply_op_sac( - submodule, - ac_config, - base_fqn=f"{base_fqn}.{name}" if base_fqn else name, - op_sac_save_list=op_sac_save_list, - ) - module.register_module(name, submodule) - - if hasattr(module, "moe"): - wrap_submodule("moe", full_ac=False) - if model_compile_enabled: - wrap_submodule("attention", full_ac=False) - else: - wrap_submodule("attention", full_ac=True) - else: - if model_compile_enabled: - module = _apply_op_sac( - module, - ac_config, - base_fqn=base_fqn, - op_sac_save_list=op_sac_save_list, - ) - else: - wrap_submodule("feed_forward", full_ac=False) - wrap_submodule("attention", full_ac=True) - return module - - def _apply_ac_to_transformer_block( module: nn.Module, ac_config: ACConfig, *, base_fqn: str | None = None, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, ) -> nn.Module: valid_ac_modes = ("full", "selective") @@ -259,26 +184,9 @@ def _apply_ac_to_transformer_block( if use_op_sac: op_sac_save_list = op_sac_save_list or set() - if use_flex_attn: - """ - For Flex Attention, we need to apply SAC carefully to avoid invalidating - torch.compile. Any torch.compile inside the SAC region will be ignored, - and any torch.compile outside the SAC region will also be ignored if the - SAC region contains a graph break (e.g., MoE). - - TODO: remove this once SAC issues are resolved. - """ - return _apply_op_sac_to_transformer_block_with_flex( - module, - ac_config, - base_fqn=base_fqn, - model_compile_enabled=model_compile_enabled, - op_sac_save_list=op_sac_save_list, - ) - else: - return _apply_op_sac( - module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list - ) + return _apply_op_sac( + module, ac_config, base_fqn=base_fqn, op_sac_save_list=op_sac_save_list + ) return _apply_layer_sac(module, ac_config) @@ -288,26 +196,33 @@ def apply_ac( ac_config: ACConfig, *, model_compile_enabled: bool = False, - use_flex_attn: bool = False, op_sac_save_list: set[torch._ops.OpOverload] | None = None, base_folder: str = "", ) -> None: """Apply activation checkpointing to the model. - Note that SAC, Flex Attention and model compilation have some conflicts. - We explicitly ask the user to pass these configs to warn as the wrapping - will be different. - Args: model (nn.Module): The model to apply activation checkpointing to. ac_config (ACConfig): The activation checkpointing config. model_compile_enabled (bool): Whether torch.compile is enabled for the model. - use_flex_attn (bool): Whether flex attention is enabled for the model. op_sac_save_list (set[torch._ops.OpOverload]): The list of ops to save instead of recomputing. Returns: None """ + # Disable dynamo LRU cache to workaround an interaction between SAC, PP, and Flex: + # + # When forward runs with a second PP microbatch, it triggers recompilation with dynamic + # shapes enabled. Now there are two valid compiled graphs. By default, dynamo selects + # the latest one (the dynamic shapes version), so the runtime wrapper expects an extra + # symint output. When SAC caches the inductor HOP output from the static graph for + # batch_idx=0, it would miss that symint and cause an assertion failure. The workaround + # here is to disable the LRU cache, and select graphs in insertion order instead. + # + # Also see: https://github.com/pytorch/pytorch/issues/166926 + # pyrefly: ignore [missing-attribute] + torch._C._dynamo.eval_frame._set_lru_cache(False) + if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: @@ -320,15 +235,16 @@ def apply_ac( torch._functorch.config.activation_memory_budget = ac_config.memory_budget logger.info(f"Selected {ac_config.memory_budget} budget option") else: + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, ac_config, base_fqn=f"layers.{layer_id}", model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=op_sac_save_list, ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/distributed/context_parallel.py b/torchtitan/distributed/context_parallel.py new file mode 100644 index 0000000000..7214f6a603 --- /dev/null +++ b/torchtitan/distributed/context_parallel.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Sequence +from typing import Any, cast + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.experimental._attention import ( + _context_parallel_shard, + _ContextParallel, + _enable_context_parallel_dispatcher, + _HeadTailLoadBalancer, + _PTRRLoadBalancer, +) +from torch.distributed.tensor.parallel import parallelize_module +from torch.nn.attention.flex_attention import BlockMask + +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.tools.logging import logger + + +def apply_cp_to_attention_module( + attention_modules: Sequence[nn.Module], + cp_mesh: DeviceMesh, + attention_type: str, +) -> None: + """ + Apply context parallelism to attention modules. + + CP splits the sequence dimension across devices to enable training with + longer sequences. This function applies CP to the provided attention + modules. + + Args: + attention_modules: Sequence of attention modules to apply CP to + cp_mesh: Device mesh for context parallel dimension + attention_type: Type of attention mechanism. Must be one of: + - "sdpa": scaled_dot_product_attention() + - "flex": flex_attention() + - "varlen": varlen_attn() (not yet implemented) + + Raises: + NotImplementedError: If attention_type is "varlen" + """ + # Apply context parallelism to every attention module + # TODO: make seq_dim configurable once the implementation doesn't assume 2 + # internally. + match attention_type: + case "flex": + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX + ) + case "sdpa": + # Enable the DTensor dispatcher to route SDPA operations to the + # Context Parallel implementation. This is required for CP to work + # with SDPA (but not FlexAttention). + # Note: Use _disable_context_parallel_dispatcher() if you need to + # turn this off. In TorchTitan, we currently don't disable the CP + # dispatcher. + _enable_context_parallel_dispatcher() + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA + ) + case "varlen": + raise NotImplementedError( + "Variable-length attention CP is not yet supported" + ) + case _: + raise ValueError( + f"Invalid attention_type '{attention_type}'. " + f"Must be one of: 'sdpa', 'flex', 'varlen'" + ) + + for attention_module in attention_modules: + parallelize_module( + module=attention_module, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + logger.info("Applied Context Parallel to the model") + + +def prepare_context_parallel_input( + inputs: torch.Tensor, + labels: torch.Tensor, + extra_kwargs: dict[str, Any], + cp_mesh: DeviceMesh, + device: torch.device, + load_balancer_type: str | None = "headtail", +) -> tuple[torch.Tensor, torch.Tensor, dict[str, Any]]: + """ + Prepare inputs, labels, and attention masks for Context Parallel forward pass. + + This function prepares tensors for context parallel by: + 1. Creating position indices based on input sequence length + 2. Sharding inputs, labels, and positions across the CP mesh + 3. Sharding attention masks if present + + Args: + inputs: Input tensor of shape [batch_size, seq_len] + labels: Label tensor of shape [batch_size, seq_len] + extra_kwargs: Dictionary that may contain 'attention_masks' to be sharded + cp_mesh: Device mesh for context parallel dimension + device: Device to create position tensor on + load_balancer_type: Type of load balancer to use for sharding. + Options: "headtail", "ptrr", or None. Defaults to "headtail". + + Returns: + Tuple of (sharded_inputs, sharded_labels, updated_extra_kwargs) where: + - sharded_inputs: Inputs sharded along sequence dimension + - sharded_labels: Labels sharded along sequence dimension + - updated_extra_kwargs: Dict with sharded 'positions' and optionally + sharded 'attention_masks' + """ + attention_masks = extra_kwargs.get("attention_masks", None) + positions = extra_kwargs.get("positions", None) + if positions is None: + positions = torch.arange( + 0, inputs.shape[1], dtype=torch.int32, device=device + ).expand(inputs.shape) + (inputs, labels, positions), attention_masks = cp_shard( + cp_mesh, + (inputs, labels, positions), + attention_masks, + load_balancer_type, + ) + extra_kwargs["positions"] = positions + if attention_masks is not None: + extra_kwargs["attention_masks"] = attention_masks + + return inputs, labels, extra_kwargs + + +def cp_shard( + cp_mesh: DeviceMesh, + inputs: tuple[torch.Tensor, ...], + attention_masks: AttentionMasksType | None, + load_balancer_type: str | None = "headtail", + input_seq_dim: int = 1, +) -> tuple[tuple[torch.Tensor, ...], AttentionMasksType | None]: + """ + Shard inputs and attention masks across the context parallel mesh. + + This function distributes input tensors across devices in the CP mesh + along the sequence dimension, enabling efficient processing. It optionally + uses a load balancer to handle uneven computation workload. + + Args: + cp_mesh: Device mesh for context parallel dimension + inputs: Tuple of input tensors to be sharded along the sequence + dimension + attention_masks: Attention masks to be sharded. Supports None, + BlockMask, or dict[str, BlockMask] + load_balancer_type: Type of load balancer to use. Options: + - "headtail": Use HeadTailLoadBalancer (for SDPA) + - "ptrr": Use PTRRLoadBalancer (for FlexAttention) + - None: Disable load balancing + Defaults to "headtail". + input_seq_dim: Sequence dimension index for sharding. Defaults to 1, + which covers most use cases where tensors have shape + [batch_size, seq_len]. Can be changed by passing a + different value if your tensors use a different sequence + dimension layout. + + Returns: + Tuple of (sharded_inputs, attention_masks) where: + - sharded_inputs: Tuple of input tensors sharded along the + sequence dimension + - attention_masks: Sharded attention masks (BlockMask or + dict[str, BlockMask]) or None + + Raises: + ValueError: If load_balancer_type is "ptrr" and attention_masks + is None or a dict + """ + seq_len = inputs[0].size(input_seq_dim) + cp_world_size = cp_mesh.size(0) + + load_balancer = None + if load_balancer_type: + match load_balancer_type: + case "headtail": + # For SDPA, we use the _HeadTailLoadBalancer. + load_balancer = _HeadTailLoadBalancer( + seq_len, cp_world_size, cp_mesh.device_type + ) + case "ptrr": + # For FlexAttention, we use _PTRRLoadBalancer. + # _PTRRLoadBalancer requires attention_masks to be a BlockMask. + # For dict[str, BlockMask], _PTRRLoadBalancer currently doesn't + # support the case where there are multiple masks. + if attention_masks is None or isinstance(attention_masks, dict): + raise ValueError( + "PTRRLoadBalancer requires attention_masks to be a " + "BlockMask, but got None or dict[str, BlockMask]" + ) + if not isinstance(attention_masks, BlockMask): + raise ValueError( + f"PTRRLoadBalancer requires attention_masks to be a " + f"BlockMask, but got {type(attention_masks)}" + ) + load_balancer = _PTRRLoadBalancer(attention_masks, cp_world_size) + case _: + raise ValueError( + f"Invalid load_balancer_type '{load_balancer_type}'. " + f"Must be one of: 'headtail', 'ptrr', or None" + ) + + inputs = cast( + tuple[torch.Tensor, ...], + _context_parallel_shard( + mesh=cp_mesh, + buffers=inputs, + seq_dims=tuple(input_seq_dim for _ in inputs), + load_balancer=load_balancer, + ), + ) + + # BlockMask, has shape, [B, H, Q, KV], and we can only shard + # on the Q seq dimension, not KV. + MASK_Q_SEQ_DIM = 2 + if attention_masks is not None: + assert isinstance(attention_masks, (BlockMask, dict[str, BlockMask])) + masks = ( + [attention_masks] + if isinstance(attention_masks, BlockMask) + else list(attention_masks.values()) + ) + masks = _context_parallel_shard( + mesh=cp_mesh, + buffers=masks, + seq_dims=(MASK_Q_SEQ_DIM,) * len(masks), + load_balancer=load_balancer, + ) + attention_masks = cast( + (BlockMask | dict[str, BlockMask]), + masks[0] + if isinstance(attention_masks, BlockMask) + else {k: v for k, v in zip(attention_masks.keys(), masks)}, + ) + + return inputs, attention_masks diff --git a/torchtitan/distributed/deepep/README.md b/torchtitan/distributed/deepep/README.md index 94012de27e..3f05d5161a 100644 --- a/torchtitan/distributed/deepep/README.md +++ b/torchtitan/distributed/deepep/README.md @@ -175,6 +175,9 @@ uv pip install git+https://github.com/deepseek-ai/DeepEP.git --no-build-isolatio > > See [GitHub Issue #224](https://github.com/deepseek-ai/DeepEP/issues/224#issuecomment-2985783610) +> If you see /usr/bin/ld: cannot find -l:libnvshmem_host.so: No such file or directory +> try ln -s /path/to/libnvshmem_host.so.3 /path/to/libnvshmem_host.so + ### Step 3: Verify Installation ```bash diff --git a/torchtitan/distributed/deepep/__init__.py b/torchtitan/distributed/deepep/__init__.py new file mode 100644 index 0000000000..53001938a8 --- /dev/null +++ b/torchtitan/distributed/deepep/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""DeepEP distributed communication primitives for MoE.""" + +from .deepep import combine_tokens, dispatch_tokens, DispatchState + +__all__ = [ + "dispatch_tokens", + "combine_tokens", + "DispatchState", +] diff --git a/torchtitan/distributed/deepep/deepep.py b/torchtitan/distributed/deepep/deepep.py new file mode 100644 index 0000000000..ce44fc232e --- /dev/null +++ b/torchtitan/distributed/deepep/deepep.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +DeepEP primitives for MoE Expert Parallel. + +Provides low-level functions and autograd wrappers for DeepEP communication. +Used by DeepEPExpertParallel in expert_parallel.py. +""" + +from dataclasses import dataclass +from typing import Optional, Tuple + +import torch +from torch.distributed import ProcessGroup + +try: + from deep_ep import Buffer # pyrefly: ignore[missing-import] + from deep_ep.utils import ( # pyrefly: ignore[missing-import] + EventHandle, + EventOverlap, + ) +except ImportError as e: + raise ImportError( + "DeepEP is required for this module. " + "Install from: https://github.com/deepseek-ai/deepep" + ) from e + + +# Global buffer (single buffer per process, recreated if group changes) +_buffer: Buffer = None + +# Global cache for dispatch handles, keyed by cache_id +# SAC saves the cache_id tensor; we use it to retrieve the non-tensor handle +_handle_cache: dict = {} +_cache_counter: int = 0 + + +def _get_next_cache_id() -> torch.Tensor: + """Generate a unique cache_id tensor on CPU to avoid GPU-CPU sync.""" + global _cache_counter + _cache_counter += 1 + return torch.tensor([_cache_counter], dtype=torch.int64, device="cpu") + + +# ============================================================================ +# Custom Op Registration for SAC Integration +# ============================================================================ + +_lib = torch.library.Library("deepep", "DEF") + +# dispatch returns: (recv_x, recv_indices, recv_scores, num_recv_per_expert, cache_id) +_lib.define( + "dispatch(Tensor x, Tensor topk_idx, Tensor topk_weights, " + "Tensor num_tokens_per_rank, Tensor num_tokens_per_rdma_rank, " + "Tensor is_token_in_rank, Tensor num_tokens_per_expert) " + "-> (Tensor, Tensor, Tensor, Tensor, Tensor)" +) + +# combine returns: combined_x +_lib.define("combine(Tensor x, Tensor cache_id) -> Tensor") + + +@torch.library.impl(_lib, "dispatch", "CUDA") +def _dispatch_op_impl( + x: torch.Tensor, + topk_idx: torch.Tensor, + topk_weights: torch.Tensor, + num_tokens_per_rank: torch.Tensor, + num_tokens_per_rdma_rank: torch.Tensor, + is_token_in_rank: torch.Tensor, + num_tokens_per_expert: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Execute DeepEP dispatch.""" + global _buffer + + buffer = _buffer + assert buffer is not None, "Buffer must be initialized before dispatch" + + previous_event = _create_event_if_async(True) + + ( + recv_x, + recv_indices, + recv_scores, + num_recv_list, + handle, + after_event, + ) = buffer.dispatch( + x=x, + topk_idx=topk_idx, + topk_weights=topk_weights.to(torch.float32), + num_tokens_per_rank=num_tokens_per_rank, + num_tokens_per_rdma_rank=num_tokens_per_rdma_rank, + is_token_in_rank=is_token_in_rank, + num_tokens_per_expert=num_tokens_per_expert, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + cache_id = _get_next_cache_id() + _handle_cache[cache_id.item()] = handle + + num_recv_tensor = torch.tensor(num_recv_list, dtype=torch.int32, device="cpu") + return recv_x, recv_indices, recv_scores, num_recv_tensor, cache_id + + +@torch.library.impl(_lib, "combine", "CUDA") +def _combine_op_impl(x: torch.Tensor, cache_id: torch.Tensor) -> torch.Tensor: + """Execute DeepEP combine.""" + global _buffer + + buffer = _buffer + assert buffer is not None, "Buffer must be initialized before combine" + + handle = _handle_cache.get(cache_id.item()) + assert handle is not None, f"Handle not found for cache_id={cache_id.item()}" + + previous_event = _create_event_if_async(True) + + combined, _, after_event = buffer.combine( + x=x, + handle=handle, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + return combined + + +def _dispatch_backward( + ctx, grad_recv_x, grad_recv_indices, grad_recv_scores, grad_num_recv, grad_cache_id +): + """Backward for dispatch: performs combine on gradients.""" + global _buffer + + if grad_recv_x is None: + return None, None, None, None, None, None, None + + handle = _handle_cache.get(ctx.cache_id_int) + assert handle is not None, f"Handle not found for cache_id={ctx.cache_id_int}" + + previous_event = _create_event_if_async(True) + + grad_x, grad_scores, after_event = _buffer.combine( + x=grad_recv_x, + handle=handle, + topk_weights=grad_recv_scores.float() if grad_recv_scores is not None else None, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + _handle_cache.pop(ctx.cache_id_int, None) + + grad_x = grad_x.to(ctx.input_dtype) + grad_topk_weights = ( + grad_scores.to(ctx.input_dtype) if grad_scores is not None else None + ) + + return grad_x, None, grad_topk_weights, None, None, None, None + + +def _dispatch_setup_context(ctx, inputs, output): + x, topk_idx, topk_weights, *_ = inputs + recv_x, recv_indices, recv_scores, num_recv, cache_id = output + ctx.cache_id_int = cache_id.item() + ctx.input_dtype = x.dtype + + +def _combine_backward(ctx, grad_combined): + """Backward for combine: performs dispatch on gradients.""" + global _buffer + + handle = ctx.saved_handle + previous_event = _create_event_if_async(True) + + grad_x, _, _, _, _, after_event = _buffer.dispatch( + x=grad_combined, + topk_idx=None, + topk_weights=None, + num_tokens_per_rank=None, + num_tokens_per_rdma_rank=None, + is_token_in_rank=None, + num_tokens_per_expert=None, + handle=handle, + previous_event=previous_event, + async_finish=True, + allocate_on_comm_stream=True, + ) + + _sync_stream_if_async(True, after_event) + + return grad_x, None + + +def _combine_setup_context(ctx, inputs, output): + x, cache_id = inputs + ctx.cache_id_int = cache_id.item() + ctx.saved_handle = _handle_cache.get(ctx.cache_id_int) + + +torch.library.register_autograd( + "deepep::dispatch", _dispatch_backward, setup_context=_dispatch_setup_context +) +torch.library.register_autograd( + "deepep::combine", _combine_backward, setup_context=_combine_setup_context +) + + +def _create_event_if_async(async_finish: bool): + """Create EventOverlap handle if async mode is enabled.""" + return EventOverlap(EventHandle()) if async_finish else None + + +def _sync_stream_if_async(async_finish: bool, after_event): + """Synchronize current stream with communication stream if async mode is enabled.""" + if async_finish and after_event is not None: + after_event.current_stream_wait() + + +def get_hidden_bytes(x: torch.Tensor) -> int: + """Calculate the number of hidden bytes for a tensor.""" + return x.size(1) * max(x.element_size(), 2) + + +def get_buffer(group: ProcessGroup, hidden_bytes: int) -> Buffer: + """Get or create a buffer for all-to-all communication.""" + global _buffer + num_nvl_bytes, num_rdma_bytes = 0, 0 + for config in ( + Buffer.get_dispatch_config(group.size()), + Buffer.get_combine_config(group.size()), + ): + num_nvl_bytes = max( + config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes + ) + num_rdma_bytes = max( + config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes + ) + + if ( + _buffer is None + or _buffer.group != group + or _buffer.num_nvl_bytes < num_nvl_bytes + or _buffer.num_rdma_bytes < num_rdma_bytes + ): + _buffer = Buffer(group, num_nvl_bytes, num_rdma_bytes) + + return _buffer + + +def _indices_to_multihot( + indices: torch.Tensor, scores: torch.Tensor, num_local_experts: int +) -> Tuple[torch.Tensor, torch.Tensor]: + """Convert topk indices to multihot format for permutation.""" + batch_size = indices.shape[0] + multihot_routing_map = torch.zeros( + (batch_size, num_local_experts), dtype=torch.long, device=indices.device + ) + multihot_scores = torch.zeros( + (batch_size, num_local_experts), dtype=scores.dtype, device=indices.device + ) + + mask = indices != -1 + valid_indices = indices[mask] + row_indices = torch.arange(batch_size, device=indices.device).repeat_interleave( + mask.sum(dim=1) + ) + multihot_routing_map[row_indices, valid_indices] = 1 + multihot_scores[row_indices, valid_indices] = scores[mask] + + return multihot_routing_map.bool(), multihot_scores + + +def _permute_tokens( + tokens: torch.Tensor, + routing_map: torch.Tensor, + scores: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], torch.Tensor]: + """Permute tokens by expert for grouped_mm. + + Returns: + (permuted_tokens, permuted_scores, sorted_indices) + """ + num_tokens = tokens.shape[0] + num_experts = routing_map.shape[1] + + routing_map_t = routing_map.bool().T.contiguous() + token_indices = torch.arange(num_tokens, device=routing_map.device) + token_indices = token_indices.unsqueeze(0).expand(num_experts, -1) + sorted_indices = token_indices.masked_select(routing_map_t) + sorted_tokens = tokens.index_select(0, sorted_indices) + + if scores is not None: + sorted_scores = scores.T.contiguous().masked_select(routing_map_t) + else: + sorted_scores = None + + return sorted_tokens, sorted_scores, sorted_indices + + +def _unpermute_tokens( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + num_tokens: int, +) -> torch.Tensor: + """Reverse permutation applied by _permute_tokens.""" + hidden = permuted_tokens.shape[1] + output_tokens = torch.zeros( + (num_tokens, hidden), dtype=permuted_tokens.dtype, device=permuted_tokens.device + ) + output_tokens.scatter_add_( + 0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens + ) + return output_tokens + + +@dataclass +class DispatchState: + """State from dispatch needed for combine.""" + + cache_id: torch.Tensor # CPU tensor used to retrieve cached handle + sorted_indices: torch.Tensor + num_recv_tokens: int + permuted_scores: Optional[torch.Tensor] = None + + +def dispatch_tokens( + hidden_states: torch.Tensor, + selected_experts_indices: torch.Tensor, + top_scores: torch.Tensor, + num_local_experts: int, + num_experts: int, + group: ProcessGroup, + score_before_experts: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor, DispatchState]: + """Dispatch tokens to experts via DeepEP. + + Args: + hidden_states: Input tokens [num_tokens, hidden_dim] + selected_experts_indices: Expert indices for each token [num_tokens, top_k] + top_scores: Routing scores for each token [num_tokens, top_k] + num_local_experts: Number of experts on this rank + num_experts: Total number of experts across all ranks + group: EP process group + score_before_experts: If True, apply routing scores before expert computation. + + Returns: + (permuted_tokens, tokens_per_expert, state_for_combine) + """ + # Ensure contiguous and proper shape + router_topk = ( + selected_experts_indices.shape[1] if selected_experts_indices.dim() == 2 else 1 + ) + if selected_experts_indices.dim() != 2: + selected_experts_indices = selected_experts_indices.view( + -1, router_topk + ).contiguous() + top_scores = top_scores.view(-1, router_topk).contiguous() + else: + selected_experts_indices = selected_experts_indices.contiguous() + top_scores = top_scores.contiguous() + + # Mask out zero-score tokens + selected_experts_indices = selected_experts_indices.masked_fill(top_scores == 0, -1) + + # Ensure float32 scores (DeepEP requirement) + if top_scores.dtype != torch.float32: + top_scores = top_scores.float() + + buffer = get_buffer(group, get_hidden_bytes(hidden_states)) + + ( + num_tokens_per_rank, + num_tokens_per_rdma_rank, + num_tokens_per_expert_dispatch, + is_token_in_rank, + _, + ) = buffer.get_dispatch_layout( + topk_idx=selected_experts_indices, num_experts=num_experts + ) + + ( + hidden_states, + dispatched_indices, + dispatched_expert_scores, + tokens_per_expert, + cache_id, + ) = torch.ops.deepep.dispatch( + hidden_states, + selected_experts_indices, + top_scores, + num_tokens_per_rank, + num_tokens_per_rdma_rank, + is_token_in_rank, + num_tokens_per_expert_dispatch, + ) + + dispatched_routing_map, dispatched_expert_scores_multihot = _indices_to_multihot( + dispatched_indices, dispatched_expert_scores, num_local_experts + ) + + num_recv_tokens = hidden_states.shape[0] + + # Sort tokens by expert for grouped_mm + hidden_states, permuted_scores, sorted_indices = _permute_tokens( + hidden_states, dispatched_routing_map, scores=dispatched_expert_scores_multihot + ) + + # Compute tokens_per_expert from routing_map (matches the sorted tokens) + tokens_per_expert = ( + dispatched_routing_map.sum(dim=0).to(torch.int32).to(hidden_states.device) + ) + + if score_before_experts and permuted_scores is not None: + # Avoid float32 conversion to save memory + hidden_states = hidden_states * permuted_scores.to(hidden_states.dtype).reshape( + -1, 1 + ) + permuted_scores_for_state = None + else: + permuted_scores_for_state = permuted_scores + + state = DispatchState( + cache_id=cache_id, + sorted_indices=sorted_indices, + num_recv_tokens=num_recv_tokens, + permuted_scores=permuted_scores_for_state, + ) + + return hidden_states, tokens_per_expert, state + + +def combine_tokens( + hidden_states: torch.Tensor, + state: DispatchState, +) -> torch.Tensor: + """Combine tokens from experts via DeepEP.""" + if state.permuted_scores is not None: + # In-place multiplication to save memory + hidden_states = hidden_states * state.permuted_scores.to( + hidden_states.dtype + ).reshape(-1, 1) + + hidden_states = _unpermute_tokens( + hidden_states, state.sorted_indices, state.num_recv_tokens + ) + + hidden_states = torch.ops.deepep.combine(hidden_states, state.cache_id) + + return hidden_states diff --git a/torchtitan/distributed/dual_pipe_v.py b/torchtitan/distributed/dual_pipe_v.py new file mode 100644 index 0000000000..9f13b7f958 --- /dev/null +++ b/torchtitan/distributed/dual_pipe_v.py @@ -0,0 +1,324 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import threading +from typing import cast, Optional + +import torch +import torch.nn as nn +from torch import Tensor + +from torch.distributed.pipelining.schedules import ( + _Action, + _PipelineContext, + _PipelineScheduleRuntime, + _wait_batch_p2p, +) +from torch.distributed.pipelining.stage import _PipelineStageBase +from torch.distributed.tensor import DeviceMesh, distribute_module +from torch.profiler import record_function + +from torchtitan.distributed.expert_parallel import BaseExpertParallel + +from torchtitan.tools.utils import get_device_info + +""" +Below are optimizations related to pipeline parallelism with expert parallelism +""" + + +def get_dual_pipe_v_flag(job_config, parallel_dims) -> bool: + """ + Determine if DualPipeV should be enabled based on config and + validates that incompatible features (EP + DualPipeV + AC) are not used together. + """ + if not parallel_dims.ep_enabled or not parallel_dims.pp_enabled: + return False + + dual_pipe_v = ( + job_config.parallelism.pipeline_parallel_expert_parallel_overlap + and job_config.parallelism.pipeline_parallel_schedule.lower() == "dualpipev" + ) + + if dual_pipe_v and job_config.activation_checkpoint.mode != "none": + raise NotImplementedError( + "Expert Parallel with DualPipeV and Activation Checkpointing " + "cannot be used together. Please disable one of them." + ) + + return dual_pipe_v + + +class DualPipeExpertParallel(BaseExpertParallel): + """ + Wrapper that adds dual-pipe synchronization hooks to any BaseExpertParallel. + Wraps dispatch/combine with sync hooks for overlapping EP communication + with PP computation in DualPipe scheduling. + + The execution order becomes: + A -> dispatch -> B -> module -> C -> combine -> D + """ + + def __init__(self, inner_ep: BaseExpertParallel): + super().__init__() + self.inner_ep = inner_ep + + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + return self.inner_ep._partition_fn(name, mod, device_mesh) + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + """A -> dispatch -> B""" + inputs = (cast(Tensor, SyncHook.apply(inputs[0], "A")),) + inputs[1:] + outputs = self.inner_ep._token_dispatch(mod, inputs, device_mesh) + outputs = (cast(Tensor, SyncHook.apply(outputs[0], "B")),) + outputs[1:] + return outputs + + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + """C -> combine -> D""" + routed_output = cast(Tensor, SyncHook.apply(routed_output, "C")) + combine_output = self.inner_ep._token_combine(mod, routed_output, device_mesh) + combine_output = cast(Tensor, SyncHook.apply(combine_output, "D")) + return combine_output + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] + input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] + output_fn=self._token_combine, + ) + + +class HookCoordinator: + def __init__(self): + # Barrier for 2 threads (forward and backward) to synchronize + # This ensures that we always alternate at executing one compute and one comm op together + self._execution_barrier = threading.Barrier(2) + + self._coordination_enabled = False + self._cycle_count = 0 + self._num_layers = None + + def barrier(self): + """Barrier for 2 threads to synchronize""" + if not self.is_coordination_enabled(): + return + + try: + self._execution_barrier.wait() + except threading.BrokenBarrierError: + pass + + def enable_coordination(self, num_layers: Optional[int] = None): + if num_layers is not None and num_layers > 0: + self._coordination_enabled = True + self._cycle_count = 0 + + # Reset barrier + self._execution_barrier = threading.Barrier(2) + self._num_layers = num_layers # pyrefly: ignore[bad-assignment] + + def disable_coordination(self): + self._coordination_enabled = False + self._cycle_count = 0 + self._execution_barrier.abort() # Break barrier to unblock threads + + def check_should_continue_coordination(self): + if self._num_layers is not None and self._cycle_count >= self._num_layers: + return False + return True + + def is_coordination_enabled(self): + return self._coordination_enabled + + +# Global coordinator +_hook_coordinator = HookCoordinator() + + +class SyncHook(torch.autograd.Function): + @staticmethod + # pyrefly: ignore [bad-override] + def forward(ctx, x, hook_name=""): + ctx.hook_name = hook_name + # handle edge case for transformer level boundary + if _hook_coordinator._coordination_enabled and hook_name == "D": + _hook_coordinator._cycle_count += 1 + if not _hook_coordinator.check_should_continue_coordination(): + _hook_coordinator.disable_coordination() + return x + + _hook_coordinator.barrier() + return x + + @staticmethod + # pyrefly: ignore [bad-override] + def backward(ctx, grad_output): + hook_name = ctx.hook_name + + # Edge case, skip initial barrier, all subsequent backward hooks will acquire + if hook_name == "D" and _hook_coordinator._cycle_count == 0: + return grad_output, None + + _hook_coordinator.barrier() + return grad_output, None + + +def _count_moe_modules(model): + """Count MoE modules directly""" + from torchtitan.models.moe import MoE + + moe_count = 0 + for _, module in model.named_modules(): + if isinstance(module, MoE): + moe_count += 1 + return moe_count + + +device_type, device_module = get_device_info() + + +def overlap_callback(action: _Action, ctx: _PipelineContext): + """ + Custom callback for OVERLAP_F_B computation that allows expert parallel communication + and pipeline parallel computation to overlap. + """ + schedule = ctx.schedule_ref + assert isinstance(schedule, _PipelineScheduleRuntime) + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in schedule._stages + } + assert action.sub_actions is not None + fwd_action = action.sub_actions[0] + bwd_action = action.sub_actions[1] + + # Get stages + forward_stage_index = fwd_action.stage_index + forward_mb_index = fwd_action.microbatch_index + assert forward_mb_index is not None + backward_stage_index = bwd_action.stage_index + backward_stage = stage_index_to_stage[backward_stage_index] + + # Forward setup + arg_mbs = ctx.arg_mbs + kwarg_mbs = ctx.kwarg_mbs + assert arg_mbs is not None and kwarg_mbs is not None + fwd_recv_ops = schedule.fwd_recv_ops + forward_stage = stage_index_to_stage[forward_stage_index] + forward_is_next_stage_on_this_rank = forward_stage_index + 1 in stage_index_to_stage + forward_is_prev_stage_on_this_rank = forward_stage_index - 1 in stage_index_to_stage + + # Backward setup + backward_is_next_stage_on_this_rank = ( + backward_stage.stage_index + 1 in stage_index_to_stage + ) + backward_is_prev_stage_on_this_rank = ( + backward_stage.stage_index - 1 in stage_index_to_stage + ) + backward_mb_index = bwd_action.microbatch_index + assert backward_mb_index is not None + bwd_recv_ops = schedule.bwd_recv_ops + + # Fwd receives + if ( + not forward_stage.is_first + # no recv op expected for V-schedule special case + and not forward_is_prev_stage_on_this_rank + ): + assert ( + forward_stage_index, + forward_mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((forward_stage_index, forward_mb_index))) + + # Bwd receives + if ( + not backward_stage.is_last + # no recv op expected for V-schedule special case + and not backward_is_next_stage_on_this_rank + ): + assert ( + backward_stage_index, + backward_mb_index, + ) in bwd_recv_ops, f"Attempted to run compute {action=} before receiving input" + _wait_batch_p2p(bwd_recv_ops.pop((backward_stage_index, backward_mb_index))) + + # We count num layers in case the stage layers differ + # If they differ than we only want coordination to happen for the min amount of layers + min_num_layers = min( + _count_moe_modules(forward_stage.submod), + _count_moe_modules(backward_stage.submod), + ) + # PP computation ======================================================== + _hook_coordinator.enable_coordination(num_layers=min_num_layers) + main_stream = torch.accelerator.current_stream(device_type) + + # Shared container for exception from backward thread + def run_backward(): + # pyrefly: ignore [missing-attribute] + schedule._assert_unsharded(backward_stage) + # Set the backward thread to use the same stream as forward + device_module.set_stream(main_stream) + with record_function( + f"backward_stage_{backward_stage_index}_mb_{backward_mb_index}" + ): + loss = schedule._maybe_get_loss(backward_stage, backward_mb_index) + # pyrefly: ignore [missing-attribute] + schedule.backward_counter[backward_stage_index] += 1 + last_backward = ( + # pyrefly: ignore [missing-attribute] + schedule.backward_counter[backward_stage_index] + == schedule._n_microbatches + ) + backward_stage.backward_one_chunk( + # pyrefly: ignore [bad-argument-type] + backward_mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + + if backward_is_prev_stage_on_this_rank: + stage_index_to_stage[backward_stage_index - 1].set_local_bwd_input( + backward_stage.get_local_bwd_output(backward_mb_index), + # pyrefly: ignore [bad-argument-type] + backward_mb_index, + ) + + def run_forward(): + # pyrefly: ignore [missing-attribute] + schedule._assert_unsharded(forward_stage) + output = forward_stage.forward_one_chunk( + # pyrefly: ignore [bad-argument-type] + forward_mb_index, + # pyrefly: ignore[bad-index, unsupported-operation] + arg_mbs[forward_mb_index], + # pyrefly: ignore[bad-index, unsupported-operation] + kwarg_mbs[forward_mb_index], + ) + schedule._maybe_compute_loss( + forward_stage, output, ctx.target_mbs, forward_mb_index + ) + if forward_is_next_stage_on_this_rank: + stage_index_to_stage[forward_stage_index + 1].set_local_fwd_input( + output, + # pyrefly: ignore [bad-argument-type] + forward_mb_index, + ) + + # Run forward and backward in parallel + thread = threading.Thread(target=run_backward, daemon=True) + thread.start() + run_forward() + thread.join() + + _hook_coordinator.disable_coordination() diff --git a/torchtitan/distributed/expert_parallel.py b/torchtitan/distributed/expert_parallel.py index 461df40d9a..06185ddedb 100644 --- a/torchtitan/distributed/expert_parallel.py +++ b/torchtitan/distributed/expert_parallel.py @@ -4,8 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from abc import ABC, abstractmethod + import torch import torch.nn as nn +from torch import Tensor from torch.distributed._functional_collectives import ( all_to_all_single, all_to_all_single_autograd, @@ -24,6 +27,24 @@ from torchtitan.models.moe.utils import _permute, _unpermute +class BaseExpertParallel(ParallelStyle, ABC): + @abstractmethod + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + ... + + @abstractmethod + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: + ... + + @abstractmethod + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: + ... + + # implementation of Tensor Parallel for the GroupedExperts in MoE class TensorParallel(ParallelStyle): def _prepare_input_fn(self, mod, inputs, device_mesh): @@ -98,11 +119,12 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, self._partition_fn, + # pyrefly: ignore [bad-argument-type] self._prepare_input_fn, ) -class ExpertParallel(ParallelStyle): +class ExpertParallel(BaseExpertParallel): def __init__(self): super().__init__() self.input_splits = None @@ -110,8 +132,14 @@ def __init__(self): self.input_shape = None self.permuted_indices = None - # performing all-to-all dispatch on the input - def _token_dispatch(self, mod, inputs, device_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: + for param_name, param in mod.named_parameters(recurse=False): + dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) + mod.register_parameter(param_name, dist_param) + + def _token_dispatch( + self, mod: nn.Module, inputs: tuple, device_mesh: DeviceMesh + ) -> tuple[Tensor, Tensor]: # annotate module input placements/sharding with input_layouts routed_input, num_tokens_per_expert = inputs ep_degree = device_mesh.shape[0] @@ -165,9 +193,9 @@ def _token_dispatch(self, mod, inputs, device_mesh): # of GroupedExperts, as it does not need padding. ( - self.input_shape, + self.input_shape, # pyrefly: ignore[bad-assignment] routed_input, - self.permuted_indices, + self.permuted_indices, # pyrefly: ignore[bad-assignment] num_tokens_per_expert_group, ) = _permute( routed_input, num_tokens_per_expert_group, ep_degree, num_local_experts @@ -175,15 +203,9 @@ def _token_dispatch(self, mod, inputs, device_mesh): return routed_input, num_tokens_per_expert_group - @staticmethod - def _partition_fn(name, mod, device_mesh): - # shard on the expert dimension - for name, param in mod.named_parameters(recurse=False): - dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) - mod.register_parameter(name, dist_param) - - # performing all-to-all combine on the output - def _token_combine(self, mod, routed_output, device_mesh): + def _token_combine( + self, mod: nn.Module, routed_output: Tensor, device_mesh: DeviceMesh + ) -> Tensor: routed_output = _unpermute( routed_output, self.input_shape, self.permuted_indices ) @@ -200,8 +222,10 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, - partition_fn=ExpertParallel._partition_fn, + partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -214,8 +238,14 @@ def _token_dispatch(self, mod, inputs, device_mesh): # NOTE: Currently in MoE TP, experts multiplication runs in plain Tensors. # The grad_placements on inputs is set to Partial so that necessary # reductions are performed during backward. + + # NOTE: The mesh used here should be dense_mesh["tp"] as routed_input is + # technically wrapped with the dense_mesh["tp"] but this complicates + # the interface of ExpertTensorParallel and it doesn't matter as etp + # is almost always the same as tp or is 1. To avoid the complexity, + # we use the etp mesh here. routed_input = DTensor.from_local( - routed_input, device_mesh["tp"], (Replicate(),) + routed_input, device_mesh["etp"], (Replicate(),) ).to_local(grad_placements=(Partial(),)) inputs = (routed_input, num_tokens_per_expert) @@ -223,23 +253,26 @@ def _token_dispatch(self, mod, inputs, device_mesh): # token dispatch happens on the EP mesh, whereas device_mesh is [ep, tp] mesh return super()._token_dispatch(mod, inputs, device_mesh["ep"]) - def _partition_fn_2d(self, name, mod, ep_tp_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: # w1 shape = (experts, out_dim, in_dim) mod.register_parameter( "w1", - nn.Parameter(distribute_tensor(mod.w1, ep_tp_mesh, [Shard(0), Shard(1)])), + # pyrefly: ignore [bad-argument-type] + nn.Parameter(distribute_tensor(mod.w1, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding # w2 shape = (experts, in_dim, out_dim) mod.register_parameter( "w2", - nn.Parameter(distribute_tensor(mod.w2, ep_tp_mesh, [Shard(0), Shard(2)])), + # pyrefly: ignore [bad-argument-type] + nn.Parameter(distribute_tensor(mod.w2, device_mesh, [Shard(0), Shard(2)])), ) # Row-wise sharding # w3 shape = (experts, out_dim, in_dim) mod.register_parameter( "w3", - nn.Parameter(distribute_tensor(mod.w3, ep_tp_mesh, [Shard(0), Shard(1)])), + # pyrefly: ignore [bad-argument-type] + nn.Parameter(distribute_tensor(mod.w3, device_mesh, [Shard(0), Shard(1)])), ) # Column-wise sharding def _token_combine(self, mod, routed_output, device_mesh): @@ -250,8 +283,10 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: return distribute_module( module, device_mesh, - partition_fn=self._partition_fn_2d, + partition_fn=self._partition_fn, + # pyrefly: ignore [bad-argument-type] input_fn=self._token_dispatch, + # pyrefly: ignore [bad-argument-type] output_fn=self._token_combine, ) @@ -302,12 +337,9 @@ def _prepare_output_fn(self, mod, outputs, device_mesh): # NOTE: As we shard routed tokens along bs*slen dim across the TP ranks, # the MoE gather and scatter still require global token indices. local_rank = device_mesh.get_local_rank() - # fact: top_scores.shape[0] // mod.top_k = batch_size * seq_len // ep_degree - if not hasattr(mod, "top_k"): - raise ValueError( - "TokenReorderer class in MoE should always have top_k attribute." - ) - token_indices_experts_sorted += top_scores.shape[0] // mod.top_k * local_rank + token_indices_experts_sorted = ( + token_indices_experts_sorted + top_scores.shape[0] * local_rank + ) return top_scores, token_indices_experts_sorted, num_tokens_per_expert @@ -316,58 +348,76 @@ def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: module, device_mesh, partition_fn=None, + # pyrefly: ignore [bad-argument-type] input_fn=self._prepare_inputput_fn, + # pyrefly: ignore [bad-argument-type] output_fn=self._prepare_output_fn, ) -class ExpertParallelDeepEP(ParallelStyle): - def __init__(self): +class DeepEPExpertParallel(BaseExpertParallel): + """Expert Parallel using DeepEP for efficient token dispatch/combine. + + Expects inputs as: + (hidden_states, num_tokens_per_expert, selected_experts_indices, top_scores, num_experts) + + Args: + score_before_experts: If True, apply routing scores before expert computation. + """ + + def __init__(self, score_before_experts: bool = True): super().__init__() - self.input_splits = None - self.output_splits = None + self._state = None # State preserved between dispatch and combine + self.score_before_experts = score_before_experts - # performing all-to-all dispatch on the input def _token_dispatch(self, mod, inputs, device_mesh): - # annotate module input placements/sharding with input_layouts - routed_input, num_tokens_per_expert = inputs - - routed_input, routed_prob = mod.deepep_dispatcher.token_dispatch( - routed_input, group=device_mesh.get_group() + """Dispatch tokens via DeepEP.""" + from torchtitan.distributed.deepep import dispatch_tokens + + hidden_states, _, selected_experts_indices, top_scores, num_experts = inputs + if isinstance(mod.w1, DTensor): + num_local_experts = mod.w1.to_local().shape[0] + else: + num_local_experts = mod.w1.shape[0] + ep_group = device_mesh.get_group() + + # pyrefly: ignore[bad-assignment] + hidden_states, tokens_per_expert, self._state = dispatch_tokens( + hidden_states, + selected_experts_indices, + top_scores, + num_local_experts, + num_experts, + ep_group, + score_before_experts=self.score_before_experts, ) - ( - routed_input, - num_tokens_per_expert, - routed_prob, - ) = mod.deepep_dispatcher.dispatch_postprocess(routed_input, None) - # NOTE: routed_prob is returned and passed to GroupedExperts.forward(). - # When fused_weighted_scatter_add=True, probs are also stored in dispatcher - # for use in unpermute(). When False, GroupedExperts.forward() handles - # the multiplication directly. - return routed_input, num_tokens_per_expert, routed_prob + return hidden_states, tokens_per_expert @staticmethod def _partition_fn(name, mod, device_mesh): - # shard on the expert dimension - for name, param in mod.named_parameters(recurse=False): - dist_param = nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])) - mod.register_parameter(name, dist_param) + """Shard expert weights on expert dimension.""" + for param_name, param in mod.named_parameters(recurse=False): + mod.register_parameter( + param_name, + nn.Parameter(distribute_tensor(param, device_mesh, [Shard(0)])), + ) - # performing all-to-all combine on the output def _token_combine(self, mod, routed_output, device_mesh): - routed_output = mod.deepep_dispatcher.combine_preprocess(routed_output) - routed_output = mod.deepep_dispatcher.token_combine( - routed_output, group=device_mesh.get_group() - ) - # TODO: combine post process? + """Combine tokens via DeepEP.""" + from torchtitan.distributed.deepep import combine_tokens + + # pyrefly: ignore [bad-argument-type] + routed_output = combine_tokens(routed_output, self._state) + self._state = None return routed_output def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + """Apply DeepEP parallelization.""" return distribute_module( module, device_mesh, - partition_fn=self._partition_fn, - input_fn=self._token_dispatch, - output_fn=self._token_combine, + partition_fn=DeepEPExpertParallel._partition_fn, + input_fn=self._token_dispatch, # pyrefly: ignore [bad-argument-type] + output_fn=self._token_combine, # pyrefly: ignore [bad-argument-type] ) diff --git a/torchtitan/distributed/parallel_dims.py b/torchtitan/distributed/parallel_dims.py index 44822039a6..86173ba78a 100644 --- a/torchtitan/distributed/parallel_dims.py +++ b/torchtitan/distributed/parallel_dims.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from dataclasses import dataclass +from dataclasses import dataclass, field from torch.distributed.device_mesh import DeviceMesh, init_device_mesh @@ -26,7 +26,8 @@ class ParallelDims: etp: int world_size: int - _world_mesh: DeviceMesh = None + _meshes: dict[str, DeviceMesh] = field(default_factory=dict) + _world_mesh: DeviceMesh | None = None def __post_init__(self): self._validate() @@ -56,143 +57,253 @@ def _validate(self): if ep > 1: assert etp == tp or etp == 1, "Currently we only support ETP=TP or ETP=1" - if etp == tp: - # EP would borrow all cp and some dp_shard degree - assert ep % cp == 0 and (dp_shard * cp) % ep == 0 - elif etp == 1: - # EP would borrow all cp and tp and some dp_shard degree - assert ep % (cp * tp) == 0 and (dp_shard * cp * tp) % ep == 0 + + def _mesh_exist(self, name: str, degree: int) -> bool: + if name == "efsdp": + # We always keep the efsdp if EP is larger than 1 because we need + # FSDP wrapping to help the MoE layers do mixed precision training. + return True if self.ep > 1 else False + return degree > 1 def build_mesh(self) -> DeviceMesh: - # TODO: Current implementation of ParallelDims for dp2ep Expert Parallel - # is not very clean, due to the limited support from DeviceMesh - # for creating two staggered meshes. Will improve. - if self.ep > 1: - return self._build_mesh_with_ep() - else: - return self._build_mesh_without_ep() - - def _build_mesh_with_ep(self) -> DeviceMesh: - # With ep, dp_shard and ep are derived submeshes: - # dp_shard = dp_shard_mod_ep * dp_shard_in_ep - if self.etp == self.tp: - # ep = dp_shard_in_ep * cp - dp_shard_mod_ep = self.dp_shard * self.cp // self.ep - dp_shard_in_ep = self.ep // self.cp - else: - assert self.etp == 1 - # ep = dp_shard_in_ep * cp * tp - dp_shard_mod_ep = self.dp_shard * self.cp * self.tp // self.ep - dp_shard_in_ep = self.ep // (self.cp * self.tp) - - dims = [] - names = [] - for d, name in zip( - [ - self.pp, - self.dp_replicate, - dp_shard_mod_ep, - dp_shard_in_ep, - self.cp, - self.tp, - ], - ["pp", "dp_replicate", "dp_shard_mod_ep", "dp_shard_in_ep", "cp", "tp"], + """ + Build the device mesh with the required mesh dimensions. + + The following mesh dimensions will be created: + + pp: Pipeline Parallelism (PP). + batch: Used by data loading to determine the global batch size and which + part of the data each rank should read. This dimension includes both + ``dp_replicate`` and ``dp_shard``. The backend is set to ``fake`` for + this dimension to avoid unnecessary process group creation. + loss: Used by all-reduce when computing the loss. Includes ``dp_replicate``, + ``dp_shard``, and ``cp`` degrees, as all of them parallelize the data, + essentially require the weight gradients reduction. + dp_replicate: For DDP or HSDP replicate dimension. + fsdp: For FSDP dimension. This includes ``dp_shard`` and ``cp``. Note that + we always assume that when ``cp`` is used, FSDP is also applied to + utilize its weight all-gather and gradients reduce_scatter even if + there may be no data parallelism (e.g., global batch size is 1). + cp: Context Parallelism (CP). + tp: Tensor Parallelism (TP). + ep: Expert Parallelism (EP). + efsdp: FSDP in the EP region. + etp: TP in the EP region. + + Note: Most dimensions above are created by unflattening the world mesh, except for loss, + which is created by flattening the batch and cp dimensions. + This API performs the following unflatten operations from the world mesh: + + ["pp", "batch", "cp", "tp"] # dataloading_mesh + ["pp", "dp_replicate", "fsdp", "tp"] # dense_mesh + ["pp", "dp_replicate", "efsdp", "ep", "etp"] # sparse_mesh + + Note: DeviceMesh currently recreates the process group for each dimension. + It should share the process group for the same dim group to avoid unnecessary + process group creation. We can also use Fake to achieve a similar goal. + However, using Fake to avoid redundancy messing up the code. We only use Fake + when it is necessary. For now, we just let DeviceMesh create redundant process + group and wait for DeviceMesh to fix the issue. + """ + + def unflatten_mesh( + world_mesh: DeviceMesh, + dim_names: tuple[str, ...], + dim_degrees: tuple[int, ...], ): - # dp_shard_mod_ep is needed even if it's 1, whose FSDP wrapping - # helps the MoE layers do mixed precision training - if d > 1 or name == "dp_shard_mod_ep": - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - # Mesh for ep - ep_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - # dp_shard_mod_ep is always needed, even if it's 1 - dp_mesh_dim_names.append("dp_shard_mod_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_mod_ep") - dp_cp_mesh_dim_names.append("dp_shard_mod_ep") - if "dp_shard_in_ep" in names: - dp_mesh_dim_names.append("dp_shard_in_ep") - dp_shard_cp_mesh_dim_names.append("dp_shard_in_ep") - dp_cp_mesh_dim_names.append("dp_shard_in_ep") - ep_mesh_dim_names.append("dp_shard_in_ep") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - ep_mesh_dim_names.append("cp") - if self.etp == 1 and self.tp_enabled: - ep_mesh_dim_names.append("tp") - - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_shard_cp") - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") - mesh[tuple(ep_mesh_dim_names)]._flatten(mesh_dim_name="ep") + """Unflatten the world mesh to create the required mesh dimensions. + + Uses fake backend for dimensions with degree 1 or for 'batch' dimension + to avoid unnecessary process group creation. + """ + backend_override = {} + for name, degree in zip(dim_names, dim_degrees, strict=True): + if (not self._mesh_exist(name, degree)) or name == "batch": + backend_override[name] = "fake" + + return world_mesh._unflatten( + 0, dim_degrees, dim_names, backend_override=backend_override + ) - return mesh + logger.info( + f"Building device mesh with parallelism: " + f"pp={self.pp}, dp_replicate={self.dp_replicate}, dp_shard={self.dp_shard}, " + f"cp={self.cp}, tp={self.tp}, ep={self.ep}, etp={self.etp}" + ) - def _build_mesh_without_ep(self) -> DeviceMesh: - dims = [] - names = [] - for d, name in zip( - [self.pp, self.dp_replicate, self.dp_shard, self.cp, self.tp], - ["pp", "dp_replicate", "dp_shard", "cp", "tp"], - ): - if d > 1: - dims.append(d) - names.append(name) - - logger.info(f"Building {len(dims)}-D device mesh with {names}, {dims}") - mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - - # Create all the submesh here to ensure all required process groups are - # initialized: - # Mesh for data loading (no communication on this mesh) - dp_mesh_dim_names = [] - # Mesh for param sharding - dp_shard_cp_mesh_dim_names = [] - # Mesh for loss all-reduce - dp_cp_mesh_dim_names = [] - - if self.dp_replicate_enabled: - dp_mesh_dim_names.append("dp_replicate") - dp_cp_mesh_dim_names.append("dp_replicate") - if self.dp_shard_enabled: - dp_mesh_dim_names.append("dp_shard") - dp_shard_cp_mesh_dim_names.append("dp_shard") - dp_cp_mesh_dim_names.append("dp_shard") - if self.cp_enabled: - dp_shard_cp_mesh_dim_names.append("cp") - dp_cp_mesh_dim_names.append("cp") - - if dp_mesh_dim_names != []: - mesh[tuple(dp_mesh_dim_names)]._flatten(mesh_dim_name="dp") - if dp_shard_cp_mesh_dim_names != []: - mesh[tuple(dp_shard_cp_mesh_dim_names)]._flatten( - mesh_dim_name="dp_shard_cp" + batch = self.dp_replicate * self.dp_shard + fsdp = self.dp_shard * self.cp + efsdp = fsdp * self.tp // (self.etp * self.ep) + + self._world_mesh = init_device_mesh( + device_type, (self.world_size,), mesh_dim_names=("world",) + ) + dataloading_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "batch", "cp", "tp"), + (self.pp, batch, self.cp, self.tp), + ) + loss_mesh = dataloading_mesh["batch", "cp"]._flatten("loss_mesh") + dense_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "fsdp", "tp"), + (self.pp, self.dp_replicate, fsdp, self.tp), + ) + sparse_mesh = unflatten_mesh( + self._world_mesh, + ("pp", "dp_replicate", "efsdp", "ep", "etp"), + (self.pp, self.dp_replicate, efsdp, self.ep, self.etp), + ) + + self._global_meshes = { + "dataloading": dataloading_mesh, + "loss": loss_mesh, + "dense": dense_mesh, + "sparse": sparse_mesh, + } + + self._meshes = { + "pp": dataloading_mesh["pp"], + "batch": dataloading_mesh["batch"], + "loss": loss_mesh, + "dp_replicate": dense_mesh["dp_replicate"], + "fsdp": dense_mesh["fsdp"], + "cp": dataloading_mesh["cp"], + "tp": dataloading_mesh["tp"], + "ep": sparse_mesh["ep"], + "efsdp": sparse_mesh["efsdp"], + "etp": sparse_mesh["etp"], + } + + # Validate mesh sizes + self._validate_meshes() + + logger.info( + f"Successfully created meshes with active dimensions: " + f"{list(self.get_all_one_dimensional_meshes().keys())}" + ) + + return self._world_mesh + + def _validate_meshes(self): + """Validate that created meshes have the expected sizes.""" + expected_sizes = { + "pp": self.pp, + "batch": self.dp_replicate * self.dp_shard, + "loss": self.dp_replicate * self.dp_shard * self.cp, + "dp_replicate": self.dp_replicate, + "fsdp": self.dp_shard * self.cp, + "cp": self.cp, + "tp": self.tp, + "ep": self.ep, + "efsdp": self.dp_shard * self.cp * self.tp // (self.etp * self.ep), + "etp": self.etp, + } + + for mesh_name, expected_size in expected_sizes.items(): + actual_size = self._meshes[mesh_name].size() + assert actual_size == expected_size, ( + f"Mesh '{mesh_name}' has unexpected size: " + f"expected {expected_size}, got {actual_size}" ) - if dp_cp_mesh_dim_names != []: - mesh[tuple(dp_cp_mesh_dim_names)]._flatten(mesh_dim_name="dp_cp") + def get_optional_mesh(self, dims: str | list[str]) -> DeviceMesh | None: + """Get a device mesh by dimension name(s), returning None if not enabled. + + Args: + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp'. + + Returns: + DeviceMesh for the requested dimension(s), or None if: + - The dimension size is 1 (parallelism not enabled) + - The dimension doesn't exist (except efsdp which can exist even if size is 1 when ep > 1) + + Raises: + ValueError: If the requested dimension name(s) is not valid. + """ + if not self._meshes: + self.build_mesh() + + if isinstance(dims, str): + dims = [dims] + + for mesh_name in dims: + if mesh_name not in self._meshes: + raise ValueError( + f"Invalid mesh dim: '{mesh_name}'. " + f"Valid dimensions are: {list(self._meshes.keys())}" + ) + + if any(not self._mesh_exist(dim, self._meshes[dim].size()) for dim in dims): + return None + + if len(dims) == 1: + return self._meshes[dims[0]] + else: + for global_mesh in self._global_meshes.values(): + assert global_mesh.mesh_dim_names is not None + if not set(dims).issubset(set(global_mesh.mesh_dim_names)): + continue + return global_mesh[tuple(dims)] + raise ValueError(f"Invalid mesh name combinations {dims}.") + + def get_mesh(self, dims: str | list[str]) -> DeviceMesh: + """Get a device mesh by dimension name(s), raising if not available. + + Args: + dims: Names of the mesh dimension. Valid options include: + 'pp', 'batch', 'loss', 'dp_replicate', 'fsdp', + 'cp', 'tp', 'ep', 'etp', 'efsdp'. + + Returns: + DeviceMesh for the requested dimension(s). + + Raises: + ValueError: If the mesh is not available (dimension size = 1 or not enabled), + or if the requested dimension name(s) is not valid. + """ + mesh = self.get_optional_mesh(dims) + if mesh is None: + enabled_str = ( + "enabled (size > 1)" if isinstance(dims, str) else "all enabled" + ) + raise ValueError( + f"Mesh '{dims}' is not available. " + f"Ensure the corresponding parallelism dimension is {enabled_str}." + ) return mesh + def get_all_one_dimensional_meshes(self) -> dict[str, DeviceMesh]: + """Get all enabled one-dimensional device meshes. + + Returns a dictionary of enabled one-dimensional device meshes, allowing you to + access their process groups. + + Note: + Device meshes created with the Fake backend are still included in the results. + + Returns: + dict[str, DeviceMesh]: A dictionary mapping mesh dimension names to their + corresponding DeviceMesh objects. Only includes meshes where: + - ndim == 1 (one-dimensional) + - parallelism is enabled (size > 1) + + Example: + >>> parallel_dims = ParallelDims( + ... dp_replicate=2, dp_shard=2, cp=1, tp=2, pp=1, ep=1, etp=1, world_size=8 + ... ) + >>> meshes = parallel_dims.get_all_one_dimensional_meshes() + >>> print(meshes.keys()) + dict_keys(['dp_replicate', 'fsdp', 'tp', 'batch', 'loss', 'efsdp']) + """ + if not self._meshes: + self.build_mesh() + return {k: v for k, v in self._meshes.items() if v.ndim == 1 and v.size() > 1} + @property def world_mesh(self) -> DeviceMesh: - # doing late init so ParallelDims can still be used as a lightweight - # dataclass without having to initialize the world mesh if self._world_mesh is None: self._world_mesh = self.build_mesh() return self._world_mesh diff --git a/torchtitan/distributed/pipeline_parallel.py b/torchtitan/distributed/pipeline_parallel.py index 7840b29879..799eaa0279 100644 --- a/torchtitan/distributed/pipeline_parallel.py +++ b/torchtitan/distributed/pipeline_parallel.py @@ -14,33 +14,21 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.pipelining import PipelineStage -try: - from torch.distributed.pipelining.schedules import ( - _PipelineSchedule, - _PipelineScheduleRuntime, - get_schedule_class, - PipelineScheduleMulti, - PipelineScheduleSingle, - ScheduleDualPipeV, - ScheduleZBVZeroBubble, - ) -except ImportError: - print("Not using 2.9 or nightly, ScheduleDualPipeV not available.") - from torch.distributed.pipelining.schedules import ( - _PipelineSchedule, - _PipelineScheduleRuntime, - get_schedule_class, - PipelineScheduleMulti, - PipelineScheduleSingle, - ScheduleZBVZeroBubble, - ) - - ScheduleDualPipeV = None - +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + _PipelineScheduleRuntime, + get_schedule_class, + OVERLAP_F_B, + PipelineScheduleMulti, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) from torchtitan.components.loss import LossFunction, rescale_accumulated_loss from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.distributed.dual_pipe_v import overlap_callback from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction from torchtitan.tools.logging import logger @@ -61,7 +49,7 @@ def pipeline_llm( parallelize_fn: ParallelizeFunction, loss_fn: LossFunction, ) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: - pp_mesh = parallel_dims.world_mesh["pp"] + pp_mesh = parallel_dims.get_mesh("pp") # Determine the number of virtual stages based on schedule type schedule_class = get_schedule_class( @@ -212,7 +200,9 @@ def build_pipeline_schedule( f"of stages ({num_total_stages}) which may result in a bubble in the pipeline." ) + # pyrefly: ignore [bad-instantiation] schedule = schedule_class( + # pyrefly: ignore [bad-argument-type] stages if looped_schedule else stages[0], n_microbatches=n_microbatches, loss_fn=rescale_accumulated_loss(loss_fn, n_microbatches), @@ -223,6 +213,11 @@ def build_pipeline_schedule( f"with {n_microbatches} microbatches and {num_total_stages} stages." ) + if job_config.parallelism.pipeline_parallel_expert_parallel_overlap and isinstance( + schedule, ScheduleDualPipeV + ): + schedule.register_custom_function(OVERLAP_F_B, overlap_callback) + if pp_schedule_csv: assert schedule_class in [ PipelineScheduleSingle, @@ -232,6 +227,7 @@ def build_pipeline_schedule( "Only PipelineScheduleSingle (single stage), PipelineScheduleMulti (multistage), " "and _PipelineScheduleRuntime support csv schedules" ) + # pyrefly: ignore [missing-attribute] schedule._load_csv(pp_schedule_csv) return schedule @@ -457,7 +453,7 @@ def _build_stage_from_modules( else: style = "v" if schedule_class is ScheduleZBVZeroBubble else "loop" - def _get_stage_indices() -> tuple[int]: + def _get_stage_indices() -> tuple[int, ...]: """ Compute the stage ids for the stages that will run on this pp rank for either a looped or V style schedule @@ -476,6 +472,8 @@ def _get_stage_indices() -> tuple[int]: zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) ) return stage_v_pairs[pp_rank] + else: + raise ValueError(f"Unknown style {style}") for stage_idx in _get_stage_indices(): module_names = module_names_per_stage[stage_idx] diff --git a/torchtitan/distributed/tensor_parallel.py b/torchtitan/distributed/tensor_parallel.py index a2749f4c11..60101a2862 100644 --- a/torchtitan/distributed/tensor_parallel.py +++ b/torchtitan/distributed/tensor_parallel.py @@ -6,6 +6,7 @@ import torch +import torch._inductor.config from torch.distributed.device_mesh import DeviceMesh from torchtitan.config import JobConfig @@ -17,11 +18,10 @@ def maybe_enable_async_tp(job_config: JobConfig, tp_mesh: DeviceMesh): return if not (job_config.compile.enable and "model" in job_config.compile.components): - raise RuntimeError("Async TP requires --training.compile") - - from torch.distributed._symmetric_memory import enable_symm_mem_for_group + raise RuntimeError( + "Async TP requires 'model' in --compile.components and --compile.enable" + ) torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info("Async TP is enabled") diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 2eaa7d7b95..3c9366cc95 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -7,12 +7,16 @@ import contextlib import math import os -from collections.abc import Generator, Iterable +from abc import abstractmethod +from collections.abc import Iterable from datetime import timedelta +from typing import Protocol import torch import torch.distributed._functional_collectives as funcol import torch.distributed.distributed_c10d as c10d +import torch.distributed.tensor._random +import torch.distributed.tensor.parallel from torch import distributed as dist from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import DTensor @@ -26,7 +30,7 @@ def _dist_reduce( x: torch.Tensor, reduceOp: str, - mesh: DeviceMesh, + mesh: DeviceMesh | None, extra_pg: dist.ProcessGroup | None, ) -> float: """Perform distributed reduction on a tensor. @@ -34,7 +38,8 @@ def _dist_reduce( Args: x (torch.Tensor): Input tensor. reduceOp (str): Reduce operation to perform. - mesh (DeviceMesh): Device mesh to use for reduction. + mesh (DeviceMesh | None): Device mesh to use for reduction. + If None, no reduction is performed but simply convert the tensor to a float. extra_pg (dist.ProcessGroup, optional): Extra process group to use for reduction. Defaults to None. If provided, this all_reduce will be called for the extra process group, and then the result will be all_reduced for the mesh. @@ -46,13 +51,17 @@ def _dist_reduce( if extra_pg is not None: x = funcol.all_reduce(x, reduceOp=reduceOp, group=extra_pg) + if mesh is None: + return x.item() + assert x.numel() == 1 # required by `.item()` return funcol.all_reduce(x, reduceOp=reduceOp, group=mesh).item() +# TODO: rename this to maybe_dist_max def dist_max( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -62,7 +71,7 @@ def dist_max( def dist_sum( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -72,7 +81,7 @@ def dist_sum( def dist_mean( x: torch.Tensor, - mesh: DeviceMesh, + mesh: DeviceMesh | None = None, extra_pg: dist.ProcessGroup | None = None, ) -> float: return _dist_reduce( @@ -81,7 +90,7 @@ def dist_mean( def set_determinism( - world_mesh: DeviceMesh | None, + parallel_dims: ParallelDims, device: torch.device, debug_config: DebugConfig, distinct_seed_mesh_dims: list[str], @@ -99,9 +108,8 @@ def set_determinism( Args: world_mesh: Device mesh for distributed training device: Device to use + debug_config: Debug config to use distinct_seed_mesh_dims: List of mesh dimension names to have distinct seeds across. - seed: Base seed value (if None, will be determined automatically) - deterministic: Whether to enable deterministic algorithms """ if debug_config.deterministic: logger.info("Deterministic algorithm enabled (expect perf degradation).") @@ -124,7 +132,7 @@ def set_determinism( FlexAttentionWrapper._compiled_flex_attn = torch.compile(flex_attention) seed = debug_config.seed - if not world_mesh: + if parallel_dims.world_size == 1: if seed is not None: torch.manual_seed(seed) os.environ["PYTHONHASHSEED"] = str(seed % 2**32) @@ -139,25 +147,25 @@ def set_determinism( seed_tensor = torch.get_rng_state()[:8].to(device) torch.distributed.broadcast(seed_tensor, src=0) seed = seed_tensor.to("cpu").view(torch.uint64).item() + assert isinstance(seed, int) # Set distinct seed for each rank in mesh dimensions, with dimension names provided by `distinct_seed_mesh_dims` # For PP + SPMD cases, we want to separate the world into the SPMD mesh and the PP mesh, # and choose a unique seed for each rank on the PP mesh. # We support multiple distinct dimensions by adding each distinct dimension's local rank to the seed. - distinct_dims_in_mesh = [ - dim - for dim in distinct_seed_mesh_dims - if world_mesh.mesh_dim_names and dim in world_mesh.mesh_dim_names + distinct_seed_meshes = [ + parallel_dims.get_optional_mesh(dim) for dim in distinct_seed_mesh_dims ] + distinct_seed_meshes = [mesh for mesh in distinct_seed_meshes if mesh is not None] + assert all(mesh is not None for mesh in distinct_seed_meshes) - if c10d.get_world_size() > 1 and distinct_dims_in_mesh: + if distinct_seed_meshes: # Each dimension contributes: local_rank * (product of all previous dimension sizes) # This guarantees uniqueness like multi-dimensional array indexing seed_offset = 0 cumulative_size = 1 - for dim in distinct_dims_in_mesh: - distinct_mesh = world_mesh[dim] + for distinct_mesh in distinct_seed_meshes: local_rank = distinct_mesh.get_local_rank() # Add contribution from this dimension seed_offset += local_rank * cumulative_size @@ -168,20 +176,10 @@ def set_determinism( seed %= 2**64 logger.debug( - f"Distinct dims {distinct_dims_in_mesh}, Global rank {c10d.get_rank()} using seed: {seed}" + f"Distinct dims {distinct_seed_mesh_dims}, Global rank {c10d.get_rank()} using seed: {seed}" ) - # Filter out all distinct dimensions to get duplicate_seed_mesh - duplicate_seed_mesh_dims = [ - name - for name in world_mesh.mesh_dim_names - if name not in distinct_dims_in_mesh - ] - duplicate_seed_mesh = ( - world_mesh[duplicate_seed_mesh_dims] if duplicate_seed_mesh_dims else None - ) else: - duplicate_seed_mesh = world_mesh logger.debug(f"Global Rank {c10d.get_rank()} using seed: {seed}") # The native RNGs and python RNG may not be important, except for the 1-D PP case, but we seed them for consistency. @@ -189,10 +187,14 @@ def set_determinism( # PYTHONHASHSEED can be a decimal number in the range [0, 2**32 - 1] os.environ["PYTHONHASHSEED"] = str(seed % 2**32) - # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for all ranks of the SPMD mesh. - # IF PP is also used, this seed is unique per PP rank. - if duplicate_seed_mesh and duplicate_seed_mesh.get_coordinate() is not None: - torch.distributed.tensor._random.manual_seed(seed, duplicate_seed_mesh) + # As long as we are not in the 1-D (PP-only) case, we will have a seed to use for + # all ranks of the SPMD mesh. If PP is also used, this seed is unique per PP rank. + # TODO: remove the need of passing in a mesh once + # torch.distributed.tensor._random.manual_seed doesn't require a mesh input. + if parallel_dims.world_size > parallel_dims.pp: + # We just need to pass the world_mesh as the device_id is the only information + # this API uses. + torch.distributed.tensor._random.manual_seed(seed, parallel_dims.world_mesh) def create_context_parallel_ctx( @@ -205,11 +207,11 @@ def create_context_parallel_ctx( try: from torch.distributed.tensor.experimental import context_parallel from torch.distributed.tensor.experimental._attention import set_rotate_method - except ImportError: - print( + except ImportError as e: + raise ValueError( f"PyTorch version {torch.__version__} does not include the experimental " "Context Parallel API. Please update to a newer version." - ) + ) from e set_rotate_method(cp_rotate_method) return context_parallel( @@ -220,24 +222,27 @@ def create_context_parallel_ctx( ) -def get_train_context(enable_loss_parallel: bool) -> Generator[None, None, None]: +class TrainContext(Protocol): + @abstractmethod + def __call__(self) -> contextlib.AbstractContextManager[None]: + pass + + +def get_train_context(enable_loss_parallel: bool) -> TrainContext: @contextlib.contextmanager - def context(cp_context: Generator[None, None, None] | None = None): + def context(): with contextlib.ExitStack() as stack: if enable_loss_parallel: stack.enter_context(torch.distributed.tensor.parallel.loss_parallel()) - if cp_context: - stack.enter_context(cp_context) - yield return context def maybe_enable_amp( - parallel_dims: ParallelDims, mixed_precision_param: str, device_type: torch.device -) -> Generator[None, None, None]: + parallel_dims: ParallelDims, mixed_precision_param: str, device_type: str +) -> contextlib.AbstractContextManager[None]: if parallel_dims.fsdp_enabled: # FSDP handles mixed precision internally logger.info("Mixed precision training is handled by fully_shard") @@ -252,15 +257,58 @@ def maybe_enable_amp( else: # the following code will only be executed for DDP or single-device training logger.info("Mixed precision training is handled by AMP") + # pyrefly: ignore [bad-return] return torch.autocast( device_type, dtype=TORCH_DTYPE_MAP[mixed_precision_param], ) +def init_fake_mode(world_size: int, comm_mode: str = "fake_backend"): + """Initialize fake backend + + Args: + world_size: The number of GPUs to simulate + comm_mode: Communication mode ("fake_backend" or "local_tensor") + + Returns: + The world size + """ + torch.distributed.init_process_group( + "fake", + rank=0, + world_size=world_size, + ) + + # If local_tensor mode is enabled, initialize LocalTensorMode context + if comm_mode == "local_tensor": + from torch.distributed import _local_tensor + + lm = _local_tensor.LocalTensorMode(world_size) + lm.__enter__() + + def init_distributed( - comm_config: CommConfig, enable_cpu_backend: bool = False, base_folder: str = "" -): + comm_config: CommConfig, + enable_cpu_backend: bool = False, + base_folder: str = "", + ranks: list[int] | None = None, +) -> int: + if comm_config.mode in ("fake_backend", "local_tensor"): + ngpu_str = os.environ.get("NGPU") + if ngpu_str is None: + raise ValueError( + f"NGPU environment variable must be set when using comm_mode={comm_config.mode}" + ) + try: + world_size = int(ngpu_str) + except ValueError as e: + raise ValueError( + f"NGPU environment variable must be a valid integer, got: {ngpu_str}" + ) from e + init_fake_mode(world_size, comm_config.mode) + return world_size + def _warn_overwrite_env(env, val): if env in os.environ: logger.warning( @@ -303,10 +351,16 @@ def _get_distributed_backend(enable_cpu_backend): torch.distributed.init_process_group( backend=_get_distributed_backend(enable_cpu_backend), timeout=timedelta(seconds=comm_config.init_timeout_seconds), + _ranks=ranks if ranks is not None else [], ) + return torch.distributed.get_world_size() + -def set_pg_timeouts(timeout, world_mesh): +def set_pg_timeouts( + timeout: timedelta, + parallel_dims: ParallelDims, +): """ Sets the timeout for all PGs in the provided mesh, and the default (world) group. @@ -325,10 +379,11 @@ def set_pg_timeouts(timeout, world_mesh): torch.distributed.barrier(device_ids=[device_module.current_device()]) device_module.synchronize() - groups = [world_mesh.get_group(mesh_dim) for mesh_dim in range(world_mesh.ndim)] - # None represents the 'default' PG, not part of the mesh - groups.append(None) + groups: list[torch.distributed.ProcessGroup | None] = [ + mesh.get_group() + for mesh in parallel_dims.get_all_one_dimensional_meshes().values() + ] + [None] for group in groups: torch.distributed.distributed_c10d._set_pg_timeout(timeout, group) @@ -406,9 +461,9 @@ def clip_grad_norm_( if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: - total_norm **= norm_type + total_norm **= norm_type # pyrefly: ignore[unsupported-operation] dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] if max_norm > 0: torch.nn.utils.clip_grads_with_norm_(parameters, max_norm, total_norm, foreach) @@ -433,6 +488,7 @@ def _clip_grad_norm_with_ep( if p.grad is None: continue assert isinstance(p, DTensor) and isinstance(p.grad, DTensor) + # pyrefly: ignore[not-iterable] if "ep" in p.device_mesh.mesh_dim_names: ep_params.append(p) ep_grads.append(p.grad) @@ -447,6 +503,7 @@ def _clip_grad_norm_with_ep( if isinstance(ep_grads_total_norm, DTensor): ep_grads_total_norm = ep_grads_total_norm.full_tensor() + # pyrefly: ignore [missing-attribute] non_ep_grads_total_norm = torch.nn.utils.get_total_norm( non_ep_grads, norm_type, error_if_nonfinite, foreach ).full_tensor() @@ -455,17 +512,19 @@ def _clip_grad_norm_with_ep( total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type + # pyrefly: ignore[unsupported-operation] + ep_grads_total_norm**norm_type + + non_ep_grads_total_norm**norm_type ) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] if pp_mesh is not None: if math.isinf(norm_type): dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) else: - total_norm **= norm_type + total_norm **= norm_type # pyrefly: ignore[unsupported-operation] dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) - total_norm **= 1.0 / norm_type + total_norm **= 1.0 / norm_type # pyrefly: ignore[unsupported-operation] if max_norm > 0: torch.nn.utils.clip_grads_with_norm_(ep_params, max_norm, total_norm, foreach) diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 0d6db0d2a1..53df45dd84 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -10,7 +10,7 @@ We provide this `experiments/` folder to host experiments that add significant v 3. An experiment should reuse existing `torchtitan` code as much as possible, such as modules in [`components/`](../components/) (via a new [`TrainSpec`](../protocols/train_spec.py)) and [`train.py`](../train.py). For a list of extension points we provide, please refer to [docs/extension.md](../../docs/extension.md). - The extension points are subject to change. We kindly request that contributors provide feedback if they encounter issues reusing any components, rather than simply using a copy-and-paste approach. - The degree to which existing components are reused and whether duplications are legit will also be a criteria of whether an experiment would be accepted. -4. Each experiment is independent from other experiments, and can have its own dependencies (on top of [core dependencies](../../requirements.txt)), and its own tests. +4. Each experiment is independent from other experiments, and can have its own dependencies (on top of [core dependencies](../../requirements.txt)), and its own tests. An experiment should not contain vendor-specific code, such as kernels written in a proprietary language. Those can be hosted outside as dependency. 5. The dependency from `experiments` to `core` is one-way. Anything in `experiments` is optional for `core` to run successfully. In particular, development in `core` is not blocked by breakage in `experiments`. We will utilize GitHub's [CI mechanism](https://docs.github.com/en/actions/writing-workflows/workflow-syntax-for-github-actions#onpushpull_requestpull_request_targetpathspaths-ignore) to help test an experiment periodically and only if the experiment itself is affected by a PR. 6. Each experiment needs to have an owner. The owner is responsible to work with `torchtitan` team to maintain the quality and healthiness of an experiment, which includes - adapting an experiment to changes in `core` and fix broken tests, no later than the next official `torchtitan` release; @@ -27,7 +27,10 @@ We provide this `experiments/` folder to host experiments that add significant v | [simple_fsdp](./simple_fsdp/) | [![SimpleFSDP 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) | [@ruisizhang123](https://github.com/ruisizhang123) [@tianyu-l](https://github.com/tianyu-l) | | [vlm](./vlm/) | [![VLM 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml?query=branch%3Amain) | [@lkhphuc](https://github.com/lkhphuc) | | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | -| [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | +| [torchcomms](./torchcomms/) | [![TorchComms 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_torchcomms.yaml?query=branch%3Amain) | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | | [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | -| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | +| [gpt_oss](./gpt_oss/) | TBA | [@wwwjn](https://github.com/wwwjn) | | [compiler_toolkit](./compiler_toolkit/) | [![Compiler Toolkit 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_compiler_toolkit.yaml?query=branch%3Amain) | [@SherlockNoMad](https://github.com/SherlockNoMad) [@yiming0416](https://github.com/yiming0416) | +| [transformers_modeling_backend](./transformers_modeling_backend/) | [![Transformers modeling backend 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_transformers_modeling_backend.yaml?query=branch%3Amain) | [@3outeille](https://github.com/3outeille) | +| [rl](./rl/) | TBA | [@bwasti](https://github.com/bwasti) [@wwwjn](https://github.com/wwwjn) | +| [autoparallel](./autoparallel/) | [![Auto Parallel 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_autoparallel.yaml?query=branch%3Amain) | [@wconstab](https://github.com/wconstab) [@xmfan](https://github.com/xmfan) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 673c48fc31..7c6f034e65 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -6,12 +6,15 @@ _supported_experiments = frozenset( [ - "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm", "compiler_toolkit.deepseek_v3", "compiler_toolkit.llama3", + "transformers_modeling_backend", + "autoparallel.llama3", + "autoparallel.deepseek_v3", + "autoparallel.local_map_deepseek_v3", "qwen3_next", "kimi_linear", ] diff --git a/torchtitan/experiments/autoparallel/README.md b/torchtitan/experiments/autoparallel/README.md new file mode 100644 index 0000000000..54f3c95fb7 --- /dev/null +++ b/torchtitan/experiments/autoparallel/README.md @@ -0,0 +1,25 @@ +## Auto Parallel + +### Overview + +The Auto Parallel experiment integrates PyTorch's AutoParallel framework with TorchTitan to automatically optimize distributed training parallelism strategies given a device mesh. Instead of manually configuring parallelism layouts, AutoParallel uses cost-based analysis to determine optimal sharding placements for model parameters, activations, and gradients. + +### Requirements + +Requires installing [git@github.com:meta-pytorch/autoparallel.git](https://github.com/meta-pytorch/autoparallel) + +### Single Node + +**Llama3** + +`CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.llama3 --parallelism.tensor_parallel_degree 4 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` + +**DeepSeekv3** + +`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name autoparallel.deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config` + +**DeepSeekv3 local_map** + +This is a variant of titan's DSv3, which uses a local_map for the expert parallel region. This only supports 2D mesh right now. NOTE: the mesh provided are just to reuse torchtitan's trainer mesh setup code. Autoparallel is not bound to use dp2ep. + +`NGPU=2 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml tlp ./run_train.sh --model.name autoparallel.local_map_deepseek_v3 --job.custom_config_module=torchtitan.experiments.autoparallel.job_config --parallelism.data_parallel_shard_degree 2 --parallelism.expert_parallel_degree 2` diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/__init__.py b/torchtitan/experiments/autoparallel/deepseek_v3/__init__.py new file mode 100644 index 0000000000..b90583c86b --- /dev/null +++ b/torchtitan/experiments/autoparallel/deepseek_v3/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args, DeepSeekV3Model +from torchtitan.models.deepseek_v3.model.args import DeepSeekV3ModelArgs +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_train_spec() -> TrainSpec: + model_args = copy.deepcopy(deepseekv3_args) + + default_args = DeepSeekV3ModelArgs() + for config, args in model_args.items(): + if "flex_attn" in config: + continue + + args.attn_type = default_args.attn_type + args.attn_mask_type = default_args.attn_mask_type + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..68adb3c038 --- /dev/null +++ b/torchtitan/experiments/autoparallel/deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,446 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time +import types +from typing import Callable, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Replicate, Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims +from torchtitan.models.moe.moe import _run_experts_grouped_mm + +from torchtitan.tools.logging import logger + + +def create_functional_router_forward( + self: nn.Module, +) -> Callable: # TokenChoiceTopKRouter + def functional_router_forward( + x: torch.Tensor, gate_weight: torch.nn.Parameter, expert_bias: torch.Tensor + ): + # scores shape (bs*slen, num_experts) + scores = F.linear(x, gate_weight) + + # By default, sigmoid or softmax is performed in float32 to avoid loss explosion + if self.score_func == "sigmoid": + scores = torch.sigmoid(scores.to(torch.float32)) + elif self.score_func == "softmax": + scores = F.softmax(scores.to(torch.float32), dim=1) + else: + raise NotImplementedError(f"Unknown score function {self.score_func}") + + # top scores shape (bs*slen, top_k) + # NOTE: The expert_bias is only used for routing. The gating value + # top_scores is still derived from the original scores. + if expert_bias is not None: + _, selected_experts_indices = torch.topk( + scores + expert_bias, k=self.top_k, dim=1 + ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) + else: + top_scores, selected_experts_indices = torch.topk( + scores, k=self.top_k, dim=1 + ) + + # debug override: balanced round-robin routing + if self._debug_force_load_balance: + ( + selected_experts_indices, + top_scores, + ) = self._debug_force_load_balance_routing(scores) + + if self.route_norm: + denominator = top_scores.sum(dim=-1, keepdim=True) + 1e-20 + top_scores = top_scores / denominator + top_scores = top_scores * self.route_scale + + # group tokens together by expert indices from 0 to num_experts and pass that to experts forward + num_tokens_per_expert = torch.histc( + selected_experts_indices.view(-1), + bins=self.num_experts, + min=0, + max=self.num_experts, + ) + + return top_scores, selected_experts_indices, num_tokens_per_expert + + return functional_router_forward + + +def _moe_forward( + x: torch.Tensor, + router_gate_weight: torch.nn.Parameter, + expert_bias: Optional[torch.Tensor], + experts_w1: torch.Tensor, + experts_w3: torch.Tensor, + experts_w2: torch.Tensor, + shared_w1_weight: torch.Tensor, + shared_w3_weight: torch.Tensor, + shared_w2_weight: torch.Tensor, + functional_router_forward: Callable, + reorderer: nn.Module, # TokenReorderer + top_k: int, +): + bs, slen, dim = x.shape + x = x.view(-1, dim) + + # top_scores and selected_experts_indices shape (bs*slen, top_k) + # num_tokens_per_expert shape (num_experts,) + ( + top_scores, + selected_experts_indices, + num_tokens_per_expert, + ) = functional_router_forward(x, router_gate_weight, expert_bias) + num_tokens_per_expert_update = num_tokens_per_expert + + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + routed_input = x[token_indices_experts_sorted // top_k] + + # DSv3 score_before_experts is always False + # if score_before_experts: + # routed_input = ( + # routed_input.to(torch.float32) * top_scores_experts_sorted.reshape(-1, 1) + # ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + # routed_output = experts(routed_input, num_tokens_per_expert) + routed_output = _run_experts_grouped_mm( + experts_w1, experts_w2, experts_w3, routed_input, num_tokens_per_expert + ) + + # always has shared expert + # Note: we execute the shared expert before scoring the output of the routed expert + # to "implicitly" overlap the shared expert compute with token combine communication + _h1 = F.linear(x, shared_w1_weight) + _h3 = F.linear(x, shared_w3_weight) + out = F.linear(F.silu(_h1) * _h3, shared_w2_weight) + + # Unsort routed outputs + routed_output_unsorted = torch.zeros( + (bs * slen * top_k, dim), + dtype=routed_output.dtype, + device=routed_output.device, + ) + routed_output_unsorted[token_indices_experts_sorted] = routed_output + routed_output_unsorted = routed_output_unsorted.reshape(-1, top_k, dim) + # DSv3 score_before_experts is False + # if not self.score_before_experts: + out_experts = ( + torch.bmm( + top_scores.reshape(-1, 1, top_k), + routed_output_unsorted.float(), + ) + .to(x.dtype) + .squeeze(1) + ) + # else: + # out_experts = routed_output_unsorted.sum(dim=1) + + # always has shared experts + # if out is None: + return (out + out_experts).reshape(bs, slen, dim), num_tokens_per_expert_update + + +def moe_forward(self, x: torch.Tensor) -> torch.Tensor: + functional_router_forward = create_functional_router_forward(self.router) + out, num_tokens_per_expert = _moe_forward( + x, + self.router.gate.weight, + self.expert_bias, + self.experts.w1, + self.experts.w3, + self.experts.w2, + self.shared_experts.w1.weight, + self.shared_experts.w3.weight, + self.shared_experts.w2.weight, + functional_router_forward, + self.reorderer, + self.router.top_k, + ) + # HOPs don't support buffer mutations, keep this outside + # tokens_per_expert will be used to update the expert bias for load balancing. + # and also to count the expert usage + # TODO: Activation Checkpointing has the side effect of double counting tokens_per_expert -- + # first in the forward pass, and then in the backward pass. However, this has no + # effect on the expert bias update thanks to the torch.sign() operator. + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + return out + + +def monkey_patch_checks(moe): + # causes data-dependent issue, hardcoded into monkey patch + assert not moe.score_before_experts + assert moe.router.gate.bias is None + assert moe.experts.use_grouped_mm + assert moe.shared_experts is not None + assert moe.shared_experts.w1.bias is None + assert moe.shared_experts.w2.bias is None + assert moe.shared_experts.w3.bias is None + assert not list(moe.reorderer.parameters()) + assert not list(moe.reorderer.buffers()) + + +def monkey_patch_local_map_moe(model, sparse_mesh): + """ + TODO: fix HOPs not restoring the original signature. + TODO: fix tracing with local shapes so that we can use Shard placements + + Current HOP signature we get: + """ + from torch.distributed._tensor.experimental import local_map + + # from torchtitan.models.moe import moe + global _moe_forward + _moe_forward = local_map( + _moe_forward, + out_placements=( + (Replicate(),), # out: torch.Tensor + (Replicate(),), # num_tokens_per_expert_update: torch.Tensor + ), + in_placements=( + (Replicate(),), # x: torch.Tensor, + (Replicate(),), # router_gate_weight: torch.nn.Parameter, + (Replicate(),), # expert_bias: Optional[torch.Tensor], + (Replicate(),), # experts_w1: torch.Tensor, + (Replicate(),), # experts_w3: torch.Tensor, + (Replicate(),), # experts_w2: torch.Tensor, + (Replicate(),), # shared_w1: torch.Tensor, + (Replicate(),), # shared_w3: torch.Tensor, + (Replicate(),), # shared_w2: torch.Tensor, + None, # functional_router_forward: Callable, + None, # reorderer: TokenReorderer, + None, # top_k + ), + redistribute_inputs=True, + in_grad_placements=None, + device_mesh=sparse_mesh, + ) + + for block in model.layers.children(): + if not block.moe_enabled: + continue + block.moe.forward = types.MethodType(moe_forward, block.moe) + monkey_patch_checks(block.moe) + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + sparse_names = ["dp_replicate", "efsdp", "ep", "etp"] + sparse_names = [ + name + for name in sparse_names + if parallel_dims.get_optional_mesh(name) is not None + ] + sparse_mesh = parallel_dims.get_mesh(sparse_names) + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + # apply local_map to MoE + monkey_patch_local_map_moe(model, sparse_mesh) + + # torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + # lambda bucket_idx: 500 / parallel_dims.tp + # ) + # torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + # lambda bucket_idx: 1000 / parallel_dims.tp + # ) + + # if job_config.experimental.autop_force_bf16: + # logger.info("Forcing bf16 on model") + # model = model.bfloat16() + + # param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + # reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + # mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + mp_policy = None + with AutoParallel( + model, + input_fn, + sparse_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "efsdp": Shard(0), + "ep": Shard(0), + "etp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "efsdp": Shard(0), + "etp": Shard(2), + } + assert all( + name in possible_input_shardings for name in sparse_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in sparse_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in sparse_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, sparse_mesh["etp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/autoparallel/job_config.py b/torchtitan/experiments/autoparallel/job_config.py new file mode 100644 index 0000000000..b481318562 --- /dev/null +++ b/torchtitan/experiments/autoparallel/job_config.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +""" +Use --job.custom_config_module=torchtitan.experiments.autoparallel.job_config +""" + + +@dataclass +class Experimental: + # "aten" (default), "inductor", "none" + comms_bucket_reorder_strategy: str = "aten" + + autop_force_bf16: bool = False + + +@dataclass +class JobConfig: + experimental: Experimental = field(default_factory=Experimental) diff --git a/torchtitan/experiments/autoparallel/llama3/__init__.py b/torchtitan/experiments/autoparallel/llama3/__init__.py new file mode 100644 index 0000000000..ea38ac631a --- /dev/null +++ b/torchtitan/experiments/autoparallel/llama3/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.components.validate import build_validator +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.llama3 import llama3_args, Transformer +from torchtitan.models.llama3.model.state_dict_adapter import Llama3StateDictAdapter +from torchtitan.protocols.train_spec import TrainSpec + +from .parallelize_llama import parallelize_llama + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=Transformer, + model_args=llama3_args, + parallelize_fn=parallelize_llama, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + build_validator_fn=build_validator, + state_dict_adapter=Llama3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py new file mode 100644 index 0000000000..27149f67f0 --- /dev/null +++ b/torchtitan/experiments/autoparallel/llama3/parallelize_llama.py @@ -0,0 +1,157 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch + +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.fsdp import MixedPrecisionPolicy +from torch.distributed.tensor.placement_types import Replicate, Shard + +from torchtitan.config import JobConfig, TORCH_DTYPE_MAP +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +def parallelize_llama( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + dense_names = ["dp_replicate", "fsdp", "tp"] + dense_names = [ + name + for name in dense_names + if parallel_dims.get_optional_mesh(name) is not None + ] + dense_mesh = parallel_dims.get_mesh(dense_names) + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + # job_config.training.vocab_size, + model.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + # TODO make autop work correctly with different combinations of DP, DP+TP, TP, and support DDP / HSDP + assert parallel_dims.dp_replicate_enabled is False, "DDP not supported yet" + assert parallel_dims.cp_enabled is False, "CP not supported yet" + assert parallel_dims.pp_enabled is False, "PP not supported yet" + + torch._inductor.config.bucket_all_gathers_fx_bucket_size_determinator = ( + lambda bucket_idx: 500 / parallel_dims.tp + ) + torch._inductor.config.bucket_reduce_scatters_fx_bucket_size_determinator = ( + lambda bucket_idx: 1000 / parallel_dims.tp + ) + + # bail out + # model = model_fn() + # return model + if job_config.experimental.autop_force_bf16: + logger.info("Forcing bf16 on model") + model = model.bfloat16() + + param_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_param] + reduce_dtype = TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce] + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + with AutoParallel( + model, + input_fn, + dense_mesh, + mp_policy=mp_policy, + compile=job_config.compile, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + possible_input_shardings = { + # maps relative to mesh dim names used in torchtitan + "dp_replicate": Shard(0), + "fsdp": Shard(0), + "tp": Replicate(), + } + # only used if loss parallel is enabled + possible_output_shardings = { + # maps relative to mesh dim names used in torchtitan + "fsdp": Shard(0), + "tp": Shard(2), + } + assert all( + name in possible_input_shardings for name in dense_mesh.mesh_dim_names + ), f"Unsupported mesh dim in world mesh, only {possible_input_shardings.keys()} are supported by AutoParallel" + x_sharding = tuple( + possible_input_shardings[name] for name in dense_mesh.mesh_dim_names + ) + out_sharding = x_sharding + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + if loss_parallel_enabled: + out_sharding = tuple( + possible_output_shardings[name] + for name in dense_mesh.mesh_dim_names + if name != "dp_replicate" + ) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([out_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement(verbose=False) + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, dense_mesh["tp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + return parallel_mod diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py new file mode 100644 index 0000000000..fdd8435ebc --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/__init__.py @@ -0,0 +1,57 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +import copy + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.hf_datasets.text_datasets import build_text_dataloader + +from torchtitan.models.deepseek_v3 import deepseekv3_args +from torchtitan.models.deepseek_v3.model.state_dict_adapter import ( + DeepSeekV3StateDictAdapter, +) +from torchtitan.protocols.train_spec import TrainSpec + +from .args import DeepSeekV3ModelArgs, get_sample_config + +from .model import DeepSeekV3Model +from .parallelize_deepseekv3 import parallelize_deepseekv3 + + +def get_model_args() -> DeepSeekV3ModelArgs: + model_args = copy.deepcopy(deepseekv3_args) + # TODO: Align configs between AP and Titan + for config in model_args.keys(): + # Just override the configs + override = get_sample_config() + override.update_from_config = model_args[config].update_from_config + override.get_nparams_and_flops = model_args[config].get_nparams_and_flops + model_args[config] = override + + return model_args + + +def get_train_spec() -> TrainSpec: + model_args = get_model_args() + + return TrainSpec( + model_cls=DeepSeekV3Model, + model_args=model_args, + parallelize_fn=parallelize_deepseekv3, + pipelining_fn=pipeline_llm, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=DeepSeekV3StateDictAdapter, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py new file mode 100644 index 0000000000..7f1f84f45a --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/args.py @@ -0,0 +1,49 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from dataclasses import dataclass + +from autoparallel._testing.models.dsv3 import ( + DeepSeekV3ModelArgs as _DeepSeekV3ModelArgs, + MoEArgs as _MoEArgs, +) +from torchtitan.protocols.model import BaseModelArgs + + +# Need to share same base class with torchtitan models +@dataclass +class DeepSeekV3ModelArgs(_DeepSeekV3ModelArgs, BaseModelArgs): + pass + + +def get_sample_config() -> DeepSeekV3ModelArgs: + return DeepSeekV3ModelArgs( + vocab_size=2048, + max_seq_len=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=4, + n_dense_layers=0, + n_heads=16, + moe_args=_MoEArgs( + num_experts=4, + num_shared_experts=2, + top_k=2, + score_func="softmax", + route_norm=False, + score_before_experts=False, + mesh=None, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py new file mode 100644 index 0000000000..f4915fb708 --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/model.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +# +# Copyright (c) Meta Platforms, Inc. All Rights Reserved. + +from autoparallel._testing.models.dsv3 import DeepSeekV3Model as _DeepSeekV3Model +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import DeepSeekV3ModelArgs + + +# Need to share same base class with torchtitan models +class DeepSeekV3Model(_DeepSeekV3Model, ModelProtocol): + def __init__(self, model_args: DeepSeekV3ModelArgs): + super().__init__(model_args) diff --git a/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py new file mode 100644 index 0000000000..5db38e841f --- /dev/null +++ b/torchtitan/experiments/autoparallel/local_map_deepseek_v3/parallelize_deepseekv3.py @@ -0,0 +1,186 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import time + +import torch +from autoparallel.api import AutoParallel +from autoparallel.auto_bucketing import configure_inductor_for_autobucketing + +from torch.distributed.tensor.placement_types import Shard +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims + +from torchtitan.tools.logging import logger + + +# TODO: Autoparallel should transparently wrap the original nn.Module +# but I don't know how to do that. +def set_torchtitan_fields(orig, new): + assert isinstance(new.layers, torch.nn.ModuleDict) + for block in new.layers.values(): + block.moe_enabled = hasattr(block, "moe") + + +def parallelize_deepseekv3( + model, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply Autoparallel to the model + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + + # TODO(whc) + # I do this because otherwise sometimes inductor will skip re-running passes like comms reordering + torch._inductor.config.force_disable_caches = True + # this is necessary for working with reordering passes. Just leave it set for all the jobs for now. + torch._inductor.config.allow_buffer_reuse = False + + # allow configuring inductor comms optimizations from torchtitan commandline + configure_inductor_for_autobucketing( + job_config.experimental.comms_bucket_reorder_strategy + ) + + # Build the sparse mesh for MoE expert parallelism + # Filter to only include enabled mesh dimensions + sparse_names = ["dp_replicate", "efsdp", "ep", "etp"] + sparse_names = [ + name + for name in sparse_names + if parallel_dims.get_optional_mesh(name) is not None + ] + sparse_mesh = parallel_dims.get_mesh(sparse_names) + + # Update me when changing dsv3.py + assert sparse_mesh.ndim == 2, "AP dsv3.py's local_map is specialized on 2 dims" + + # Provide AP MoE with mesh + for layer in model.layers.values(): + if layer.moe_enabled: + layer.moe.mesh = sparse_mesh + layer.moe.axis_name = "ep" + + def input_fn(): + global_batch_size = job_config.training.global_batch_size + if global_batch_size < 0: + # This global batch size results in 1 gradient accumulation + # step. + dp_degree = parallel_dims.dp_replicate * parallel_dims.dp_shard + global_batch_size = job_config.training.local_batch_size * dp_degree + return ( + torch.randint( + 0, + model.model_args.vocab_size, + (global_batch_size, job_config.training.seq_len), + device=torch.device("cuda"), + ), + ) + + should_compile = job_config.compile.enable + if should_compile: + # TODO: support more options in AP API + assert job_config.compile.components == ["model"] + assert job_config.compile.backend == "inductor" + + mp_policy = None + with AutoParallel( + model, + input_fn, + sparse_mesh, + mp_policy=mp_policy, + compile=should_compile, + dynamic=True, + ) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) + + x_sharding = (Shard(0), Shard(0)) + loss_parallel_enabled = ( + parallel_dims.tp_enabled + and not job_config.parallelism.disable_loss_parallel + ) + assert not loss_parallel_enabled + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + t0 = time.time() + sharding_placement = autop.optimize_placement() + t1 = time.time() + logger.info(f"AutoParallel took {t1 - t0} seconds") + parallel_mod = autop.apply_placement(sharding_placement) + + set_torchtitan_fields(model, parallel_mod) + + if loss_parallel_enabled: + + # current PyTorch's implementation of loss parallel assumes + # that the DTensor has a 1d device mesh. This is not true + # in our case, but we can work around it by adding + # casting the output to a DTensor on a 1d device mesh. + # We should just use AutoParallel to do this for us, but + # it would require putting the loss inside the model as well + def _return_as_dtensor_for_loss_parallel(module, args, output): + return torch.distributed.tensor.DTensor.from_local( + output, sparse_mesh["etp"], (Shard(2),) + ) + + # not keeping a reference to the hook, don't plan on + # removing it at any point + parallel_mod.register_forward_hook(_return_as_dtensor_for_loss_parallel) + + _preserve_moe_attributes(model, parallel_mod) + + return parallel_mod + + +def _preserve_moe_attributes(original_model, parallel_model): + """ + Preserve MoE custom attributes from the original model to the parallel model. + This is only needed for attributes that aren't used in the graph, so they aren't + lifted as graph inputs and fetched by the pre-graph runtime wrapper. + + `moe_enabled` and `load_balance_coeff` are used later in the optimizer to identify + this block as a moe block. This should be safe as they are read-only. + """ + + def get_moe_modules(model): + """Extract all MoE modules from the model.""" + moe_modules = [] + if hasattr(model, "layers"): + if isinstance(model.layers, torch.nn.ModuleDict): + # regular torchtitan structure + blocks = model.layers.values() + else: + # autoparallel might change structure + blocks = ( + model.layers.children() if hasattr(model.layers, "children") else [] + ) + + for block in blocks: + if ( + hasattr(block, "moe_enabled") + and block.moe_enabled + and hasattr(block, "moe") + ): + moe_modules.append(block.moe) + elif hasattr(block, "moe"): # fallback for autoparallel + moe_modules.append(block.moe) + return moe_modules + + original_moe_modules = get_moe_modules(original_model) + parallel_moe_modules = get_moe_modules(parallel_model) + + # Copy custom attributes from original to parallel MoE modules + # This is fine to do since these attributes are read only + for orig_moe, par_moe in zip(original_moe_modules, parallel_moe_modules): + if hasattr(orig_moe, "moe_enabled"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff + + # Copy load_balance_coeff + if hasattr(orig_moe, "load_balance_coeff"): + par_moe.load_balance_coeff = orig_moe.load_balance_coeff diff --git a/torchtitan/experiments/autoparallel/tests/__init__.py b/torchtitan/experiments/autoparallel/tests/__init__.py new file mode 100644 index 0000000000..2e41cd717f --- /dev/null +++ b/torchtitan/experiments/autoparallel/tests/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. diff --git a/torchtitan/experiments/autoparallel/tests/integration_tests.py b/torchtitan/experiments/autoparallel/tests/integration_tests.py new file mode 100644 index 0000000000..8425d23254 --- /dev/null +++ b/torchtitan/experiments/autoparallel/tests/integration_tests.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_autoparallel_test_list() -> list[OverrideDefinitions]: + """ + returns a list of OverrideDefinitions that is used to generate + variations of integration tests based on the same root config file. + """ + integration_tests_flavors = [ + # llama3 tests + OverrideDefinitions( + [ + [ + "--model.name autoparallel.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.autoparallel.job_config", + ], + ], + "llama3 AutoParallel FSDP+TP", + "llama3_autoparallel_fsdp_tp", + ngpu=4, + ), + # TODO: Re-enable this once we fix the test + # deepseek_v3 tests + # OverrideDefinitions( + # [ + # [ + # "--model.name autoparallel.deepseek_v3", + # "--parallelism.data_parallel_shard_degree 2", + # "--parallelism.expert_parallel_degree 2", + # "--job.custom_config_module=torchtitan.experiments.autoparallel.job_config", + # "--activation_checkpoint.mode none", + # ], + # ], + # "deepseek_v3 AutoParallel FSDP+TP+EP", + # "deepseekv3_autoparallel_fsdp_tp_ep", + # ngpu=4, + # ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "autoparallel": build_autoparallel_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["autoparallel"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/compiler_toolkit/README.md b/torchtitan/experiments/compiler_toolkit/README.md index a75b3a17b2..9fc6660245 100644 --- a/torchtitan/experiments/compiler_toolkit/README.md +++ b/torchtitan/experiments/compiler_toolkit/README.md @@ -14,22 +14,56 @@ Joint Graph based Training Prototype: **SimpleFSDP + TP + EP** ```shell -NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none ``` **SimpleFSDP + TP + EP + FlexAttention** ```shell -NGPU=4 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn +NGPU=4 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.deepseek_v3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=2 --parallelism.expert_parallel_degree=2 --activation_checkpoint.mode none --model.flavor=debugmodel_flex_attn ``` ## llama3 **SimpleFSDP + TP** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 +``` + +**SimpleFSDP + TP + auto-bucketing** +```shell +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering +``` + +**SimpleFSDP + TP + transformer-block-bucketing** +```shell +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing ``` **SimpleFSDP + TP + FlexAttention** ```shell -NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --model.flavor=debugmodel_flex_attn +``` + +**SimpleFSDP + TP + FlexAttention + auto-bucketing + regional-inductor** + +```shell +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes autobucketing_reordering,regional_inductor --model.flavor=debugmodel_flex_attn +``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor** + +```shell +NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor --model.flavor=debugmodel_flex_attn +``` + +**SimpleFSDP + TP + FlexAttention + transformer-block-bucketing + regional-inductor + cudagraph** + +```shell +NCCL_GRAPH_REGISTER=0 NGPU=8 TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.passes transformer_block_bucketing,regional_inductor,cudagraph --model.flavor=debugmodel_flex_attn +``` + +**SimpleFSDP + TP + Full Inductor compilation** + +```shell +NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml TRAIN_FILE=torchtitan.experiments.compiler_toolkit.train ./run_train.sh --model.name $MODEL_NAME compiler_toolkit.llama3 --parallelism.data_parallel_shard_degree=2 --parallelism.tensor_parallel_degree=4 --job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config --compile.joint_passes inductor_decomposition --compile.passes full_inductor_compilation ``` diff --git a/torchtitan/experiments/compiler_toolkit/common_utils.py b/torchtitan/experiments/compiler_toolkit/common_utils.py index 965e027bdb..2b2a1f5244 100644 --- a/torchtitan/experiments/compiler_toolkit/common_utils.py +++ b/torchtitan/experiments/compiler_toolkit/common_utils.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from contextlib import contextmanager +from typing import Callable import torch from torch.distributed.tensor import DTensor, Replicate @@ -24,10 +25,12 @@ def disable_compile(job_config: JobConfig): job_config.compile.enable = original_value -def parallelize_inputs(world_mesh, args, kwargs): +def parallelize_inputs(parallel_dims, args, kwargs): def to_dtensor(tensor): if isinstance(tensor, torch.Tensor): - return DTensor.from_local(tensor, world_mesh["tp"], [Replicate()]) + return DTensor.from_local( + tensor, parallel_dims.get_mesh("tp"), [Replicate()] + ) return tensor dt_args = tree_map(to_dtensor, args) @@ -53,3 +56,11 @@ def register_blockmask_pytree_node(): flatten_with_keys_fn=BlockMask._flatten_with_keys, serialized_type_name="torch.nn.attention.flex_attention.BlockMask", ) + + +def end_with_pass(passes: list[Callable], names: list[str]) -> bool: + return ( + len(passes) > 0 + and (last_pass_name := getattr(passes[-1], "__name__", None)) + and (last_pass_name in names) + ) diff --git a/torchtitan/experiments/compiler_toolkit/cudagraph.py b/torchtitan/experiments/compiler_toolkit/cudagraph.py new file mode 100644 index 0000000000..d008e5d455 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/cudagraph.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +CUDAGraph pass for the compiler toolkit. + +This module provides a cudagraph pass that can be applied to graph modules +during compilation. +""" + +import warnings +from typing import Any, Callable, Optional, Sequence + +import torch +from torch._inductor.cudagraph_trees import _use_cuda_memory_pool_manager +from torch.utils._ordered_set import OrderedSet + + +def init_global_graph_pool() -> tuple[ + torch.cuda.CUDAGraph, torch.cuda._POOL_HANDLE, torch.cuda.Stream +]: + dummy_graph = torch.cuda.CUDAGraph() + + # create a global cudagraph memory pool to allow memory reuse across cudagraphs. + graph_pool = torch.cuda.graph_pool_handle() + + # create a global cuda stream for graph capture. we need to use a single stream + # for all allocations to the memory pool, otherwise the allocations to separate streams + # will not be used. + graph_capture_stream = torch.cuda.Stream() + + # use a dummy graph to keep the global graph pool alive + with ( + # suppress an empty cudagraph warning, since we intentionally create + # an empty cudagraph here + warnings.catch_warnings(record=True), + torch.cuda.graph( + dummy_graph, + pool=graph_pool, + stream=graph_capture_stream, + capture_error_mode="thread_local", + ), + ): + pass + + return dummy_graph, graph_pool, graph_capture_stream + + +( + _global_dummy_graph, + _global_graph_pool, + _global_graph_capture_stream, +) = init_global_graph_pool() + + +class CUDAGraphWrapper: + def __init__( + self, + runnable: Callable, + example_inputs: Sequence[Any], + static_input_indices: Optional[tuple[int]] = None, + should_check_address: bool = False, + ): + self.runnable = runnable + self.graph_pool = _global_graph_pool + self.stream = _global_graph_capture_stream + self.static_input_indices = OrderedSet( + static_input_indices if static_input_indices is not None else [] + ) + self.input_indices_to_copy = [ + i + for i, inp in enumerate(example_inputs) + if isinstance(inp, torch.Tensor) and i not in self.static_input_indices + ] + self.cudagraph: Optional[torch.cuda.CUDAGraph] = None + self.has_warmup = False + + self.args = None + self.output = None + + # (debug only) whether check static input tensor addresses during runtime + self.should_check_address = should_check_address + + def copy_non_static_inputs(self, *args): + for i in self.input_indices_to_copy: + self.args[i].copy_(args[i]) + + def check_input_types(self, inputs) -> None: + for inp in inputs: + assert isinstance(inp, (torch.Tensor, int, torch._C.Generator)), ( + "args must be tensor, integer (for dynamic shapes), " + "or Generator (for random number generator), " + f"but found {type(inp)}" + ) + + def check_static_inputs_address(self) -> None: + for i in self.static_input_indices: + actual = self.args[i].data_ptr() + expected = self.input_addresses[i] + assert expected == actual, ( + "Expected the same static tensor address but found " + f"{expected} != {actual}" + ) + + def __call__(self, *args): + if not self.has_warmup: + self.has_warmup = True + device = torch.cuda.current_device() + + # warmup in cudagraph memory pool to avoid fragmentation + # across eager memory pool and cudagraph memory pool. + with _use_cuda_memory_pool_manager(device, self.graph_pool, self.stream): + out = self.runnable(*args) + return out + + if self.cudagraph is None: + self.check_input_types(args) + self.args = args + self.input_addresses = [ + x.data_ptr() if isinstance(x, torch.Tensor) else None for x in args + ] + + self.cudagraph = torch.cuda.CUDAGraph() + + with torch.cuda.graph( + self.cudagraph, pool=self.graph_pool, stream=self.stream + ): + # `output` is managed by pytorch's cudagraph pool + self.output = self.runnable(*args) + + if self.should_check_address: + self.check_static_inputs_address() + + self.copy_non_static_inputs(*args) + self.cudagraph.replay() + return self.output + + +def get_static_input_indices(gm: torch.fx.GraphModule, is_forward: bool) -> list[int]: + """ + Get indices of gm inputs that are static input tensors whose tensor addresses do not + change across runs. Example of static input tensors include weights, buffers, and + outputs of previous cudagraph wrapped functions. + """ + from torch._inductor.utils import count_tangents + + static_input_indices = [] + if ( + is_forward + and (tracing_context := torch._guards.TracingContext.try_get()) + and hasattr(tracing_context, "fw_metadata") + ): + # for forward, we rely on graph capture (i.e., dynamo or export) to provide + # the correct static input indices stored in tracing context. Typical examples + # include weights and buffers. + static_input_indices = tracing_context.fw_metadata.static_input_indices + + elif not is_forward: + # for backward, we identify saved tensors as static inputs, since saved tensors + # are outputs of cudagraph-wrapped forward run. In PT2-generated backward gm, + # saved tensors are always the leading args. So we can get the number of saved + # tensors and generate static input indices. + fixed = count_tangents(gm) + static_input_indices = list(range(fixed)) + + return static_input_indices diff --git a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py index 5c8ffb45c5..13e6689563 100644 --- a/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/deepseek_v3/parallelize.py @@ -21,33 +21,15 @@ from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, + get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.deepseek_v3.parallelize import ( parallelize_deepseekv3 as simple_fsdp_parallelize_deepseekv3, ) -from torchtitan.tools.logging import logger - - -def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) - - # TODO: regional_inductor should work with deepseek_v3 - # gm = regional_inductor(gm, example_inputs) - - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) - return gm - - -def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) - - -def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) def annotate_deepseekv3() -> None: @@ -75,7 +57,17 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + """ + Parallelize and compile a DeepSeek v3 model with optional custom compiler passes. + + Args: + model: The model to parallelize + parallel_dims: Parallel dimensions configuration + job_config: Job configuration + Returns: + CompiledModule wrapping the parallelized and compiled model + """ annotate_deepseekv3() register_blockmask_pytree_node() @@ -84,11 +76,25 @@ def parallelize_deepseekv3( with disable_compile(job_config): model = simple_fsdp_parallelize_deepseekv3(model, parallel_dims, job_config) + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + + # Get compiler passes from config + compiler_passes = get_compiler_passes_from_config(model, job_config) + + # Create compilers with specified passes + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) + + # Create custom joint_graph_builder with deepseekv3-specific compilers deepseekv3_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=None, + joint_custom_passes=joint_custom_passes, + dump_folder=job_config.job.dump_folder, + job_config=job_config, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/graph_utils.py b/torchtitan/experiments/compiler_toolkit/graph_utils.py index 4ff6c8187b..64dc03c312 100644 --- a/torchtitan/experiments/compiler_toolkit/graph_utils.py +++ b/torchtitan/experiments/compiler_toolkit/graph_utils.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import contextlib -from typing import Callable, Optional +import functools +from pathlib import Path +from typing import Any, Callable, List, Optional import torch from torch._dynamo.functional_export import dynamo_graph_capture_for_export @@ -16,26 +18,36 @@ ) from torch._guards import tracing, TracingContext from torch.distributed.tensor import DTensor +from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims +from torchtitan.experiments.compiler_toolkit.common_utils import end_with_pass from torchtitan.tools.logging import logger -def _clear_traced_params_buffers( - traced_module: torch.fx.GraphModule, const_keys: list[str] -) -> None: - """Remove all parameters and buffers from traced module before restoring.""" - for key in const_keys: - assert key in traced_module._buffers.keys() - # We don't want constants to show up as a buffer in the state dict. - # Instead they should just be a direct attribute. - buffer = getattr(traced_module, key) - torch.fx.graph_module._del_attr(traced_module, key) - setattr(traced_module, key, buffer) +def _dump_gm(dump_folder: str | None, gm: torch.fx.GraphModule, name: str) -> None: + # TODO: make the dump rank configurable + if not dump_folder or torch.distributed.get_rank() != 0: + return + + output_path = Path(dump_folder) / "compiler" / f"{name}.txt" + output_path.parent.mkdir(parents=True, exist_ok=True) + output_path.write_text( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) def export_joint( - model, args, kwargs=None + model, args, kwargs=None, dump_folder: str | None = None ) -> tuple[JointWithDescriptors, TracingContext]: + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + dump_folder: Optional folder to dump the graph to + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -47,8 +59,14 @@ def export_joint( torch.fx.traceback.preserve_node_meta(), ): gm = dynamo_graph_capture_for_export(model)(*args, **kwargs) - logger.info("Dynamo gm:") - logger.info(gm.print_readable(print_output=False)) + logger.debug("Dynamo gm:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, "dynamo_gm") + tracing_context = gm.meta["tracing_context"] with tracing(tracing_context): @@ -59,6 +77,14 @@ def export_joint( def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): + """ + Export joint forward-backward graph with AOT Autograd. + + Args: + model: The model to export + args: Tuple of input arguments + kwargs: Dict of keyword arguments for the model + """ if kwargs is None: kwargs = {} assert isinstance(args, tuple) @@ -70,6 +96,7 @@ def aot_export_joint_with_descriptors_alone(model, args, kwargs=None): args, kwargs, ) + return joint_with_descriptors @@ -79,7 +106,9 @@ def joint_graph_builder( model_kwargs: dict, fw_compiler: Optional[Callable] = None, bw_compiler: Optional[Callable] = None, - joint_custom_pass: Optional[Callable] = None, + joint_custom_passes: Optional[List[Callable]] = None, + dump_folder: str | None = None, + job_config: Optional["JobConfig"] = None, ): """ Build a joint forward-backward graph for the model with optional custom compilers. @@ -90,21 +119,50 @@ def joint_graph_builder( model_kwargs: Dict of model input keyword arguments fw_compiler: Optional custom forward compiler function bw_compiler: Optional custom backward compiler function - joint_custom_pass: Optional custom pass to run on the joint graph + joint_custom_passes: list of custom passes to run on the joint graph + dump_folder: Optional folder to dump the graph to + job_config: Job configuration """ assert isinstance(model_args, tuple) - for arg in model_args: - assert isinstance(arg, DTensor) + for idx, arg in enumerate(model_args): + assert isinstance(arg, DTensor), f"Argument {idx} is of type {type(arg)}" # get joint graph - ( - joint_with_descriptors, - tracing_context, - ) = export_joint(model, model_args, model_kwargs) + (joint_with_descriptors, tracing_context,) = export_joint( + model, + model_args, + model_kwargs, + dump_folder=dump_folder, + ) - # Optional validation - if joint_custom_pass is not None: - joint_custom_pass(joint_with_descriptors) + # Check if inductor_decomposition is configured and create the pass with proper context + if job_config is not None: + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + if "inductor_decomposition" in joint_pass_names: + from torchtitan.experiments.compiler_toolkit.passes import ( + inductor_decomposition_pass, + ) + + # Create the decomposition pass with context + decomp_pass = functools.partial( + inductor_decomposition_pass, + model=model, + joint_with_descriptors=joint_with_descriptors, + forward_inputs=model_args, + tracing_context=tracing_context, + ) + + # Prepend to joint_custom_passes + if joint_custom_passes is None: + joint_custom_passes = [] + joint_custom_passes = [decomp_pass] + joint_custom_passes + + # run custom passes on joint-graph before partitioner + if joint_custom_passes is not None: + for joint_custom_pass in joint_custom_passes: + joint_with_descriptors.graph_module = joint_custom_pass( + joint_with_descriptors.graph_module + ) with tracing(tracing_context): fn = aot_compile_joint_with_descriptors( @@ -165,12 +223,22 @@ def __delattr__(self, name: str) -> None: else: super().__delattr__(name) + def state_dict(self, *args, **kwargs) -> Any: + return self.inner.state_dict(*args, **kwargs) + + def load_state_dict(self, *args, **kwargs) -> Any: + return self.inner.load_state_dict(*args, **kwargs) + + def name_parameters(self, *args, **kwargs) -> Any: + return self.inner.named_parameters(*args, **kwargs) + + def parameters(self, *args, **kwargs) -> Any: + return self.inner.parameters(*args, **kwargs) + def forward(self, *args, **kwargs): assert "forward" not in self._overrides, "forward cannot be overridden" - dt_args, dt_kwargs = self.parallelize_inputs( - self.parallel_dims.world_mesh, args, kwargs - ) + dt_args, dt_kwargs = self.parallelize_inputs(self.parallel_dims, args, kwargs) if self.joint_graph_module is None: self.joint_graph_module = self.joint_graph_builder( @@ -180,3 +248,263 @@ def forward(self, *args, **kwargs): # calling the line below returns control to torchtitan's runner # letting it call the backward, and optimizer. return self.joint_graph_module(args, kwargs) + + +# Default compiler pass configuration - no passes by default +DEFAULT_COMPILER_PASSES = [] + + +def compiler( + name: str, + gm: torch.fx.GraphModule, + example_inputs, + passes: List[Callable] = None, + dump_folder: str | None = None, + is_forward: bool = True, +): + """ + Compile a graph module by applying a sequence of compiler passes. + + Args: + name: Name for logging purposes + gm: The graph module to compile + example_inputs: Example inputs for the graph module + passes: List of compiler pass functions to apply. Each function should take + (gm, example_inputs) and return a transformed gm. If None, uses + DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump the graph to + """ + if passes is None: + passes = DEFAULT_COMPILER_PASSES + + logger.debug(f"{name} before compiler:") + logger.debug( + gm.print_readable(print_output=False, include_stride=True, include_device=True) + ) + _dump_gm(dump_folder, gm, f"{name}_before_compiler") + + if end_with_pass(passes, ["cudagraph_pass"]): + # cudagraph pass is always the last pass if it is applied + cg_pass = passes[-1] + + # to identify static input indices, cudagraph passes behaves differently for + # forward and backward pass. so we explicitly pass the info. + _cg_pass = functools.partial(cg_pass, is_forward=is_forward) + + # keep the function name for debug log + passes[-1] = functools.wraps(cg_pass)(_cg_pass) + + for pass_fn in passes: + pass_name = ( + pass_fn.func.__name__ + if isinstance(pass_fn, functools.partial) + else pass_fn.__name__ + ) + logger.info(f"Applying pass: {pass_name}") + gm = pass_fn(gm, example_inputs) + + # Only try to print/dump if gm is still a GraphModule + # (compile_fx_inner returns a CompiledFxGraph which doesn't have print_readable) + if hasattr(gm, "print_readable"): + logger.debug(f"{name} after compiler:") + logger.debug( + gm.print_readable( + print_output=False, include_stride=True, include_device=True + ) + ) + _dump_gm(dump_folder, gm, f"{name}_after_compiler") + + return gm + + +def make_compiler_with_passes( + passes: List[Callable] = None, + dump_folder: str | None = None, +): + """ + Create forward and backward compilers with specified passes. + + Args: + passes: List of compiler pass functions to apply. If None, uses DEFAULT_COMPILER_PASSES. + dump_folder: Optional folder to dump graphs + + Returns: + Tuple of (fw_compiler, bw_compiler) functions + """ + + def fw_compiler(gm: torch.fx.GraphModule, example_inputs): + return compiler( + "fwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=True, + ) + + def bw_compiler(gm: torch.fx.GraphModule, example_inputs): + return compiler( + "bwd_gm", + gm, + example_inputs, + passes=passes, + dump_folder=dump_folder, + is_forward=False, + ) + + return fw_compiler, bw_compiler + + +def validate_pass_names(pass_names: list[str], joint_pass_names: list[str]) -> None: + """ + Validate compiler and joint pass names and their dependencies. + + Args: + pass_names: List of compiler pass names + joint_pass_names: List of joint custom pass names + + Raises: + ValueError: If pass configuration is invalid + """ + if "cudagraph" in pass_names: + assert ( + pass_names[-1] == "cudagraph" + ), "cudagraph has to be the last pass to apply" + + if ( + "autobucketing_reordering" in pass_names + and "transformer_block_bucketing" in pass_names + ): + raise ValueError( + "Cannot apply autobucketing_reordering and transformer_block_bucketing at the same time!" + ) + + # Validate that full_inductor_compilation requires inductor_decomposition + if "full_inductor_compilation" in pass_names: + if "inductor_decomposition" not in joint_pass_names: + raise ValueError( + "full_inductor_compilation pass requires inductor_decomposition to be " + "specified in joint_passes. Please add --compile.joint_passes inductor_decomposition" + ) + + +def get_compiler_passes_from_config(model: torch.nn.Module, job_config: JobConfig): + """ + Extract and validate compiler passes from job config. + + Args: + model: The model being compiled + job_config: Job configuration containing compile.passes and compile.joint_passes + + Returns: + List of compiler pass functions + """ + from torchtitan.experiments.compiler_toolkit.passes import AVAILABLE_COMPILER_PASSES + from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( + get_transformer_block_buckets, + ) + + pass_names = getattr(job_config.compile, "passes", []) + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + + validate_pass_names(pass_names, joint_pass_names) + compiler_passes = [] + + # Warn if full Inductor compilation is enabled + if "full_inductor_compilation" in pass_names: + logger.warning( + "Full Inductor compilation is enabled. Note that Inductor may change numerics " + "and does not guarantee bitwise equivalent results compared to eager mode." + ) + + for pass_name in pass_names: + if pass_name not in AVAILABLE_COMPILER_PASSES: + raise ValueError( + f"Unknown compiler pass: {pass_name}. " + f"Available compiler passes: {list(AVAILABLE_COMPILER_PASSES.keys())}" + ) + if pass_name == "transformer_block_bucketing": + compiler_passes.append( + functools.partial( + AVAILABLE_COMPILER_PASSES[pass_name], + fsdp_manual_buckets=get_transformer_block_buckets(model), + ) + ) + else: + compiler_passes.append(AVAILABLE_COMPILER_PASSES[pass_name]) + + if pass_names: + logger.info(f"Using compiler passes from config: {pass_names}") + + return compiler_passes + + +def get_joint_custom_passes_from_config( + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Extract and validate joint custom passes from job config. + + Note: The inductor_decomposition pass is handled separately in joint_graph_builder + because it requires context (model, joint_with_descriptors, etc.) that's only + available at graph capture time. + + Args: + parallel_dims: Parallelism dimensions + job_config: Job configuration containing parallelism.fsdp_reshard_after_forward + and compile.joint_passes + + Returns: + List of joint custom pass functions + """ + from torchtitan.experiments.compiler_toolkit.passes import ( + AVAILABLE_JOINT_PASSES, + fsdp_reshard_after_fwd_pass, + validate_flex_attn_annotation_pass, + ) + + joint_custom_passes = [] + joint_custom_passes.append(validate_flex_attn_annotation_pass) + + # Handle joint passes from config (excluding inductor_decomposition) + joint_pass_names = getattr(job_config.compile, "joint_passes", []) + for pass_name in joint_pass_names: + if pass_name not in AVAILABLE_JOINT_PASSES: + raise ValueError( + f"Unknown joint pass: {pass_name}. " + f"Available joint passes: {list(AVAILABLE_JOINT_PASSES.keys())}" + ) + + # Skip inductor_decomposition - it's handled in joint_graph_builder + if pass_name == "inductor_decomposition": + continue + + joint_custom_passes.append(AVAILABLE_JOINT_PASSES[pass_name]) + + if joint_pass_names: + logger.info(f"Using joint passes from config: {joint_pass_names}") + + # Handle FSDP reshard after forward + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + joint_custom_passes.append( + functools.partial( + fsdp_reshard_after_fwd_pass, + reshard_after_forward=fsdp_reshard_after_forward, + ) + ) + + return joint_custom_passes diff --git a/torchtitan/experiments/compiler_toolkit/job_config.py b/torchtitan/experiments/compiler_toolkit/job_config.py new file mode 100644 index 0000000000..7db461b984 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/job_config.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class Compile: + """ + Compiler configuration for the compiler toolkit workflow. + + - joint_passes: List of joint graph pass names to apply on the joint forward-backward + graph before partitioning. + + Example: --compile.joint_passes inductor_decomposition + + - passes: List of compiler pass names to apply to the partitioned forward/backward graphs. + + Example: --compile.passes full_inductor_compilation + + Note: If "full_inductor_compilation" is specified, "inductor_decomposition" must + be included in joint_passes. + """ + + joint_passes: list[str] = field(default_factory=list) + passes: list[str] = field(default_factory=list) + + +@dataclass +class JobConfig: + compile: Compile = field(default_factory=Compile) diff --git a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py index 0ed8452148..c955dc02f0 100644 --- a/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py +++ b/torchtitan/experiments/compiler_toolkit/llama3/parallelize.py @@ -8,9 +8,6 @@ import functools import torch -from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing - -from torch.fx.passes.regional_inductor import regional_inductor from torch.fx.traceback import annotate_fn from torchtitan.config import JobConfig @@ -23,52 +20,15 @@ from torchtitan.experiments.compiler_toolkit.graph_utils import ( CompiledModule, + get_compiler_passes_from_config, + get_joint_custom_passes_from_config, joint_graph_builder, + make_compiler_with_passes, ) from torchtitan.experiments.simple_fsdp.llama3.parallelize import ( parallelize_llama as simple_fsdp_parallelize_llama, ) -from torchtitan.tools.logging import logger - - -# TODO: support passing configs into schedule_overlap_bucketing -def autobucketing_reordering_pass(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm, collective_bucketing=True) - gm.recompile() - return gm - - -def compiler(name: str, gm: torch.fx.GraphModule, example_inputs): - logger.info(f"{name} before compiler:") - logger.info(gm.print_readable(print_output=False)) - - gm = autobucketing_reordering_pass(gm) - - gm = regional_inductor(gm, example_inputs) - - logger.info(f"{name} after compiler:") - logger.info(gm.print_readable(print_output=False)) - return gm - - -def fw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("fwd_gm", gm, example_inputs) - - -def bw_compiler(gm: torch.fx.GraphModule, example_inputs) -> None: - return compiler("bwd_gm", gm, example_inputs) - - -def validate_flex_attention_annotation(joint_with_descriptors): - """Verify user annotations show up in the graph.""" - for node in joint_with_descriptors.graph_module.graph.nodes: - if node.target in { - torch.ops.higher_order.flex_attention, - torch.ops.higher_order.flex_attention_backward, - }: - assert "compile_with_inductor" in node.meta.get("custom", {}) - def annotate_llama() -> None: from torchtitan.models.attention import FlexAttentionWrapper @@ -84,7 +44,17 @@ def parallelize_llama( parallel_dims: ParallelDims, job_config: JobConfig, ) -> CompiledModule: + """ + Parallelize and compile a Llama model with optional custom compiler passes. + Args: + model: The model to parallelize + parallel_dims: Parallel dimensions configuration + job_config: Job configuration + + Returns: + CompiledModule wrapping the parallelized and compiled model + """ annotate_llama() register_blockmask_pytree_node() @@ -93,12 +63,25 @@ def parallelize_llama( with disable_compile(job_config): model = simple_fsdp_parallelize_llama(model, parallel_dims, job_config) - # Create custom joint_graph_builder with llama-specific compilers and validation + # Get joint custom passes from config + joint_custom_passes = get_joint_custom_passes_from_config(parallel_dims, job_config) + + # Get compiler passes from config + compiler_passes = get_compiler_passes_from_config(model, job_config) + + # Create compilers with specified passes + fw_compiler, bw_compiler = make_compiler_with_passes( + compiler_passes, dump_folder=job_config.job.dump_folder + ) + + # Create custom joint_graph_builder with llama-specific compilers llama_joint_graph_builder = functools.partial( joint_graph_builder, fw_compiler=fw_compiler, bw_compiler=bw_compiler, - joint_custom_pass=validate_flex_attention_annotation, + joint_custom_passes=joint_custom_passes, + dump_folder=job_config.job.dump_folder, + job_config=job_config, ) # TODO: CompiledModule should take sample input as well, so that we can diff --git a/torchtitan/experiments/compiler_toolkit/passes.py b/torchtitan/experiments/compiler_toolkit/passes.py new file mode 100644 index 0000000000..1e7354deff --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/passes.py @@ -0,0 +1,279 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Compiler passes for the compiler toolkit. + +This module provides various compiler passes that can be applied to graph modules +during compilation. Passes can be selected and configured via job config. + +Pass Types: +- Joint custom passes: Applied to the joint forward-backward graph before partitioning +- Compiler passes: Applied to the partitioned forward/backward graphs +""" + +from typing import Any, Sequence + +import torch +from torch._functorch.aot_autograd import JointWithDescriptors +from torch._guards import TracingContext +from torch._inductor.compile_fx import compile_fx_inner +from torch._inductor.fx_passes.overlap_manual_scheduling import manual_overlap_bucketing +from torch._inductor.fx_passes.overlap_scheduling import schedule_overlap_bucketing +from torch.fx.passes.regional_inductor import regional_inductor +from torchtitan.experiments.compiler_toolkit.cudagraph import ( + CUDAGraphWrapper, + get_static_input_indices, +) +from torchtitan.experiments.simple_fsdp.reshard_after_forward import ( + annotate_fsdp_all_gather, +) +from torchtitan.tools.logging import logger + + +def autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs=None +) -> torch.fx.GraphModule: + """ + Apply autobucketing and reordering optimization. + + This pass applies schedule_overlap_bucketing with collective_bucketing enabled + to optimize comm/compute overlap patterns in the graph. + """ + schedule_overlap_bucketing(gm, collective_bucketing=True) + gm.recompile() + return gm + + +def transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs, fsdp_manual_buckets +) -> torch.fx.GraphModule: + """ + Apply aten-level manual bucketing and reordering optimization. + """ + manual_overlap_bucketing( + gm, module_bucket_plans=fsdp_manual_buckets, insert_overlap_deps=False + ) + gm.recompile() + return gm + + +def regional_inductor_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply regional inductor compilation based on user annotation. + """ + return regional_inductor(gm, example_inputs) + + +def cudagraph_pass( + gm: torch.fx.GraphModule, example_inputs: Sequence[Any], is_forward: bool +) -> torch.fx.GraphModule: + """ + Apply cudagraph. + + This pass wraps the forward function with cudagraph during compilation and does + not record cudagraph until runtime. + - For the first run, it will warm up operators such as nccl. + - For the second run, it will record cudagraph and replay cudagraph. + - For the following runs, it will replay cudagraph. + """ + static_input_indices = get_static_input_indices(gm, is_forward) + gm.forward = CUDAGraphWrapper(gm.forward, example_inputs, static_input_indices) + return gm + + +def validate_flex_attn_annotation_pass( + gm: torch.fx.GraphModule, +) -> torch.fx.GraphModule: + """Verify user annotations show up in the graph.""" + for node in gm.graph.nodes: + if node.target in { + torch.ops.higher_order.flex_attention, + torch.ops.higher_order.flex_attention_backward, + }: + assert "compile_with_inductor" in node.meta.get("custom", {}) + return gm + + +# Apply activation checkpointing on joint graph before partitioner +def fsdp_reshard_after_fwd_pass( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, reshard_after_forward) + gm.recompile() + return gm + + +def inductor_decomposition_pass( + gm: torch.fx.GraphModule, + model: torch.nn.Module, + joint_with_descriptors: JointWithDescriptors, + forward_inputs: tuple, + tracing_context: TracingContext, +) -> torch.fx.GraphModule: + """ + Apply Inductor decompositions to the joint graph. + + This pass applies decompositions to the joint forward-backward graph using make_fx. + It unwraps tensor subclasses (like DTensor) and retraces the graph with decompositions + applied, while preserving metadata required by the partitioner. + + Args: + gm: The joint graph module + model: The parallelized model + joint_with_descriptors: The joint graph with descriptors + forward_inputs: Forward input arguments (may be DTensors) + tracing_context: The tracing context from original joint graph capture + + Returns: + The joint graph with decompositions applied + """ + from torch._functorch._aot_autograd.descriptors import DummyAOTInput + from torch._functorch._aot_autograd.subclass_utils import unwrap_tensor_subclasses + from torch._inductor.decomposition import select_decomp_table + from torch.fx.experimental.proxy_tensor import make_fx + + logger.info("Applying decompositions to joint graph") + + decomp_table = select_decomp_table() + + # Get traced tangents metadata + traced_tangents = joint_with_descriptors._aot_state.fw_metadata.traced_tangents + + # Collect all inputs: params, buffers, forward inputs, tangents + param_inputs = list(model.parameters()) + buffer_inputs = list(model.buffers()) + primals = param_inputs + buffer_inputs + list(forward_inputs) + tangents = list(traced_tangents) + + # Create dummy descriptors for unwrapping + primals_descs = [DummyAOTInput(i) for i in range(len(primals))] + tangents_descs = [DummyAOTInput(i + len(primals)) for i in range(len(tangents))] + + # Unwrap tensor subclasses (DTensor -> _local_tensor) + primals_unwrapped, _ = unwrap_tensor_subclasses( + primals, primals_descs, append_symints=False + ) + tangents_unwrapped, _ = unwrap_tensor_subclasses( + tangents, tangents_descs, append_symints=False + ) + + # Verify unwrapped tensor shapes match joint graph placeholders + all_inputs = primals_unwrapped + tangents_unwrapped + placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + + if len(all_inputs) != len(placeholders): + raise RuntimeError( + f"Input count mismatch: {len(all_inputs)} inputs vs {len(placeholders)} placeholders" + ) + + shape_mismatches = [] + for i, (inp, ph) in enumerate(zip(all_inputs, placeholders)): + if hasattr(inp, "shape") and "val" in ph.meta: + expected_shape = ph.meta["val"].shape + actual_shape = inp.shape + if expected_shape != actual_shape: + shape_mismatches.append( + f" {ph.target}: expected {expected_shape}, got {actual_shape}" + ) + + if shape_mismatches: + logger.error(f"Shape mismatches found ({len(shape_mismatches)}):") + for msg in shape_mismatches: + logger.error(msg) + raise RuntimeError( + "Unwrapped tensor shapes don't match joint graph placeholders." + ) + + # Get the FakeTensorMode from the original joint graph + fake_mode = None + for node in gm.graph.nodes: + if node.op == "placeholder" and "val" in node.meta: + val = node.meta["val"] + if hasattr(val, "fake_mode"): + fake_mode = val.fake_mode + break + + if fake_mode is None: + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(primals_unwrapped) + + # Use make_fx with the original fake mode to retrace with decompositions + with fake_mode: + decomposed_gm = make_fx( + gm, + decomposition_table=decomp_table, + _allow_non_fake_inputs=False, + )(primals_unwrapped, tangents_unwrapped) + + # Copy metadata from original placeholders to decomposed placeholders + orig_placeholders = [n for n in gm.graph.nodes if n.op == "placeholder"] + decomp_placeholders = [ + n for n in decomposed_gm.graph.nodes if n.op == "placeholder" + ] + + if len(orig_placeholders) != len(decomp_placeholders): + raise RuntimeError( + f"Placeholder count mismatch: {len(orig_placeholders)} vs {len(decomp_placeholders)}" + ) + + for orig, decomp in zip(orig_placeholders, decomp_placeholders): + # Copy all metadata from original to decomposed + for key, value in orig.meta.items(): + if key not in decomp.meta: + decomp.meta[key] = value + + # Rename decomposed placeholder to match original name + decomp.target = orig.target + decomp.name = orig.name + + decomposed_gm.recompile() + logger.info("Decompositions applied successfully to joint graph") + + return decomposed_gm + + +def full_inductor_compilation_pass( + gm: torch.fx.GraphModule, example_inputs +) -> torch.fx.GraphModule: + """ + Apply full Inductor compilation with code generation. + + This pass uses compile_fx_inner to generate optimized code for the graph. + + Args: + gm: The graph module (forward or backward) + example_inputs: Example inputs for compilation + + Returns: + The compiled graph module + """ + return compile_fx_inner(gm, example_inputs) + + +# Registry mapping pass names to pass functions +AVAILABLE_COMPILER_PASSES = { + "autobucketing_reordering": autobucketing_reordering_pass, + "transformer_block_bucketing": transformer_block_bucketing_reordering_pass, + "regional_inductor": regional_inductor_pass, + "cudagraph": cudagraph_pass, + "full_inductor_compilation": full_inductor_compilation_pass, +} + +# Registry for joint custom passes (applied before partitioning) +AVAILABLE_JOINT_PASSES = { + "inductor_decomposition": inductor_decomposition_pass, + "fsdp_reshard_after_fwd": fsdp_reshard_after_fwd_pass, + "validate_flex_attn_annotation": validate_flex_attn_annotation_pass, +} diff --git a/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py new file mode 100644 index 0000000000..06c1717957 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/scripts/check_numerics.py @@ -0,0 +1,126 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import sys +from pathlib import Path + +# Add parent directory to path to import numerics_utils +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from tests.numerics_utils import run_numerics_test + + +def main(): + parser = argparse.ArgumentParser( + description="Run two training jobs and compare their tensorboard metrics" + ) + parser.add_argument( + "--ngpu", + type=int, + required=True, + help="Number of GPUs to use", + ) + parser.add_argument( + "--config-file", + type=str, + required=True, + help="Path to config file", + ) + parser.add_argument( + "--dp-shard-degree", + type=int, + default=1, + help="Data parallel shard degree", + ) + parser.add_argument( + "--tp-degree", + type=int, + default=1, + help="Tensor parallel degree", + ) + parser.add_argument( + "--cp-degree", + type=int, + default=1, + help="Context parallel degree", + ) + parser.add_argument( + "--ep-degree", + type=int, + default=1, + help="Expert parallel degree", + ) + parser.add_argument( + "--ac-mode", + type=str, + default="selective", + choices=["selective", "none", "full"], + help="Activation checkpoint mode", + ) + parser.add_argument( + "--steps", + type=int, + default=50, + help="Number of training steps", + ) + parser.add_argument( + "--seed", + type=int, + default=42, + help="Random seed for deterministic training", + ) + parser.add_argument( + "--eager-tb-folder", + type=str, + default="tb/eager_run", + help="Tensorboard folder for eager run", + ) + parser.add_argument( + "--compiled-tb-folder", + type=str, + default="tb/compiled_run", + help="Tensorboard folder for compiled run", + ) + parser.add_argument( + "--metrics", + nargs="+", + default=["loss_metrics/global_avg_loss", "grad_norm"], + help="Metrics to compare", + ) + parser.add_argument( + "--passes", + type=str, + default=None, + help=( + "Comma-separated list of compiler passes to apply " + "(e.g., 'autobucketing_reordering' or 'autobucketing_reordering,regional_inductor')" + ), + ) + + args = parser.parse_args() + + success = run_numerics_test( + ngpu=args.ngpu, + config_file=args.config_file, + dp_shard_degree=args.dp_shard_degree, + tp_degree=args.tp_degree, + cp_degree=args.cp_degree, + ep_degree=args.ep_degree, + ac_mode=args.ac_mode, + steps=args.steps, + seed=args.seed, + eager_tb_folder=args.eager_tb_folder, + compiled_tb_folder=args.compiled_tb_folder, + metrics=args.metrics, + passes=args.passes, + ) + + return 0 if success else 1 + + +if __name__ == "__main__": + exit(main()) diff --git a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py index bb64160db2..8053efe3d4 100644 --- a/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py +++ b/torchtitan/experiments/compiler_toolkit/tests/integration_tests.py @@ -24,13 +24,54 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--model.name compiler_toolkit.llama3", "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", - "--activation_checkpoint.mode none", ], ], "llama3 FSDP+TP", "llama3_fsdp_tp", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering", + ], + ], + "llama3 FSDP+TP autobucketing", + "llama3_fsdp_tp_autobucketing", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes transformer_block_bucketing", + ], + ], + "llama3 FSDP+TP manualbucketing", + "llama3_fsdp_tp_manualbucketing", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes cudagraph", + ], + ], + "llama3 FSDP+TP+cudagraph", + "llama3_fsdp_tp_cudagraph", + ngpu=4, + ), OverrideDefinitions( [ [ @@ -38,13 +79,72 @@ def build_compiler_toolkit_test_list() -> list[OverrideDefinitions]: "--parallelism.data_parallel_shard_degree 2", "--parallelism.tensor_parallel_degree 2", "--model.flavor debugmodel_flex_attn", - "--activation_checkpoint.mode none", ], ], "llama3 FSDP+TP+FlexAttn", "llama3_fsdp_tp_flexattn", ngpu=4, ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--model.flavor debugmodel_flex_attn", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes autobucketing_reordering,regional_inductor", + ], + ], + "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor", + "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor", + ngpu=4, + ), + # TODO: enable this when cudagraph is fixed + # OverrideDefinitions( + # [ + # [ + # "--model.name compiler_toolkit.llama3", + # "--parallelism.data_parallel_shard_degree 2", + # "--parallelism.tensor_parallel_degree 2", + # "--model.flavor debugmodel_flex_attn", + # "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + # "--compile.passes autobucketing_reordering,regional_inductor,cudagraph", + # ], + # ], + # "llama3 FSDP+TP+FlexAttn autobucketing regional_inductor+cudagraph", + # "llama3_fsdp_tp_flexattn_autobucketing_regional_inductor_cudagraph", + # ngpu=4, + # ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.joint_passes inductor_decomposition", + "--compile.passes full_inductor_compilation", + ], + ], + "llama3 full_inductor_compilation", + "llama3_full_inductor_compilation", + ngpu=4, + ), + OverrideDefinitions( + [ + [ + "--model.name compiler_toolkit.llama3", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--job.custom_config_module=torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes transformer_block_bucketing,regional_inductor", + ], + ], + "llama3 FSDP+TP+FlexAttn manualbucketing regional_inductor", + "llama3_fsdp_tp_flexattn_manualbucketing_regional_inductor", + ngpu=4, + ), # deepseek_v3 tests OverrideDefinitions( [ diff --git a/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py new file mode 100644 index 0000000000..0d7741b1a2 --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/numerics_utils.py @@ -0,0 +1,270 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""Shared utilities for numerics testing.""" + +import glob +import os +import subprocess + +import torch +from tensorboard.backend.event_processing.event_accumulator import EventAccumulator + + +def load_metrics(event_path, metric_names): + """Load metrics from tensorboard event files.""" + event_acc = EventAccumulator(event_path) + event_acc.Reload() + + metrics = {} + for metric_name in metric_names: + try: + scalars = event_acc.Scalars(metric_name) + metrics[metric_name] = {scalar.step: scalar.value for scalar in scalars} + except KeyError: + print(f"Warning: Metric {metric_name!r} not found in event file") + metrics[metric_name] = {} + + return metrics + + +def compare_metrics(metrics1, metrics2, label1="Eager", label2="Compiled"): + """Compare two sets of metrics and verify bitwise equivalence using torch.equal().""" + + all_metrics = set(metrics1.keys()) | set(metrics2.keys()) + all_match = True + + for metric_name in sorted(all_metrics): + + steps1 = set(metrics1[metric_name].keys()) + steps2 = set(metrics2[metric_name].keys()) + + if steps1 != steps2: + print(" ERROR: Step mismatch!") + print(f" {label1} steps: {sorted(steps1)}") + print(f" {label2} steps: {sorted(steps2)}") + all_match = False + continue + + # Convert values to tensors for each step and compare + values1 = [metrics1[metric_name][step] for step in sorted(steps1)] + values2 = [metrics2[metric_name][step] for step in sorted(steps2)] + + tensor1 = torch.tensor(values1) + tensor2 = torch.tensor(values2) + + if torch.equal(tensor1, tensor2): + print( + f" ✓ PASS: All {len(steps1)} steps match exactly (bitwise equivalent)" + ) + else: + # Find and report mismatches + mismatches = [] + for idx, step in enumerate(sorted(steps1)): + val1 = values1[idx] + val2 = values2[idx] + if val1 != val2: + mismatches.append((step, val1, val2, abs(val1 - val2))) + + print( + f" ERROR: Found {len(mismatches)} mismatches out of {len(steps1)} steps" + ) + + return all_match + + +def find_latest_event_dir(base_path): + """Find the latest timestamped directory in the base path.""" + if not os.path.exists(base_path): + raise ValueError(f"Path does not exist: {base_path}") + + subdirs = [d for d in glob.glob(os.path.join(base_path, "*")) if os.path.isdir(d)] + if not subdirs: + return base_path + + latest = max(subdirs, key=os.path.getmtime) + return latest + + +def run_training( + ngpu, + config_file, + model_name, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + deterministic, + tb_folder, + passes=None, +): + """Run a training job with the specified configuration.""" + print(f"\nStarting training: {model_name}") + + env = os.environ.copy() + env["NGPU"] = str(ngpu) + env["CONFIG_FILE"] = config_file + + cmd = [ + "./run_train.sh", + "--model.name", + model_name, + "--parallelism.data_parallel_shard_degree", + str(dp_shard_degree), + "--parallelism.tensor_parallel_degree", + str(tp_degree), + ] + + if cp_degree > 1: + cmd.extend(["--parallelism.context_parallel_degree", str(cp_degree)]) + if ep_degree > 1: + cmd.extend(["--parallelism.expert_parallel_degree", str(ep_degree)]) + + cmd.extend( + [ + "--activation_checkpoint.mode", + ac_mode, + "--training.steps", + str(steps), + "--debug.seed", + str(seed), + "--debug.deterministic", + "--metrics.enable_tensorboard", + "--metrics.save_tb_folder", + tb_folder, + ] + ) + + if passes: + cmd.extend( + [ + "--job.custom_config_module", + "torchtitan.experiments.compiler_toolkit.job_config", + "--compile.passes", + passes, + ] + ) + + print(f"Environment: NGPU={env['NGPU']}, CONFIG_FILE={env['CONFIG_FILE']}") + print(f"Running command: {' '.join(cmd)}") + + try: + result = subprocess.run( + cmd, + env=env, + check=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + text=True, + ) + print(f"✓ Training completed: {model_name}") + return True + except subprocess.CalledProcessError as e: + print(f"✗ Training failed: {model_name}") + print(f"Error output:\n{e.stdout}") + return False + + +def determine_model_names(config_file): + """Determine model names based on config file.""" + if "deepseek" in config_file: + model_name = "deepseek_v3" + elif "llama3" in config_file: + model_name = "llama3" + else: + raise ValueError( + f"Unable to determine model names from config file: {config_file}" + ) + + eager_model = f"simple_fsdp.{model_name}" + compiled_model = f"compiler_toolkit.{model_name}" + + return eager_model, compiled_model + + +def run_numerics_test( + ngpu, + config_file, + dp_shard_degree, + tp_degree, + cp_degree, + ep_degree, + ac_mode, + steps, + seed, + eager_tb_folder, + compiled_tb_folder, + metrics, + passes=None, +): + """ + Run numerics test by training both eager and compiled models and comparing metrics. + + Returns: + bool: True if all metrics match, False otherwise. + """ + # Determine model names + eager_model, compiled_model = determine_model_names(config_file) + + # Run eager training + eager_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=eager_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=eager_tb_folder, + ) + + if not eager_success: + print("✗ Eager training failed") + return False + + # Run compiled training + compiled_success = run_training( + ngpu=ngpu, + config_file=config_file, + model_name=compiled_model, + dp_shard_degree=dp_shard_degree, + tp_degree=tp_degree, + cp_degree=cp_degree, + ep_degree=ep_degree, + ac_mode=ac_mode, + steps=steps, + seed=seed, + deterministic=True, + tb_folder=compiled_tb_folder, + passes=passes, + ) + + if not compiled_success: + print("✗ Compiled training failed") + return False + + # Compare metrics + eager_path = find_latest_event_dir(f"./outputs/{eager_tb_folder}") + compiled_path = find_latest_event_dir(f"./outputs/{compiled_tb_folder}") + + eager_metrics = load_metrics(eager_path, metrics) + compiled_metrics = load_metrics(compiled_path, metrics) + + all_match = compare_metrics(eager_metrics, compiled_metrics) + + if all_match: + print("✓ SUCCESS: All metrics are bitwise equivalent") + else: + print("✗ FAILURE: Metrics differ between runs") + + return all_match diff --git a/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py new file mode 100644 index 0000000000..1421ca3bca --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/tests/test_numerics.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import unittest + +from .numerics_utils import run_numerics_test + + +class TestNumerics(unittest.TestCase): + """Test numerics equivalence between simple_fsdp and compiler_toolkit implementations.""" + + def test_llama3_fsdp_tp(self): + """Test Llama3 with FSDP + TP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "Llama3 FSDP+TP numerics test failed") + + def test_llama3_fsdp_tp_autobucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_autobucketing_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="autobucketing_reordering", + ) + self.assertTrue(result, "Llama3 FSDP+TP+autobucketing numerics test failed") + + def test_llama3_fsdp_tp_manualbucketing(self): + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/llama3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=1, + ac_mode="selective", + steps=10, + seed=42, + eager_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_eager", + compiled_tb_folder="tb/test_llama3_fsdp_tp_manualbucketing_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + passes="transformer_block_bucketing", + ) + self.assertTrue(result, "Llama3 FSDP+TP+manualbucketing numerics test failed") + + def test_deepseek_v3_fsdp_tp_ep(self): + """Test DeepSeek V3 with FSDP + TP + EP configuration.""" + result = run_numerics_test( + ngpu=4, + config_file="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml", + dp_shard_degree=2, + tp_degree=2, + cp_degree=1, + ep_degree=4, + ac_mode="none", + steps=10, + seed=42, + eager_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_eager", + compiled_tb_folder="tb/test_deepseek_v3_fsdp_tp_ep_compiled", + metrics=["loss_metrics/global_avg_loss", "grad_norm"], + ) + self.assertTrue(result, "DeepSeek V3 FSDP+TP+EP numerics test failed") + + +if __name__ == "__main__": + unittest.main() diff --git a/torchtitan/experiments/compiler_toolkit/train.py b/torchtitan/experiments/compiler_toolkit/train.py new file mode 100644 index 0000000000..7b0d58aa5a --- /dev/null +++ b/torchtitan/experiments/compiler_toolkit/train.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import gc + +from torchtitan.train import main, Trainer + + +class CompilerToolkitTrainer(Trainer): + def close(self) -> None: + super().close() + + # Note [explicit cudagraph close] + # cudagraph holds reference to nccl which prevents destroy nccl + # group. so we need to explicitly delete cudagraph which is held + # in joint_graph_module. An explicit gc.collect() is necessary + # to clean up reference cycles. + for part in self.model_parts: + if hasattr(part, "joint_graph_module"): + part.joint_graph_module = None + gc.collect() + + +if __name__ == "__main__": + main(CompilerToolkitTrainer) diff --git a/torchtitan/experiments/dion_optimizer/muon.py b/torchtitan/experiments/dion_optimizer/muon.py index 432ab1399f..0ac1602f71 100644 --- a/torchtitan/experiments/dion_optimizer/muon.py +++ b/torchtitan/experiments/dion_optimizer/muon.py @@ -609,27 +609,17 @@ def muon_update_batch_dim_sharded_async( - This is mathematically equivalent to orthogonalizing each expert's weights independently This function processes all params locally without all-to-all or all-gather. + + Optimized for CPU offloading with: + - Double-buffered CUDA streams to overlap transfer and compute + - Batched Newton-Schulz for fewer kernel launches + - Single sync point at end (no intermediate cuda.synchronize()) """ - U = muon_update_pre_orthogonalize( - G=G, - M=M, - momentum=momentum, - nesterov=nesterov, - ) - - # Orthogonalize each tensor locally - # Newton-Schulz treats dim 0 as batch, processing each slice independently - U = [ - muon_update_newton_schulz( - u, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) - for u in U - ] + # Check if we need CPU offloading (tensors are on CPU) + original_device = G[0].device + needs_gpu_transfer = original_device.type != "cuda" - # Compute scaled learning rate + # Compute scaled learning rate upfront # Use the first tensor's shape (they should all be the same shape within a batch) if adjust_lr is None: adjusted_lr = lr @@ -640,16 +630,132 @@ def muon_update_batch_dim_sharded_async( else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=X, - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if needs_gpu_transfer: + # PIPELINED MODE: Double-buffered streams for maximum overlap + # Timeline: transfer[i+1] overlaps with compute[i] overlaps with writeback[i-1] + cuda_device = torch.device("cuda") + dtype = M[0].dtype + n_tensors = len(X) + + # Mini-batch size for batched Newton-Schulz (fewer kernel launches) + BATCH_SIZE = 4 + + # Create streams: one for H2D transfers, one for compute, one for D2H transfers + h2d_stream = torch.cuda.Stream() + compute_stream = torch.cuda.Stream() + d2h_stream = torch.cuda.Stream() + + # Double buffer: prefetch next batch while computing current + prefetch_data = None # Will hold (g_batch, m_batch, x_batch, indices) for next iteration + + def prefetch_batch(start_idx): + """Prefetch a batch of tensors to GPU (non-blocking).""" + end_idx = min(start_idx + BATCH_SIZE, n_tensors) + indices = list(range(start_idx, end_idx)) + with torch.cuda.stream(h2d_stream): + g_batch = [G[i].to(dtype=dtype).to(cuda_device, non_blocking=True) for i in indices] + m_batch = [M[i].to(cuda_device, non_blocking=True) for i in indices] + x_batch = [X[i].to(cuda_device, non_blocking=True) for i in indices] + return (g_batch, m_batch, x_batch, indices) + + def compute_batch(g_batch, m_batch, x_batch, indices): + """Compute momentum update and Newton-Schulz on GPU.""" + with torch.cuda.stream(compute_stream): + # Wait for H2D transfer to complete (lightweight stream sync) + compute_stream.wait_stream(h2d_stream) + + u_batch = [] + for j in range(len(indices)): + g_gpu, m_gpu = g_batch[j], m_batch[j] + # Update momentum: M = mu * M + G + m_gpu.mul_(momentum) + m_gpu.add_(g_gpu) + # Compute U + if nesterov: + u_gpu = m_gpu * momentum + g_gpu + else: + u_gpu = m_gpu.clone() + u_batch.append(u_gpu.to(dtype=torch.bfloat16)) + + # Batched Newton-Schulz: stack same-shape tensors for single kernel + if len(u_batch) > 1 and all(u.shape == u_batch[0].shape for u in u_batch): + u_stacked = torch.stack(u_batch, dim=0) + u_stacked = muon_update_newton_schulz(u_stacked, newton_schulz_func, flatten, epsilon) + u_batch = list(u_stacked.unbind(0)) + else: + u_batch = [muon_update_newton_schulz(u, newton_schulz_func, flatten, epsilon) for u in u_batch] + + # Apply weight decay and update + for j in range(len(indices)): + x_batch[j].mul_(1 - lr * weight_decay) + x_batch[j].sub_(u_batch[j] * adjusted_lr) + + return m_batch, x_batch + + def writeback_batch(m_batch, x_batch, indices): + """Write results back to CPU (non-blocking).""" + with torch.cuda.stream(d2h_stream): + # Wait for compute to complete + d2h_stream.wait_stream(compute_stream) + for j, i in enumerate(indices): + M[i].copy_(m_batch[j], non_blocking=True) + X[i].copy_(x_batch[j], non_blocking=True) + + # Pipeline: prefetch first batch + if n_tensors > 0: + prefetch_data = prefetch_batch(0) + + # Main loop with double buffering + for batch_start in range(0, n_tensors, BATCH_SIZE): + # Get current batch (already prefetched) + g_batch, m_batch, x_batch, indices = prefetch_data + + # Start prefetching NEXT batch (overlaps with current compute) + next_start = batch_start + BATCH_SIZE + if next_start < n_tensors: + prefetch_data = prefetch_batch(next_start) + + # Compute current batch + m_batch, x_batch = compute_batch(g_batch, m_batch, x_batch, indices) + + # Writeback current batch (overlaps with next iteration's prefetch/compute) + writeback_batch(m_batch, x_batch, indices) + + # Single sync at end to ensure all D2H transfers complete + torch.cuda.synchronize() + + yield # Single yield to make this a generator + else: + # STANDARD GPU MODE: Process all tensors together (original behavior) + U = muon_update_pre_orthogonalize( + G=G, + M=M, + momentum=momentum, + nesterov=nesterov, + ) + + # Orthogonalize each tensor locally + # Newton-Schulz treats dim 0 as batch, processing each slice independently + U = [ + muon_update_newton_schulz( + u, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + for u in U + ] - yield # Single yield to make this a generator + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + yield # Single yield to make this a generator def muon_update_batch_async( @@ -673,146 +779,333 @@ def muon_update_batch_async( Batched version of Muon update. Batch size should be equal to number of GPUs. All tensors in a batch should have identical shape, sharding, and dtype. Identical hyperparameters are used for all tensors in the batch. + + Memory-optimized for CPU offloading: when tensors are on CPU, moves ALL computation + to GPU (momentum update, all_to_all, Newton-Schulz, weight update) then copies back. """ assert len(X) == len(G) assert len(X) == len(M) assert len(X) == world_size - # Update momentum and compute the inputs for orthogonalization - U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - momentum=momentum, - nesterov=nesterov, - ) - - # Get one whole matrix for each device to orthogonalize - if shard_dim is not None: - # Use all-to-all to transform from a batch of shards to a single whole matrix - # https://www.essential.ai/blog/infra - assert ( - process_group is not None - ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert not isinstance(U[0], DTensor), "U should contain local shards" - - # Debug: print full tensor info before the divisibility check - x0 = X[0] - x0_mesh = x0.device_mesh - x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} - - assert ( - X[0].size(shard_dim) % world_size == 0 - ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ - f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ - f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" - - # Allocate buffers to receive shards of one whole matrix from other devices - single_matrix_shards = [torch.empty_like(u) for u in U] - - # Redistribute the shards to form one unique full tensor on each device - # Sync CUDA before collective to ensure all prior GPU ops are complete - # This can prevent NCCL hangs due to async GPU operations + # Check early if we're in CPU offloading mode + G_local = to_local(G) + M_local = to_local(M) + X_local = to_local(X) + original_device = M_local[0].device + needs_gpu_transfer = original_device.type != "cuda" + + if needs_gpu_transfer: + # ====== CPU OFFLOADING PATH: Do ALL computation on GPU ====== + # This avoids slow CPU foreach operations for momentum and weight updates + cuda_device = torch.device("cuda") + dtype = M_local[0].dtype + + # Transfer G, M to GPU for momentum update + G_gpu = [g.to(dtype=dtype).to(cuda_device, non_blocking=True) for g in G_local] + M_gpu = [m.to(cuda_device, non_blocking=True) for m in M_local] torch.cuda.synchronize() - # N sequential all_gathers - only keep result for our assigned param - single_matrix_shards = None - for param_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + # Momentum update on GPU (equivalent to muon_update_pre_orthogonalize) + torch._foreach_mul_(M_gpu, momentum) + torch._foreach_add_(M_gpu, G_gpu) - # All ranks send their shard of param_idx - dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + if nesterov: + U_gpu = torch._foreach_mul(M_gpu, momentum) + torch._foreach_add_(U_gpu, G_gpu) + else: + # U shares memory with M when not using nesterov + U_gpu = M_gpu + + # Free G_gpu - no longer needed + del G_gpu + + # Convert to bfloat16 for communication + U_gpu = [u.to(dtype=torch.bfloat16) for u in U_gpu] + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + + # Validation + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Make contiguous for all_to_all + U_gpu = [u.contiguous() for u in U_gpu] + + # First all_to_all: batch of shards -> single whole matrix + single_matrix_shards = [torch.empty_like(U_gpu[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, U_gpu, group=process_group) + del U_gpu - # Only keep if this is our assigned parameter - if param_idx == device_rank: - single_matrix_shards = gathered - # Otherwise 'gathered' goes out of scope and memory can be freed + yield - yield + # Concatenate shards to form whole matrix + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards - # Concatentate shards to form a whole matrix to orthogonalize - single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) + # Newton-Schulz orthogonalization (on GPU) + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Split result back into shards - # Contiguous is needed for communication to work correctly - orth_shards = [ - x.contiguous() - for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) - ] + # Split result back into shards + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix - # N sequential all_gathers - collect results as we go - for shard_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + # Second all_to_all to redistribute orthogonalized shards + U_orth_gpu = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(U_orth_gpu, orth_shards, group=process_group) + del orth_shards - # All ranks send their shard at index shard_idx - dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + yield - # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} - # We need U[r] = O^r_{device_rank} - # So when shard_idx == device_rank: U[r] = gathered[r] for all r - if shard_idx == device_rank: - for r in range(world_size): - U[r].copy_(gathered[r]) + else: + # Matrices are not sharded, orthogonalize directly + single_matrix = U_gpu[device_rank] + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + if process_group is not None and process_group.size() > 1: + U_orth_gpu = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_orth_gpu, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + else: + assert world_size == 1 + U_orth_gpu = [single_matrix] + + # Compute scaled learning rate (use full tensor shape from X[0]) + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Transfer X to GPU for weight update + X_gpu = [x.to(cuda_device, non_blocking=True) for x in X_local] + torch.cuda.synchronize() + + # Weight update on GPU (equivalent to muon_update_post_orthogonalize) + torch._foreach_mul_(X_gpu, 1 - lr * weight_decay) + U_scaled = torch._foreach_mul(U_orth_gpu, adjusted_lr) + torch._foreach_sub_(X_gpu, U_scaled) + del U_scaled, U_orth_gpu + + # Copy M and X back to CPU + for i in range(world_size): + M_local[i].copy_(M_gpu[i], non_blocking=True) + X_local[i].copy_(X_gpu[i], non_blocking=True) - yield + torch.cuda.synchronize() + del M_gpu, X_gpu else: - # Matrices are not sharded, so we can directly orthogonalize - # Get a single matrix corresponding to this device - single_matrix = U[device_rank] - assert not isinstance(single_matrix, DTensor) - - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, + # ====== STANDARD GPU PATH ====== + # Update momentum and compute the inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=G_local, + M=M_local, + momentum=momentum, + nesterov=nesterov, ) - if process_group is not None and process_group.size() > 1: - # Allocate empty tensors to receive updates from other devices - U = [torch.empty_like(u) for u in U] + # Get one whole matrix for each device to orthogonalize + # JQ: This is the N sequential gather version + # if shard_dim is not None: + # # Use all-to-all to transform from a batch of shards to a single whole matrix + # # https://www.essential.ai/blog/infra + # assert ( + # process_group is not None + # ), "process_group must be provided for sharded DTensors" + # assert isinstance(X[0], DTensor), "X should contain DTensors" + # assert not isinstance(U[0], DTensor), "U should contain local shards" + + # # Debug: print full tensor info before the divisibility check + # x0 = X[0] + # x0_mesh = x0.device_mesh + # x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + + # assert ( + # X[0].size(shard_dim) % world_size == 0 + # ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + # f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + # f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # # Allocate buffers to receive shards of one whole matrix from other devices + # single_matrix_shards = [torch.empty_like(u) for u in U] + + # # Redistribute the shards to form one unique full tensor on each device + # # Sync CUDA before collective to ensure all prior GPU ops are complete + # # This can prevent NCCL hangs due to async GPU operations + # torch.cuda.synchronize() + + # # N sequential all_gathers - only keep result for our assigned param + # single_matrix_shards = None + # for param_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + + # # All ranks send their shard of param_idx + # dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + + # # Only keep if this is our assigned parameter + # if param_idx == device_rank: + # single_matrix_shards = gathered + # # Otherwise 'gathered' goes out of scope and memory can be freed + + # yield + + # # Concatentate shards to form a whole matrix to orthogonalize + # single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + # single_matrix = muon_update_newton_schulz( + # single_matrix, + # newton_schulz_func=newton_schulz_func, + # flatten=flatten, + # epsilon=epsilon, + # ) + + # # Split result back into shards + # # Contiguous is needed for communication to work correctly + # orth_shards = [ + # x.contiguous() + # for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + # ] + + # # N sequential all_gathers - collect results as we go + # for shard_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + + # # All ranks send their shard at index shard_idx + # dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + + # # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} + # # We need U[r] = O^r_{device_rank} + # # So when shard_idx == device_rank: U[r] = gathered[r] for all r + # if shard_idx == device_rank: + # for r in range(world_size): + # U[r].copy_(gathered[r]) + + # yield + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert not isinstance(U[0], DTensor), "U should contain local shards" + + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Sync CUDA before collective to prevent NCCL hangs from async GPU ops + torch.cuda.synchronize() + + single_matrix_shards = [torch.empty_like(U[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, [u.contiguous() for u in U], group=process_group) - # All gather orthogonalized results from other devices into buffer - work = dist.all_gather( - U, single_matrix.contiguous(), group=process_group, async_op=True + yield + + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, ) + + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix + + output_shards = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(output_shards, orth_shards, group=process_group) + del orth_shards + + for i in range(world_size): + U[i].copy_(output_shards[i]) + del output_shards + yield - work.wait() else: - # Single GPU case, no need to gather - assert world_size == 1 - U = [single_matrix] - - # Compute scaled learning rate - # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + single_matrix = U[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=to_local(X), - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if process_group is not None and process_group.size() > 1: + U_gathered = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_gathered, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + U = U_gathered + else: + assert world_size == 1 + U = [single_matrix] + + # Compute scaled learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X_local, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) def adamw_update_foreach_async( diff --git a/torchtitan/experiments/forge/engine.py b/torchtitan/experiments/forge/engine.py index a9284d1b27..040372d50c 100644 --- a/torchtitan/experiments/forge/engine.py +++ b/torchtitan/experiments/forge/engine.py @@ -87,10 +87,9 @@ def __init__(self, job_config: ForgeJobConfig): world_size=world_size, ) - world_mesh = parallel_dims.world_mesh if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + dp_degree, dp_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: dp_degree, dp_rank = 1, 0 self.dp_degree, self.dp_rank = dp_degree, dp_rank @@ -103,9 +102,10 @@ def __init__(self, job_config: ForgeJobConfig): # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, + distinct_seed_mesh_dims=["pp"], # same as `torchtitan/train.py` ) self.train_spec = get_train_spec(job_config.model.name) @@ -231,6 +231,7 @@ def __init__(self, job_config: ForgeJobConfig): if self.train_spec.state_dict_adapter else None ), + base_folder=job_config.job.dump_folder, ) loss_parallel_enabled = ( diff --git a/torchtitan/experiments/forge/example_train.py b/torchtitan/experiments/forge/example_train.py index 7b0b0c81e9..b00ec58a2a 100644 --- a/torchtitan/experiments/forge/example_train.py +++ b/torchtitan/experiments/forge/example_train.py @@ -19,6 +19,7 @@ from torchtitan.components.validate import build_validator from torchtitan.config import JobConfig from torchtitan.distributed import utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.hf_datasets.text_datasets import build_text_dataloader from torchtitan.tools import utils from torchtitan.tools.logging import logger @@ -152,42 +153,58 @@ def batch_generator( yield input_dict, labels - def forward_backward_step( + def post_dataloading_process( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - + ) -> tuple[torch.Tensor, torch.Tensor, dict[str, torch.Tensor], dict[str, Any]]: inputs = input_dict["input"] - extra_kwargs = {} - - if getattr(self.model_args, "use_flex_attn", False): - extra_kwargs["attention_masks"] = model_parts[0].get_attention_masks( + extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} + # For arguments, like attention_masks, we have to put them in a separate + # dict as extra_inputs are not forwarded to other stages in PP, but + # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + + try: + # pyrefly: ignore [not-callable] + extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, + extra_inputs=extra_inputs, ) - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + except TypeError: + pass + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.device, + self.job_config.parallelism.context_parallel_load_balancer, ) - if parallel_dims.cp_enabled - else None + + return inputs, labels, extra_inputs, extra_kwargs + + def forward_backward_step( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( + input_dict, labels ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) if self.pp_has_first_stage: self.pp_schedule.step( inputs, + **extra_inputs, **extra_kwargs, target=targets, losses=losses, @@ -211,10 +228,10 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): assert len(model_parts) == 1 with self.maybe_enable_amp: - pred = model_parts[0](inputs, **extra_kwargs) + pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred @@ -243,9 +260,7 @@ def train_step( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_optional_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() @@ -262,8 +277,8 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() global_avg_loss, global_max_loss = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"]), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"]), + dist_utils.dist_mean(loss, parallel_dims.get_optional_mesh("loss")), + dist_utils.dist_max(loss, parallel_dims.get_optional_mesh("loss")), ) else: global_avg_loss = global_max_loss = loss.detach().item() @@ -329,7 +344,7 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: diff --git a/torchtitan/experiments/ft/train.py b/torchtitan/experiments/ft/train.py new file mode 100644 index 0000000000..891f6c5554 --- /dev/null +++ b/torchtitan/experiments/ft/train.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.train import main, Trainer + + +class FTTrainer(Trainer): + def init_distributed(self) -> ParallelDims: + job_config = self.job_config + + # determine the global ranks when fault tolerance is enabled + global_ranks = [] + ft_config = job_config.fault_tolerance + if ft_config.enable: + group_size = ft_config.group_size + replica_id = ft_config.replica_id + first_rank = replica_id * group_size + last_rank = first_rank + group_size - 1 + global_ranks = list(range(first_rank, last_rank + 1)) + + # init distributed and build meshes + dist_utils.init_distributed( + job_config.comm, + enable_cpu_backend=job_config.training.enable_cpu_offload, + base_folder=job_config.job.dump_folder, + ranks=global_ranks, + ) + + world_size = int(os.environ["WORLD_SIZE"]) + parallelism_config = job_config.parallelism + + return ParallelDims( + dp_shard=parallelism_config.data_parallel_shard_degree, + dp_replicate=parallelism_config.data_parallel_replicate_degree, + cp=parallelism_config.context_parallel_degree, + tp=parallelism_config.tensor_parallel_degree, + pp=parallelism_config.pipeline_parallel_degree, + ep=parallelism_config.expert_parallel_degree, + etp=parallelism_config.expert_tensor_parallel_degree, + world_size=world_size, + ) + + +if __name__ == "__main__": + main(FTTrainer) diff --git a/torchtitan/experiments/kimi_linear/infra/parallelize.py b/torchtitan/experiments/kimi_linear/infra/parallelize.py index f675c95476..a0ca38a176 100644 --- a/torchtitan/experiments/kimi_linear/infra/parallelize.py +++ b/torchtitan/experiments/kimi_linear/infra/parallelize.py @@ -4,15 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -Parallelization for Kimi Linear model. - -Applies tensor parallelism, expert parallelism, FSDP, and other distributed training techniques. -""" - import torch import torch.nn as nn - from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( @@ -22,27 +15,36 @@ RowwiseParallel, SequenceParallel, ) - from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( apply_compile, apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger - -# Operations to save for selective activation checkpointing +# for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -51,8 +53,9 @@ def parallelize_kimi_linear( parallel_dims: ParallelDims, job_config: JobConfig, ): - """Apply parallelization to Kimi Linear model.""" - world_mesh = parallel_dims.world_mesh + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -60,75 +63,94 @@ def parallelize_kimi_linear( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) if parallel_dims.tp_enabled: - if ( - job_config.parallelism.enable_async_tensor_parallel - and not model_compile_enabled - ): - raise RuntimeError("Async TP requires torch.compile") + raise NotImplementedError("TP not supported for Kimi Linear") + + # Check if using DeepEP for MoE communication + if job_config.parallelism.expert_parallel_comm_backend == "deepep": + if not parallel_dims.ep_enabled: + raise ValueError( + "DeepEP requires expert parallelism (ep_degree > 1). " + "The DeepEP MoE model code does not support EP=1. " + "Please set expert_parallel_degree > 1 or use standard communication backend." + ) + if parallel_dims.etp_enabled: + raise NotImplementedError( + "DeepEP with Expert Tensor Parallelism (ETP) is not supported yet. " + "Please set expert_tensor_parallel_degree=1 or use standard communication backend." + ) + + use_deepep = True + + # Import deepep module to register custom ops before accessing them + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) + _op_sac_save_list.add(torch.ops.deepep.combine.default) + else: + use_deepep = False - enable_float8_linear = "float8" in job_config.model.converters - float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( - "rowwise", - "rowwise_with_gw_hp", - ) - enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) - apply_non_moe_tp( + apply_moe_ep_tp( model, - world_mesh["tp"], - loss_parallel=not job_config.parallelism.disable_loss_parallel, - enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, - enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + dual_pipe_v=dual_pipe_v, + use_deepep=use_deepep, ) - if parallel_dims.tp_enabled or parallel_dims.ep_enabled: - apply_moe_ep_tp( - model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) - if parallel_dims.fsdp_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -139,11 +161,7 @@ def parallelize_kimi_linear( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -152,132 +170,16 @@ def parallelize_kimi_linear( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") - elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) - # Enable weight tying after applying parallelisms - if model.model_args.enable_weight_tying: - model.output.weight = model.tok_embeddings.weight - return model - - -def apply_non_moe_tp( - model: nn.Module, - tp_mesh: DeviceMesh, - loss_parallel: bool, - enable_float8_tensorwise_tp: bool, - enable_async_tp: bool, -): - """Apply tensor parallelism to non-MoE parts of the model.""" - # Parallelize embeddings, final norm, and output - parallelize_module( - model, - tp_mesh, - { - "tok_embeddings": RowwiseParallel( - input_layouts=Replicate(), - output_layouts=Shard(1), - ), - "norm": SequenceParallel(), - "output": ColwiseParallel( - input_layouts=Shard(1), - output_layouts=Shard(-1) if loss_parallel else Replicate(), - use_local_output=not loss_parallel, - ), - }, - ) - - # Set up parallel styles - if enable_float8_tensorwise_tp: - from torchao.float8.float8_tensor_parallel import ( - Float8ColwiseParallel, - Float8RowwiseParallel, - PrepareFloat8ModuleInput, - ) - rowwise_parallel, colwise_parallel, prepare_module_input = ( - Float8RowwiseParallel, - Float8ColwiseParallel, - PrepareFloat8ModuleInput, - ) - else: - rowwise_parallel, colwise_parallel, prepare_module_input = ( - RowwiseParallel, - ColwiseParallel, - PrepareModuleInput, - ) - - # Apply TP to each transformer block - for transformer_block in model.layers.values(): - layer_plan = { - "attention_norm": SequenceParallel(), - "ffn_norm": SequenceParallel(), - } - - # Handle attention parallelization based on layer type - if transformer_block.is_linear_attn: - # KDA attention parallelization - layer_plan.update({ - "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), - ), - "attention.q_proj": colwise_parallel(use_local_output=False), - "attention.k_proj": colwise_parallel(use_local_output=False), - "attention.v_proj": colwise_parallel(use_local_output=False), - "attention.o_proj": rowwise_parallel(output_layouts=Shard(1)), - }) - else: - # MLA attention parallelization - layer_plan.update({ - "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), - ), - "attention.q_proj": colwise_parallel(use_local_output=False), - "attention.kv_a_proj_with_mqa": colwise_parallel(use_local_output=False), - "attention.kv_a_layernorm": NoParallel(use_local_output=False), - "attention.kv_b_proj": colwise_parallel(use_local_output=False), - "attention.wo": rowwise_parallel(output_layouts=Shard(1)), - }) - - # FFN parallelization (for non-MoE layers) - if not transformer_block.moe_enabled: - layer_plan.update({ - "feed_forward": prepare_module_input( - input_layouts=(Shard(1),), - desired_input_layouts=(Replicate(),), - ), - "feed_forward.w1": colwise_parallel(), - "feed_forward.w2": rowwise_parallel(output_layouts=Shard(1)), - "feed_forward.w3": colwise_parallel(), - }) - - parallelize_module( - module=transformer_block, - device_mesh=tp_mesh, - parallelize_plan=layer_plan, - ) - - if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) - - logger.info( - f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" - "Tensor Parallelism to the model" - ) diff --git a/torchtitan/experiments/kimi_linear/model/args.py b/torchtitan/experiments/kimi_linear/model/args.py index 753d72e2bf..af8c685552 100644 --- a/torchtitan/experiments/kimi_linear/model/args.py +++ b/torchtitan/experiments/kimi_linear/model/args.py @@ -42,7 +42,7 @@ class KimiLinearModelArgs(BaseModelArgs): max_seq_len: int = 8192 # Attention settings - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "block_causal" depth_init: bool = True enable_weight_tying: bool = False @@ -85,8 +85,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: else: self.layer_types.append("linear_attention") - if not self.use_flex_attn: - raise ValueError("Kimi Linear requires FlexAttention") + if self.attn_type != "flex": + raise ValueError(f"Kimi Linear requires `attn_type` be 'flex' but got {self.attn_type}") if ( job_config.compile.enable and "model" in job_config.compile.components diff --git a/torchtitan/experiments/kimi_linear/model/model.py b/torchtitan/experiments/kimi_linear/model/model.py index 4e111892a0..06ce621940 100644 --- a/torchtitan/experiments/kimi_linear/model/model.py +++ b/torchtitan/experiments/kimi_linear/model/model.py @@ -94,11 +94,11 @@ def apply_rotary_emb( q: torch.Tensor, k: torch.Tensor, rope_cache: torch.Tensor, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """Apply rotary position embeddings to query and key tensors.""" - if position_ids is not None: - rope_cache = rope_cache[position_ids] + if positions is not None: + rope_cache = rope_cache[positions] rope_cache = rope_cache.unsqueeze(2) # [batch_size, seqlen, 1, head_dim * 2] else: # reshape for broadcast @@ -186,11 +186,8 @@ def __init__(self, model_args: KimiLinearModelArgs): # Output projection self.wo = nn.Linear(self.n_heads * self.v_head_dim, model_args.dim, bias=False) - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.inner_attention = FlexAttentionWrapper() + def init_weights(self, init_std: float): nn.init.trunc_normal_(self.q_proj.weight, mean=0.0, std=0.02) @@ -203,7 +200,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): bs, seqlen, _ = x.shape @@ -234,7 +231,7 @@ def forward( k_rope = k_rope.expand(bs, seqlen, self.n_heads, self.qk_rope_head_dim) # Apply RoPE to rope parts only - q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, rope_cache, position_ids) + q_rope, k_rope = apply_rotary_emb(q_rope, k_rope, rope_cache, positions) # Combine nope and rope parts q = torch.cat([q_nope, q_rope], dim=-1) @@ -246,12 +243,9 @@ def forward( v = v.transpose(1, 2) # Apply attention - if self.use_flex_attn: - output = self.inner_attention( - q, k, v, block_mask=attention_masks["flex_attn"], scale=self.scaling - ) - else: - output = self.inner_attention(q, k, v, scale=self.scaling) + output = self.inner_attention( + q, k, v, block_mask=attention_masks["flex_attn"], scale=self.scaling + ) output = output.transpose(1, 2).contiguous().view(bs, seqlen, -1) return self.wo(output) @@ -343,7 +337,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): batch_size, seq_len, _ = x.shape @@ -482,14 +476,14 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): # Attention with residual x = x + self.attention( self.attention_norm(x), rope_cache, attention_masks, - position_ids=position_ids, + positions=positions, ) # FFN with residual @@ -510,7 +504,7 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class KimiLinearModel(nn.Module, ModelProtocol): +class KimiLinearModel(ModelProtocol): """ Kimi Linear Model with hybrid MLA and KDA attention. @@ -521,7 +515,7 @@ class KimiLinearModel(nn.Module, ModelProtocol): """ def __init__(self, model_args: KimiLinearModelArgs): - super().__init__() + super().__init__(model_args) have_linear_attention = any( lt == "linear_attention" for lt in model_args.layer_types @@ -645,7 +639,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): h = self.tok_embeddings(tokens) for layer in self.layers.values(): @@ -653,7 +647,7 @@ def forward( h, self.rope_cache, attention_masks=attention_masks, - position_ids=position_ids, + positions=positions, ) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/qwen3_next/infra/parallelize.py b/torchtitan/experiments/qwen3_next/infra/parallelize.py index bf319c2517..9c11249543 100644 --- a/torchtitan/experiments/qwen3_next/infra/parallelize.py +++ b/torchtitan/experiments/qwen3_next/infra/parallelize.py @@ -8,27 +8,28 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh from torch.distributed.tensor import Replicate, Shard from torch.distributed.tensor.parallel import ( - ColwiseParallel, parallelize_module, PrepareModuleInput, - RowwiseParallel, SequenceParallel, ) - +from torchtitan.components.peft.lora import LoraColwiseParallel, LoraRowwiseParallel from torchtitan.config import JobConfig, TORCH_DTYPE_MAP -from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag +from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( apply_compile, apply_fsdp, apply_moe_ep_tp, ) -from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.tools.logging import logger @@ -37,12 +38,17 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -51,7 +57,6 @@ def parallelize_qwen3next( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -59,9 +64,13 @@ def parallelize_qwen3next( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError( + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." + ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -84,27 +93,34 @@ def parallelize_qwen3next( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + positions_enabled=parallel_dims.cp_enabled or job_config.training.dataset_type == "preprocessed", ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + dual_pipe_v=dual_pipe_v, + ) + + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) if job_config.activation_checkpoint.mode != "none": @@ -112,28 +128,29 @@ def parallelize_qwen3next( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, + base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -144,11 +161,7 @@ def parallelize_qwen3next( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -157,23 +170,22 @@ def parallelize_qwen3next( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, - enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, ) # Enable weight tying after applying parallelisms + # pyrefly: ignore [missing-attribute] if model.model_args.enable_weight_tying: + # pyrefly: ignore [missing-attribute] model.output.weight = model.tok_embeddings.weight return model @@ -185,6 +197,7 @@ def apply_non_moe_tp( loss_parallel: bool, enable_float8_tensorwise_tp: bool, enable_async_tp: bool, + positions_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -195,12 +208,12 @@ def apply_non_moe_tp( model, tp_mesh, { - "tok_embeddings": RowwiseParallel( + "tok_embeddings": LoraRowwiseParallel( input_layouts=Replicate(), output_layouts=Shard(1), ), "norm": SequenceParallel(), - "output": ColwiseParallel( + "output": LoraColwiseParallel( input_layouts=Shard(1), output_layouts=Shard(-1) if loss_parallel else Replicate(), use_local_output=not loss_parallel, @@ -225,8 +238,8 @@ def apply_non_moe_tp( ) else: rowwise_parallel, colwise_parallel, prepare_module_input = ( - RowwiseParallel, - ColwiseParallel, + LoraRowwiseParallel, + LoraColwiseParallel, PrepareModuleInput, ) @@ -234,22 +247,30 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if positions_enabled else None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate()), - desired_input_layouts=(Replicate(), Replicate()), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), "attention.wv": colwise_parallel(use_local_output=False), - "attention.q_norm": NoParallel(use_local_output=False), - "attention.k_norm": NoParallel(use_local_output=False), + "attention.q_norm": SequenceParallel(sequence_dim=2), + "attention.k_norm": SequenceParallel(sequence_dim=2), "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -264,16 +285,14 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" diff --git a/torchtitan/experiments/qwen3_next/model/args.py b/torchtitan/experiments/qwen3_next/model/args.py index fbfae3323e..9ad44216fc 100644 --- a/torchtitan/experiments/qwen3_next/model/args.py +++ b/torchtitan/experiments/qwen3_next/model/args.py @@ -46,7 +46,7 @@ class Qwen3NextModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "block_causal" enable_weight_tying: bool = False @@ -92,8 +92,8 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: for i in range(self.n_layers) ] - if not self.use_flex_attn: - raise ValueError("Qwen3-Next requires FlexAttention") + if self.attn_type != "flex": + raise ValueError(f"Qwen3-Next requires `attn_type` be 'flex' but got {self.attn_type}") if ( job_config.compile.enable and "model" in job_config.compile.components diff --git a/torchtitan/experiments/qwen3_next/model/model.py b/torchtitan/experiments/qwen3_next/model/model.py index 4b0470bba4..4922d7f39c 100644 --- a/torchtitan/experiments/qwen3_next/model/model.py +++ b/torchtitan/experiments/qwen3_next/model/model.py @@ -17,7 +17,6 @@ FlexAttentionWrapper, ScaledDotProductAttentionWrapper, create_attention_mask, - get_block_causal_mask_mod_by_seq_lens, get_causal_mask_mod, get_document_mask_mod, ) @@ -73,10 +72,10 @@ def apply_rotary_emb( xk: torch.Tensor, rope_cache: torch.Tensor, partial_ratio: float = 1.0, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: - if position_ids is not None: - rope_cache = rope_cache[position_ids] + if positions is not None: + rope_cache = rope_cache[positions] rope_cache = rope_cache.unsqueeze(2) # [batch_size, seqlen, 1, head_dim * 2] else: rope_cache = reshape_for_broadcast(rope_cache, xq) @@ -137,6 +136,7 @@ def __init__(self, model_args: Qwen3NextModelArgs): self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 self.partial_rotary_factor = model_args.partial_rotary_factor + self.attn_type = getattr(model_args, "attn_type", "sdpa") self.q_norm = ZeroCenteredRMSNorm(self.head_dim, eps=model_args.norm_eps) self.k_norm = ZeroCenteredRMSNorm(self.head_dim, eps=model_args.norm_eps) @@ -149,11 +149,11 @@ def __init__(self, model_args: Qwen3NextModelArgs): model_args.n_heads * self.head_dim, model_args.dim, bias=False ) - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -167,7 +167,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): bs, seqlen, _ = x.shape xq, gate = torch.chunk(self.wq(x), 2, dim=-1) @@ -187,7 +187,7 @@ def forward( xk, rope_cache, partial_ratio=self.partial_rotary_factor, - position_ids=position_ids, + positions=positions, ) keys = repeat_kv(xk, self.n_rep) @@ -197,12 +197,9 @@ def forward( xk = keys.transpose(1, 2) xv = values.transpose(1, 2) - if self.use_flex_attn: - output = self.inner_attention( - xq, xk, xv, block_mask=attention_masks["flex_attn"], scale=self.scaling - ) - else: - output = self.inner_attention(xq, xk, xv, scale=self.scaling) + output = self.inner_attention( + xq, xk, xv, block_mask=attention_masks["flex_attn"], scale=self.scaling + ) output = output.transpose(1, 2).contiguous().view(bs, seqlen, -1) output = output * torch.sigmoid(gate.view(bs, seqlen, -1)) @@ -299,7 +296,7 @@ def forward( hidden_states: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None, + positions: torch.Tensor | None, ): # hidden_states = apply_mask_to_padding_states(hidden_states, attention_mask) @@ -628,13 +625,13 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): x = x + self.attention( self.attention_norm(x), rope_cache, attention_masks, - position_ids=position_ids, + positions=positions, ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -652,9 +649,9 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Qwen3NextModel(nn.Module, ModelProtocol): +class Qwen3NextModel(ModelProtocol): def __init__(self, model_args: Qwen3NextModelArgs): - super().__init__() + super().__init__(model_args) have_linear_attention = any( lt == "linear_attention" for lt in model_args.layer_types @@ -763,18 +760,6 @@ def get_attention_masks( cu_seqlens = torch.cat(cu_seqlens) ret["cu_seqlens"] = cu_seqlens - # case "block_causal_by_sequence_lengths": - # sequence_lengths = extra_inputs.pop("sequence_lengths", None) - # if sequence_lengths is None: - # raise RuntimeError( - # "`sequence_lengths` required for `block_causal_by_sequence_lengths`" - # ) - # B = input_batch.shape[0] - # mask_mods.append( - # get_block_causal_mask_mod_by_seq_lens(sequence_lengths) - # ) - - # TODO: calculate seq_idx and cu_seqlens case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" @@ -788,7 +773,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): h = self.tok_embeddings(tokens) for layer in self.layers.values(): @@ -796,7 +781,7 @@ def forward( h, self.rope_cache, attention_masks=attention_masks, - position_ids=position_ids, + positions=positions, ) h = self.norm(h) if self.norm else h output = self.output(h) if self.output else h diff --git a/torchtitan/experiments/rl/README.md b/torchtitan/experiments/rl/README.md new file mode 100644 index 0000000000..72b3d2ad11 --- /dev/null +++ b/torchtitan/experiments/rl/README.md @@ -0,0 +1,12 @@ +# Deterministic RL Training with vLLM + +This package provides two approaches for integrating TorchTitan models with vLLM: + +1. vllm_compat/ - vLLM-Compatible approach + - Separate model definition matching vLLM's weight format + - Support batch-invariant and bit-wise identity between train and inference + - Custom backward passes for attention gradient computation + +2. unified/ - Unified approach + - Uses canonical TorchTitan model definition for inference directly + - Replaces attention with vLLM Compatible attention for inference diff --git a/torchtitan/experiments/rl/unified/README.md b/torchtitan/experiments/rl/unified/README.md new file mode 100644 index 0000000000..27550e977c --- /dev/null +++ b/torchtitan/experiments/rl/unified/README.md @@ -0,0 +1,88 @@ +# Run vLLM inference with TorchTitan Qwen3 Model + +This directory contains code to run a single canonical model definition (TorchTitan model definition) with vLLM inference engine (not batch-invariant yet, working in progress). This work is actively developing and only supports inference for now. + +This work is inspired by https://github.com/vllm-project/vllm/pull/28685. + +## Overview +The integration consists of two main components: + +1. **Model Adapter** (`model/qwen3.py`): A custom model class that extends vLLM's `Qwen3ForCausalLM` to handle TorchTitan checkpoint naming conventions +2. **Inference Script** (`infer.py`): A simple script to register the model and run inference + + +## Quick Start +### Prerequisites + +1. Install PyTorch nightly for torchtitan: +``` +pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall +``` + + +2. Install vLLM from source [vllm-use-an-existing-pytorch-installation](https://docs.vllm.ai/en/latest/getting_started/installation/gpu/index.html#use-an-existing-pytorch-installation): +```bash +# install PyTorch first, either from PyPI or from source +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . +``` + + +NOTE: If `flash_attn_varlen_func` hits error "torch.AcceleratorError: CUDA error: the provided PTX was compiled with an unsupported toolchain" during forward path, this is due to GPU driver version is not compatible with vLLM/PyTorch compiled version. Use the following command to recompile vLLM. + +``` +# Set CUDA version environment variable +export CUDA_HOME=/usr/local/cuda-12.4 +export PATH=/usr/local/cuda-12.4/bin:$PATH +export LD_LIBRARY_PATH=/usr/local/cuda-12.4/lib64:$LD_LIBRARY_PATH + +# Clean previous build +rm -rf build dist *.egg-info +uv pip uninstall -y vllm + +# Rebuild vLLM from source with CUDA 12.4 +uv pip install -e . + +``` + +3. Download Qwen/Qwen3-0.6B checkpoint from HuggingFace and put into `torchtitan/experiments/rl/example_checkpoint` folder. +``` +python scripts/download_hf_assets.py --repo_id Qwen/Qwen3-0.6B --local_dir torchtitan/experiments/rl/example_checkpoint --all --hf_token=... +``` + +4. Run inference: +``` +python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path +``` + +Run with TP: (work in progress) +``` +python torchtitan/experiments/rl/unified/infer.py --model-ckpt-path --tensor-parallel-size 2 + +``` + +5. Run simple rl loop +``` +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +``` +Right now we only support VLLM_COMPAT mode, which could achieve trainer and generator bitwise identical. We are working on support UNIFIED mode, +which uses a unified model definition for trainer and generator. + +## TODO +Work on batch invariance: +1. Integrate with simple_rl_multiprocess.py to run end-to-end RL with one canonical model definition(UNIFIED mode). +2. Rewrite attention part to use vllm.Attention() with backward as the only attention path. +3. Leverage batch-invariant kernels into model definition. + +Work on the RL loop: +1. Design trainer API and integrate with [train.py](https://github.com/pytorch/torchtitan/blob/main/torchtitan/train.py#L475) +2. Remove hardcoded configs and dependency on Qwen3 Model. Use torchtitan's config/TrainSpec instead, to work with any model. +3. Need to load the gsm8k dataset using TorchTitan dataset. +4. Need to properly implement weight saving and loading using TorchTitan's checkpoint mechanism, or use TorchStore. Also need to + replace `vllm_to_torchtitan` and `torchtitan_to_vllm` calls to TorchTitan [state dict adaptor](https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/qwen3/model/state_dict_adapter.py). +5. Right now we only support trainer run on multiple processes using DDP, and generator using TP, need to onboard more parallelism. +6. Right now we only support VLLM_COMPAT mode to achieve batch invariance and bitwise determinism, need to support UNIFIED mode. +7. In the longer term, need to add trajectory queue to achieve async, right now trainer and generator are running synchronously. diff --git a/torchtitan/experiments/rl/unified/__init__.py b/torchtitan/experiments/rl/unified/__init__.py new file mode 100644 index 0000000000..430df3f268 --- /dev/null +++ b/torchtitan/experiments/rl/unified/__init__.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Unified approach for running TorchTitan models with vLLM inference. + +This module automatically registers TorchTitan models with vLLM when imported. +Uses the canonical TorchTitan model definition directly with vLLM inference engine. +""" + +from torchtitan.experiments.rl.unified.infra.parallelize import parallelize_qwen3 +from torchtitan.protocols.train_spec import get_train_spec, TrainSpec +from vllm.logger import init_logger + +from .infra.parallelism_utils import create_parallel_dims_from_vllm_config + +from .models.vllm_wrapper import TorchTitanVLLMModelWrapper + +logger = init_logger(__name__) + + +def register_torchtitan_model_from_train_spec( + train_spec: TrainSpec, + model_name: str, + model_flavor: str, +) -> None: + """ + Register a TorchTitan model with vLLM using a TrainSpec. + + Args: + train_spec: TorchTitan TrainSpec containing model components + model_name: Name to register in vLLM (e.g., "Qwen3TorchTitanForCausalLM") + model_flavor: Model flavor key (e.g., "0.6B") to select from qwen3_args + + """ + from vllm.model_executor.models.registry import ModelRegistry + + # Get model_args directly from TrainSpec.model_args dict using flavor key + if isinstance(train_spec.model_args, dict): + if model_flavor not in train_spec.model_args: + raise ValueError( + f"Model flavor '{model_flavor}' not found in train_spec.model_args. " + f"Available flavors: {list(train_spec.model_args.keys())}" + ) + model_args = train_spec.model_args[model_flavor] + else: + raise ValueError( + "train_spec.model_args must be a dict mapping flavor names to ModelArgs" + ) + + # Create dynamic model class directly from TrainSpec components + class TorchTitanVLLMModelFromSpec(TorchTitanVLLMModelWrapper): + def __init__(self, *, vllm_config, prefix=""): + super().__init__( + model_cls=train_spec.model_cls, + model_args=model_args, + state_dict_adapter=train_spec.state_dict_adapter, + # NOTE: This should be replaced with qwen3 parallelization plan in torchtitan core + parallelize_fn=parallelize_qwen3, + vllm_config=vllm_config, + prefix=prefix, + ) + + # Set the class name + TorchTitanVLLMModelFromSpec.__name__ = model_name + TorchTitanVLLMModelFromSpec.__qualname__ = model_name + + # Register with vLLM + ModelRegistry.register_model(model_name, TorchTitanVLLMModelFromSpec) + + logger.info( + f"Successfully registered {model_name} with vLLM using TrainSpec " + f"(model_cls={train_spec.model_cls.__name__}, flavor={model_flavor})" + ) + + +# Auto-register TorchTitan models with vLLM when this module is imported +register_torchtitan_model_from_train_spec( + train_spec=get_train_spec("qwen3"), + model_name="Qwen3TorchTitanForCausalLM", + # TODO: Remove the model_flavor args when registering model, + # allow passing model flavor option from config system. Now we have to specify + # model_flavor during registration because we can not pass torchtitan job_config from LLM() Api + model_flavor="0.6B", +) + + +__all__ = [ + "TorchTitanVLLMModelWrapper", + "create_parallel_dims_from_vllm_config", + "register_torchtitan_model_from_train_spec", +] diff --git a/torchtitan/experiments/rl/unified/actors/generator.py b/torchtitan/experiments/rl/unified/actors/generator.py new file mode 100644 index 0000000000..45d89095b7 --- /dev/null +++ b/torchtitan/experiments/rl/unified/actors/generator.py @@ -0,0 +1,462 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import asyncio +import logging +import os + +from dataclasses import dataclass +from typing import List + +import torch +from monarch.actor import Actor, endpoint +from safetensors.torch import save_file +from torchtitan.config.job_config import Comm +from torchtitan.distributed import utils as dist_utils + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.rl import unified # noqa: F401 + +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + compute_grpo_advantages, + compute_grpo_advantages_stable, + math_reward_function, + trivial_reward_function, +) +from torchtitan.experiments.rl.vllm_compat.weights.converter import torchtitan_to_vllm +from vllm import LLM, SamplingParams + +logger = logging.getLogger(__name__) + + +@dataclass +class TrajectoryData: + """ + Data from one generation batch. + + Attributes: + policy_version: Version of policy that produced this batch + completions: List of completion strings + vllm_token_ids: List of token ID lists for each completion + vllm_token_log_probs: List of per-token log prob lists + prompt_token_ids: List of prompt token ID lists + rewards: Computed rewards for each completion + advantages: Computed advantages for each completion + """ + + policy_version: int + completions: List[str] + vllm_token_ids: List[List[int]] + vllm_token_log_probs: List[List[float]] + prompt_token_ids: List[List[int]] + rewards: torch.Tensor + advantages: torch.Tensor + + +class VLLMRolloutEngine: + """ + vLLM engine for fast rollouts with weight updates. + + Note: vLLM loads from model_config.model path, so we create a temporary + directory with updated weights and restart the engine. This is faster than + recreating temp dirs repeatedly and handles config/tokenizer files properly. + + Args: + model_path: Path to HuggingFace model (for config/tokenizer) + temp_checkpoint_dir: Directory to save temporary weight checkpoints + """ + + def __init__( + self, + model_path: str, + temp_checkpoint_dir: str = "./converted", + tp_size: int = 1, + ): + self.base_model_path = model_path + self.temp_model_dir = os.path.abspath( + os.path.join(temp_checkpoint_dir, "vllm_temp_model") + ) + os.makedirs(self.temp_model_dir, exist_ok=True) + + import glob + + # Copy config/tokenizer files from base model to temp dir + import shutil + + for file in [ + "config.json", + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", + "merges.txt", + "vocab.json", + ]: + src = os.path.join(model_path, file) + if os.path.exists(src): + shutil.copy2(src, self.temp_model_dir) + + # Copy the original model shard files if they exist + # We'll overwrite these with our single model.safetensors later + for shard_file in glob.glob(os.path.join(model_path, "model-*.safetensors")): + dst = os.path.join(self.temp_model_dir, os.path.basename(shard_file)) + shutil.copy2(shard_file, dst) + + # Copy index file if it exists + index_file = os.path.join(model_path, "model.safetensors.index.json") + if os.path.exists(index_file): + shutil.copy2(index_file, self.temp_model_dir) + + self.llm = None + self.tp_size = tp_size + logger.info("vLLM rollout engine initialized (will load on first use)") + + def update_weights(self, vllm_compat_state: dict) -> None: + """ + Update vLLM model weights from vLLM-compat state dict. + + This converts weights to vLLM format, saves them, and reloads using + vLLM's reload_weights() API after updating the model path config. + + Args: + vllm_compat_state: vLLM-compat model state dict (with gate_up_proj/down_proj) + """ + # Convert vLLM-compat -> vLLM (torchtitan_to_vllm handles both formats) + vllm_state = torchtitan_to_vllm(vllm_compat_state) + + # Save to temp model directory + import os + + checkpoint_path = os.path.join(self.temp_model_dir, "model.safetensors") + + # Update the shard files that vLLM will actually load + # We need to split our weights to match the original 2-shard structure + import glob + import json + + shard_files = sorted( + glob.glob(os.path.join(self.temp_model_dir, "model-*.safetensors")) + ) + index_file = os.path.join(self.temp_model_dir, "model.safetensors.index.json") + + # TODO: need to replace this with Torchtitan's checkpoint save and load + # right now we hardcoded to work with 2 safe tensor files which we only + # tested on Qwen3 0.6B model. In the longer term, need to use TorchStore + # to achieve the weight communication. + # only generator rank 0 saves the weight + if torch.distributed.get_rank() == 0: + logger.info(f"Saving weights to {checkpoint_path}") + if len(shard_files) == 2 and os.path.exists(index_file): + # Load the index to see which weights go in which shard + with open(index_file, "r") as f: + index_data = json.load(f) + + weight_map = index_data["weight_map"] + + # Split weights according to the index + shard1_weights = {} + shard2_weights = {} + + for key, value in vllm_state.items(): + shard_file = weight_map.get(key, shard_files[0]) + if "model-00001-of-00002" in shard_file: + shard1_weights[key] = value + else: + shard2_weights[key] = value + + # Ensure weights stay in bfloat16 + shard1_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard1_weights.items() + } + shard2_weights = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in shard2_weights.items() + } + + # Save to the shard files + save_file(shard1_weights, shard_files[0]) + save_file(shard2_weights, shard_files[1]) + else: + # Ensure weights stay in bfloat16 + vllm_state = { + k: v.to(torch.bfloat16) if v.dtype == torch.float32 else v + for k, v in vllm_state.items() + } + # Fallback: save as single file + save_file(vllm_state, checkpoint_path) + + # Synchronize all ranks before reloading to ensure rank 0 finished writing + torch.distributed.barrier() + logger.info( + f"[Rank {torch.distributed.get_rank()}] Synchronized after weight save" + ) + + # First time: create the engine + if self.llm is None: + self.llm = LLM( + model=self.temp_model_dir, + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, + trust_remote_code=True, + max_model_len=2048, + dtype="bfloat16", + gpu_memory_utilization=0.1, # Reduced from 0.5 + distributed_executor_backend="external_launcher", # vllm do not spawn processes + seed=42, # Fixed seed for determinism + enforce_eager=True, + tensor_parallel_size=self.tp_size, # Explicitly single GPU + ) + logger.info("Created new vLLM engine") + else: + # Use collective_rpc to call reload_weights on all workers + # This reloads weights from temp_model_dir without recreating the engine + self.llm.collective_rpc("reload_weights") + + @torch.no_grad() + def generate( + self, + prompt_texts: list[str], + max_new_tokens: int = 20, + temperature: float = 1.0, + n_samples_per_prompt: int = 4, + ) -> tuple[ + list[str], torch.Tensor, list[list[int]], list[list[float]], list[list[int]] + ]: + """ + Generate samples using vLLM. + + Args: + prompt_texts: List of prompt strings + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + n_samples_per_prompt: Number of samples per prompt + + Returns: + completions: List of completion strings + log_probs: [batch] - Sum of log probs for each completion + token_ids: List of token ID lists for each completion (generated tokens only) + token_log_probs: List of per-token log prob lists for each completion + prompt_token_ids: List of prompt token ID lists for each completion + """ + sampling_params = SamplingParams( + temperature=temperature, + max_tokens=max_new_tokens, + n=n_samples_per_prompt, + seed=42, + logprobs=1, + prompt_logprobs=1, # Also get prompt log probs to access prompt token IDs + ) + + outputs = self.llm.generate(prompt_texts, sampling_params) + + # Extract completions and log probs + completions = [] + log_probs_list = [] + token_ids_list = [] + token_log_probs_list = [] + prompt_token_ids_list = [] + + for output in outputs: + # Extract prompt token IDs from the output + prompt_token_ids = output.prompt_token_ids + + for sample in output.outputs: + completions.append(sample.text) + + # Store prompt tokens for this sample + prompt_token_ids_list.append(prompt_token_ids) + + # Extract token IDs (generated tokens only) + token_ids = sample.token_ids + token_ids_list.append(token_ids) + + # Extract per-token log probs + per_token_log_probs = [ + list(logprob_dict.values())[0].logprob + for logprob_dict in sample.logprobs + ] + token_log_probs_list.append(per_token_log_probs) + + # Sum log probs across generated tokens + total_log_prob = sum(per_token_log_probs) + log_probs_list.append(total_log_prob) + + log_probs = torch.tensor(log_probs_list, dtype=torch.float32) + + return ( + completions, + log_probs, + token_ids_list, + token_log_probs_list, + prompt_token_ids_list, + ) + + def __del__(self): + """Cleanup vLLM engine.""" + if hasattr(self, "llm"): + del self.llm + torch.cuda.empty_cache() + + +class GeneratorState: + """States for the Generator's state machine.""" + + READY_TO_GENERATE = "READY_TO_GENERATE" + READY_TO_UPDATE = "READY_TO_UPDATE" + + +class Generator(Actor): + """ + Generates rollouts using vLLM engine. + + Maintains a vLLM engine that is synchronized with the Trainer + via weight sync. Generates completions for given prompts and + computes rewards/advantages. + + Args: + model_path: Path to HuggingFace model + prompt_texts: List of prompt strings + expected_answers: List of expected answers + group_size: Number of samples per prompt + max_new_tokens: Max tokens to generate + temperature: Sampling temperature + use_real_dataset: Whether using real dataset (GSM8K) + grpo_beta: Beta for GRPO advantages + use_stable_grpo: Whether to use stable GRPO + tp_size: Tensor Parallel size + """ + + def __init__( + self, + model_path: str, + prompt_texts: List[str], + expected_answers: List[str], + group_size: int = 8, + max_new_tokens: int = 20, + temperature: float = 1.0, + use_real_dataset: bool = False, + grpo_beta: float = 0.1, + use_stable_grpo: bool = False, + tp_size: int = 1, + ): + self.model_path = model_path + self.prompt_texts = prompt_texts + self.expected_answers = expected_answers + self.group_size = group_size + self.max_new_tokens = max_new_tokens + self.temperature = temperature + self.use_real_dataset = use_real_dataset + self.grpo_beta = grpo_beta + self.use_stable_grpo = use_stable_grpo + self.tp_size = tp_size + + # Initialize distributed environment for SPMD generator + world_size = dist_utils.init_distributed( + Comm(), + ) + # Initialize vLLM engine + self.vllm_engine = VLLMRolloutEngine(model_path, tp_size=self.tp_size) + + # State machine + self.state = GeneratorState.READY_TO_UPDATE + self.cond = asyncio.Condition() + self.policy_version = 0 + + # Reward function + self.reward_fn = ( + math_reward_function if use_real_dataset else trivial_reward_function + ) + + logger.info("Generator initialized with vLLM engine") + + @endpoint + async def generate(self) -> None: + """Generate trajectories and compute rewards/advantages.""" + logger.info( + f"{os.getpid()=} Generating start generate (policy v{self.policy_version})..." + ) + async with self.cond: + # Wait until ready to generate (weights have been updated) + await self.cond.wait_for( + lambda: self.state == GeneratorState.READY_TO_GENERATE + ) + + # Generate samples using vLLM + ( + completions, + vllm_log_probs, + vllm_token_ids, + vllm_token_log_probs, + prompt_token_ids, + ) = self.vllm_engine.generate( + self.prompt_texts, + self.max_new_tokens, + self.temperature, + n_samples_per_prompt=self.group_size, + ) + + # Compute rewards + rewards = self.reward_fn( + completions, self.expected_answers, self.group_size + ) + + # Normalize rewards + reward_mean = rewards.mean() + reward_std = rewards.std() + if reward_std > 1e-8: + rewards_normalized = (rewards - reward_mean) / reward_std + else: + rewards_normalized = rewards - reward_mean + + # Compute advantages using GRPO + if self.use_stable_grpo: + advantages = compute_grpo_advantages_stable( + rewards_normalized, self.group_size + ) + else: + advantages = compute_grpo_advantages( + rewards_normalized, self.group_size, beta=self.grpo_beta + ) + + # Create trajectory data + trajectory = TrajectoryData( + policy_version=self.policy_version, + completions=completions, + vllm_token_ids=vllm_token_ids, + vllm_token_log_probs=vllm_token_log_probs, + prompt_token_ids=prompt_token_ids, + rewards=rewards, + advantages=advantages, + ) + + # Signal ready for update + self.state = GeneratorState.READY_TO_UPDATE + self.cond.notify_all() + + logger.info( + f"{os.getpid()=} Generating finish generate (policy v{self.policy_version})..." + ) + return trajectory + + @endpoint + async def update(self, version: int, vllm_compat_state: dict) -> None: + """Update generate weights. + + Args: + version: New policy version number + vllm_compat_state: vLLM-compatible state dict + """ + async with self.cond: + self.vllm_engine.update_weights(vllm_compat_state) + # Update version and state + self.policy_version = version + self.state = GeneratorState.READY_TO_GENERATE + self.cond.notify_all() + logger.info( + f"{os.getpid()=} Generator updating weights to policy v{version}..." + ) diff --git a/torchtitan/experiments/rl/unified/actors/trainer.py b/torchtitan/experiments/rl/unified/actors/trainer.py new file mode 100644 index 0000000000..9ffb9f0f0a --- /dev/null +++ b/torchtitan/experiments/rl/unified/actors/trainer.py @@ -0,0 +1,136 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from typing import Any, Optional + +import torch +from monarch.actor import Actor, endpoint +from torchtitan.experiments.rl.unified.actors.generator import TrajectoryData +from torchtitan.experiments.rl.unified.models.parallelism_utils import ( + create_trainer_parallel_dims, +) +from torchtitan.experiments.rl.unified.models.utils import load_model, ModelMode +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + compute_policy_gradient_loss_vllm, +) +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) + +logger = logging.getLogger(__name__) + + +class Trainer(Actor): + """ + Updates policy based on collected trajectories. + + Run model forward on trajectories, computes loss, and run backward. + + Args: + titan_checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model + learning_rate: Learning rate for optimizer + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model + """ + + def __init__( + self, + titan_checkpoint_path: str, + model_path: str, + learning_rate: float = 1e-5, + model_mode: str = ModelMode.VLLM_COMPAT, + ddp_size: int = 1, + tp_size: int = 1, + ): + # Explicitly set cuda device for each trainer, otherwise different processes will use the same CUDA device + local_rank = int(os.environ["LOCAL_RANK"]) + device = torch.device(f"cuda:{local_rank}") + torch.cuda.set_device(local_rank) + + self.model = load_model( + titan_checkpoint_path, model_path, model_mode=model_mode + ) + self.ddp_size = ddp_size + self.tp_size = tp_size + self.parallel_dims = create_trainer_parallel_dims(self.ddp_size, self.tp_size) + + # apply PT-D Parallelism + # TODO: right now it only works for qwen3 model, need to formalize this to use parallize_fn from train_spec + from torchtitan.models.llama3.infra.parallelize import apply_ddp + + apply_ddp( + self.model, + self.parallel_dims.get_mesh("dp_replicate"), + enable_compile=False, + ) + + self.model = self.model.to(device) + self.model.train() + + # Optimizer + self.optimizer = torch.optim.AdamW(self.model.parameters(), lr=learning_rate) + self.policy_version = 0 + self.generator: Optional[Any] = None + + logger.info("Trainer initialized with TorchTitan model") + + @endpoint + async def get_weights(self) -> dict: + """Get vLLM-compatible weights for generator. + + Returns: + vLLM-compatible state dict + """ + titan_state = self.model.state_dict() + vllm_compat_state = torchtitan_to_vllm_compat(titan_state) + return vllm_compat_state + + @endpoint + async def step(self, trajectory: TrajectoryData) -> dict: + """Perform one training step. + + Returns: + Training metrics + """ + logger.info( + f"{os.getpid()=} Trainer starts to train {self.policy_version} on traj:" + ) + # Compute loss + loss, loss_metrics = compute_policy_gradient_loss_vllm( + self.model, + trajectory.vllm_token_ids, + trajectory.vllm_token_log_probs, + trajectory.prompt_token_ids, + trajectory.advantages, + kl_coef=0.1, + ) + + # Update weights + self.optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + self.optimizer.step() + + self.policy_version += 1 + + # TODO: save dcp checkpoint to file here instead of sending weight dicts + + # Return metrics + metrics = { + "loss": loss.item(), + "reward_mean": trajectory.rewards.mean().item(), + "reward_std": trajectory.rewards.std().item(), + "advantage_mean": trajectory.advantages.mean().item(), + "advantage_std": trajectory.advantages.std().item(), + "sample_completion": trajectory.completions[0][:80], + "policy_version": self.policy_version, + **loss_metrics, + } + logger.info(f"{os.getpid()=} Trainer finish step {self.policy_version}") + return metrics diff --git a/torchtitan/experiments/rl/unified/infer.py b/torchtitan/experiments/rl/unified/infer.py new file mode 100755 index 0000000000..3e9470bf5d --- /dev/null +++ b/torchtitan/experiments/rl/unified/infer.py @@ -0,0 +1,124 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os + +# Must set spawn method before any CUDA operations or vLLM imports +# CUDA cannot be re-initialized in forked subprocesses +# See also https://docs.vllm.ai/en/v0.8.3/design/multiprocessing.html#python-multiprocessing +os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + +import argparse + +# Import unified module - this automatically registers TorchTitan models with vLLM +from torchtitan.experiments.rl import unified # noqa: F401 +from vllm import LLM, SamplingParams +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser( + description="Run TorchTitan model inference with vLLM Engine", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + parser.add_argument( + "--model-ckpt-path", + type=str, + default="torchtitan/experiments/rl/example_checkpoint", + help="Path to TorchTitan checkpoint directory", + ) + parser.add_argument( + "--prompt", + type=str, + default="Hello, my name is", + help="Prompt text for generation", + ) + parser.add_argument( + "--max-tokens", + type=int, + default=100, + help="Maximum number of tokens to generate", + ) + parser.add_argument( + "--temperature", + type=float, + default=0.8, + help="Sampling temperature", + ) + parser.add_argument( + "--tensor-parallel-size", + type=int, + default=1, + help="Number of GPUs for tensor parallelism (default: 1 for single GPU)", + ) + return parser.parse_args() + + +def infer(): + args = parse_args() + + logger.info("Initializing vLLM with TorchTitan model") + logger.info(f"Model: {args.model_ckpt_path}") + logger.info(f"Tensor Parallel Size: {args.tensor_parallel_size}") + + # Initialize vLLM with custom TorchTitan model + # The LLM initialization will internally: + # 1. Load TrainSpec for Qwen3 (from models/__init__.py register()) + # 2. Create TorchTitanVLLMModel instance + # 3. Create JobConfig and ParallelDims from vLLM config + # 4. Apply parallelization using parallelize_qwen3 + # 5. Load model weights and prepare for inference + # The tensor_parallel_size will be used by vLLM to configure parallelization + # and will be available in vllm_config in worker processes + logger.info("Creating vLLM LLM engine...") + + llm = LLM( + model=args.model_ckpt_path, # Model checkpoint path + hf_overrides={ + # Override architectures to use our registered TorchTitan model class + "architectures": ["Qwen3TorchTitanForCausalLM"], + }, + dtype="bfloat16", + trust_remote_code=True, + enforce_eager=True, # Use eager mode + tensor_parallel_size=args.tensor_parallel_size, + gpu_memory_utilization=0.5, + ) + + logger.info("vLLM engine initialized successfully") + logger.info(f"Prompt: {args.prompt}") + + # Prepare prompt and sampling parameters + prompts = [args.prompt] + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=0.95, + max_tokens=args.max_tokens, + ) + + # Generate text + logger.info("Generating text...") + outputs = llm.generate( + prompts=prompts, + sampling_params=sampling_params, + ) + + # Print results + logger.info("Generation complete") + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + + print(f"\nPrompt: {prompt}") + print(f"Generated text: {generated_text!r}\n") + + +if __name__ == "__main__": + infer() diff --git a/torchtitan/experiments/rl/unified/infra/parallelism_utils.py b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py new file mode 100644 index 0000000000..cc57b5a85f --- /dev/null +++ b/torchtitan/experiments/rl/unified/infra/parallelism_utils.py @@ -0,0 +1,139 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Parallelization utilities for vLLM + TorchTitan models. + +This module provides functions for setting up device mesh and applying +tensor parallelism to TorchTitan models in vLLM using TorchTitan's ParallelDims. +""" + + +import torch.distributed as dist +from torchtitan.config.job_config import Comm, JobConfig, Model, Parallelism, Training +from torchtitan.distributed import utils as dist_utils + +from torchtitan.distributed.parallel_dims import ParallelDims +from vllm.config import VllmConfig +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +def create_parallel_dims_from_vllm_config(vllm_config: VllmConfig) -> ParallelDims: + """ + Create ParallelDims from vLLM config and maps vLLM parallelism settings to TorchTitan's ParallelDims dataclass. + + This function is needed because vLLM doesn't separate model creation and + parallelism application - it requires parallelization to be done inside + the model constructor, so we are creating parallel_dims and apply parallelism + in TorchTitanVLLMModelWrapper.__init__ function. + + Args: + vllm_config: vLLM configuration object + + Returns: + ParallelDims object with parallelism settings validated + + Note: + vLLM doesn't use FSDP sharding (dp_shard=1) or expert parallelism (ep=1, etp=1) + in inference. These are set to default values. + """ + world_size = dist.get_world_size() + + # Map vLLM config to TorchTitan ParallelDims + parallel_dims = ParallelDims( + dp_replicate=vllm_config.parallel_config.data_parallel_size, + dp_shard=1, # vLLM doesn't use FSDP sharding + cp=vllm_config.parallel_config.decode_context_parallel_size, + tp=vllm_config.parallel_config.tensor_parallel_size, + pp=vllm_config.parallel_config.pipeline_parallel_size, + ep=1, # Expert parallelism not used in vLLM inference yet + etp=1, # Expert tensor parallelism not used in vLLM inference yet + world_size=world_size, + ) + + logger.info( + f"Created ParallelDims from vLLM config: " + f"DP={parallel_dims.dp_replicate}, TP={parallel_dims.tp}, " + f"CP={parallel_dims.cp}, PP={parallel_dims.pp}" + ) + + return parallel_dims + + +def create_trainer_parallel_dims(ddp_size, tp_size) -> ParallelDims: + """ + Create ParallelDims for trainer with specified DDP and TP sizes. + + This function initializes the distributed process group and creates a ParallelDims + object configured for for trainer SPMD workers. + + Args: + ddp_size: Data parallel (DDP) replicate size + tp_size: Tensor parallel size + + Returns: + ParallelDims object with trainer parallelism settings + """ + world_size = dist_utils.init_distributed( + Comm(), + ) + return ParallelDims( + dp_replicate=ddp_size, + dp_shard=1, + tp=tp_size, + cp=1, + pp=1, + ep=1, + etp=1, + world_size=world_size, + ) + + +def create_job_config_from_vllm_config( + vllm_config: VllmConfig, + model_name: str = "qwen3", + hf_assets_path: str = "/path/to/hf/assets", +) -> JobConfig: + """ + Create TorchTitan JobConfig from vLLM configuration. + + Args: + vllm_config: vLLM configuration object containing model, parallel, and cache configs + model_name: Model name to use (default: "qwen3") + hf_assets_path: Path to HuggingFace assets directory (default: "/path/to/hf/assets") + + Returns: + JobConfig object with settings mapped from vLLM config + """ + # Create JobConfig with defaults + job_config = JobConfig() + + model_config = vllm_config.model_config + job_config.model = Model( + name=model_name, + hf_assets_path=hf_assets_path, + ) + + parallel_config = vllm_config.parallel_config + job_config.parallelism = Parallelism( + data_parallel_replicate_degree=parallel_config.data_parallel_size, + data_parallel_shard_degree=1, # vLLM doesn't use FSDP sharding in inference + context_parallel_degree=parallel_config.decode_context_parallel_size, + tensor_parallel_degree=parallel_config.tensor_parallel_size, + pipeline_parallel_degree=parallel_config.pipeline_parallel_size, + expert_parallel_degree=1, # Not used in vLLM inference yet + expert_tensor_parallel_degree=1, # Not used in vLLM inference yet + ) + + job_config.training = Training( + local_batch_size=1, # Inference typically processes one batch at a time + steps=1, # Single step for inference + ) + + return job_config diff --git a/torchtitan/experiments/rl/unified/infra/parallelize.py b/torchtitan/experiments/rl/unified/infra/parallelize.py new file mode 100644 index 0000000000..8cbeeed783 --- /dev/null +++ b/torchtitan/experiments/rl/unified/infra/parallelize.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# This file applies the PT-D parallelisms (except pipeline parallelism) and various +# training techniques (e.g. activation checkpointing and compile) to the Llama model. + + +import torch.nn as nn + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + RowwiseParallel, + SequenceParallel, +) + +from torchtitan.config import JobConfig +from torchtitan.distributed import ParallelDims + + +def parallelize_qwen3( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Temporary helper to apply tensor parallelism to the Qwen3 dense model so vLLM can run the torchtitan model. + """ + + if parallel_dims.tp_enabled: + tp_mesh = parallel_dims.get_mesh("tp") + apply_non_moe_tp( + model, + tp_mesh, + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism to the Qwen3 dense model. + + This is a temporary TP plan used while we resolve composability issues in the + main torchtitan codebase. Once DTensor is fully supported across the TP + region, this separate plan should be removed. + """ + + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + use_local_output=False, + ), + "norm": SequenceParallel( + use_local_output=False, + ), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Replicate(), + use_local_output=True, # return logits and plain tensor + ), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + # NOTE: At the cost of model code change, we can accelerate Sequence Parallel + # by folding (and unfolding) the batch dimension and the sequence dimension. + # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel( + use_local_output=False, + ), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout should be Replicate() + "attention": PrepareModuleInput( + input_layouts=(Shard(1), Replicate(), None, Replicate()), + desired_input_layouts=(Replicate(), Replicate(), None, Replicate()), + ), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), + "attention.q_norm": SequenceParallel( + sequence_dim=2, + use_local_output=False, + ), + "attention.k_norm": SequenceParallel( + sequence_dim=2, + use_local_output=False, + ), + # Apply on vllm.Attention() module to use local tensor + "attention.inner_attention": PrepareModuleInputOutput( + input_layouts=(Shard(1), Shard(1), Shard(1)), # xq, xk, xv + desired_input_layouts=(None, None, None), + use_local_input=True, # use local tensor for attention calculation + output_layouts=(Shard(1)), # output + desired_output_layouts=(Shard(1)), + use_local_output=False, + ), + "attention.wo": RowwiseParallel( + output_layouts=Shard(1), + use_local_output=False, + ), + "ffn_norm": SequenceParallel( + use_local_output=False, + ), + } + + # pyrefly: ignore [missing-attribute] + if not transformer_block.moe_enabled: + layer_plan.update( + { + "feed_forward": PrepareModuleInput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + "feed_forward.w1": ColwiseParallel(use_local_output=False), + "feed_forward.w2": RowwiseParallel( + output_layouts=Shard(1), use_local_output=False + ), + "feed_forward.w3": ColwiseParallel(use_local_output=False), + } + ) + else: + raise ValueError( + "Running vLLM inference with torchtitan Qwen3 MoE model is not supported yet." + ) + + parallelize_module( + # pyrefly: ignore [bad-argument-type] + module=transformer_block, + device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] + parallelize_plan=layer_plan, + ) diff --git a/torchtitan/experiments/rl/unified/models/attention.py b/torchtitan/experiments/rl/unified/models/attention.py new file mode 100644 index 0000000000..0492af2ffd --- /dev/null +++ b/torchtitan/experiments/rl/unified/models/attention.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from vllm.attention.layer import Attention + + +class VLLMAttention(torch.nn.Module): + """ + Wrapper around vLLM's Attention. Compatible with TorchTitan input shape. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_name: str, + scale: float | None = None, + ) -> None: + super().__init__() + + self.hidden_size = hidden_size + self.layer_name = layer_name + + from vllm.config import get_current_vllm_config + + vllm_config = get_current_vllm_config() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + + if scale is None: + self.scale = head_dim**-0.5 + else: + self.scale = scale + + cache_config = ( + vllm_config.cache_config if hasattr(vllm_config, "cache_config") else None + ) + + self.vllm_attn = Attention( + num_heads=num_heads, + head_size=head_dim, + scale=self.scale, + num_kv_heads=num_kv_heads, + cache_config=cache_config, + quant_config=None, + prefix=f"model.layers.{layer_name}.attention.inner_attention", + ) + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + scale: float | None = None, + ) -> torch.Tensor: + """ + Forward pass using vLLM's Attention layer for inference. + + Args: + q: Query tensor [batch, num_heads, seq_len, head_dim] + k: Key tensor [batch, num_kv_heads, seq_len, head_dim] + v: Value tensor [batch, num_kv_heads, seq_len, head_dim] + scale: Optional attention scale override (unused, vLLM uses internal scale) + + Returns: + output: [batch, num_heads, seq_len, head_dim] + """ + # Input is (batch, num_heads, seq_len, head_dim) + # TODO: may be good to use einops in future as we can explicitly reshape + # with dimension names - see https://github.com/arogozhnikov/einops + batch_size, num_heads, seq_len, head_dim = q.shape + _, num_kv_heads, _, _ = k.shape + + # vLLM expects (num_tokens, num_heads, head_dim) where num_tokens = batch * seq_len + # First transpose to (batch, seq_len, num_heads, head_dim) + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + # TODO: reimplement as a 4d tensor once vLLM fix has landed + # Then flatten batch and seq_len: (batch * seq_len, num_heads, head_dim) + q = q.reshape(batch_size * seq_len, num_heads, head_dim) + k = k.reshape(batch_size * seq_len, num_kv_heads, head_dim) + v = v.reshape(batch_size * seq_len, num_kv_heads, head_dim) + + # vLLM attention returns (num_tokens, hidden_size) where hidden_size = num_heads * head_dim + output_flat = self.vllm_attn(q, k, v) + + # Output is (batch * seq_len, num_heads * head_dim), reshape to (batch, seq_len, num_heads, head_dim) + output = output_flat.view(batch_size, seq_len, num_heads, head_dim) + + # Transpose back to TorchTitan format: (batch, num_heads, seq_len, head_dim) + output = output.transpose(1, 2) + + return output diff --git a/torchtitan/experiments/rl/unified/models/utils.py b/torchtitan/experiments/rl/unified/models/utils.py new file mode 100644 index 0000000000..0e5d6cde52 --- /dev/null +++ b/torchtitan/experiments/rl/unified/models/utils.py @@ -0,0 +1,181 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging +from enum import Enum + +import torch +from safetensors.torch import load_file + +from torchtitan.experiments.rl.unified.models.attention import VLLMAttention +from torchtitan.experiments.rl.vllm_compat.models.attention import ( + VLLMCompatibleFlashAttention, +) + +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( + torchtitan_to_vllm_compat, +) +from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig + +logger = logging.getLogger(__name__) + + +class ModelMode(str, Enum): + """ + Enum defining which TorchTitan model to use. + + Attributes: + UNIFIED: Standard TorchTitan model replaced with vLLM attention for unified + training and inference. + VLLM_COMPAT: vLLM-compatible TorchTitan model using vLLM's batch invariant kernels, + ensuring bitwise determinism between training and inference. + STANDARD: Plain TorchTitan model without any modifications. + """ + + UNIFIED = "unified" + VLLM_COMPAT = "vllm_compat" + STANDARD = "standard" + + +def replace_with_vllm_attention(model, tp_degree=1): + """ + Replace TorchTitan attention with vLLM's Attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + """ + if not hasattr(model, "layers"): + raise AttributeError( + f"Model {type(model).__name__} must have .layers attribute" + ) + + model_args = model.model_args + for layer_name, layer in model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + vllm_attn = VLLMAttention( + hidden_size=model_args.dim, + num_heads=model_args.n_heads // tp_degree, + num_kv_heads=model_args.n_heads + // tp_degree, # Use n_heads (already replicated) + head_dim=model_args.head_dim, + layer_name=layer_name, + scale=model_args.head_dim**-0.5, + ) + + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMAttention " + f"({len(model.layers)} layers)" + ) + + +def replace_with_vllm_compatible_flash_attention(model): + """ + Replace TorchTitan attention with vLLM compatible flash attention. + + Assumes model has .layers dict with .attention.inner_attention structure. + """ + if not hasattr(model, "layers"): + raise AttributeError( + f"Model {type(model).__name__} must have .layers attribute" + ) + + model_args = model.model_args + for layer_name, layer in model.layers.items(): + if not hasattr(layer, "attention"): + raise ValueError(f"Layer {layer_name} must have .attention attribute") + + vllm_attn = VLLMCompatibleFlashAttention() + + layer.attention.inner_attention = vllm_attn + + logger.info( + f"Successfully replaced TorchTitan attention with VLLMCompatibleFlashAttention " + f"({len(model.layers)} layers)" + ) + + +def load_model( + checkpoint_path: str, model_path: str, model_mode: str = ModelMode.VLLM_COMPAT +): + """ + Load TorchTitan model from checkpoint for trainer. + + Args: + checkpoint_path: Path to TorchTitan checkpoint + model_path: Path to HuggingFace model (for config) + model_mode: Indicates which model to use. Train inferece unified model, batch invariant Torchtitan model, + or plain Torchtitan model + + Returns: + model: Loaded TorchTitan model for trainer. + """ + # Load HuggingFace config + # TODO: do not depend on transformers.AutoConfig, use qwen_args directly + hf_config = AutoConfig.from_pretrained(model_path, trust_remote_code=True) + + # Create model args + model_args = Qwen3ModelArgs( + dim=hf_config.hidden_size, + n_layers=hf_config.num_hidden_layers, + n_heads=hf_config.num_attention_heads, + n_kv_heads=hf_config.num_key_value_heads, + vocab_size=hf_config.vocab_size, + head_dim=getattr( + hf_config, + "head_dim", + hf_config.hidden_size // hf_config.num_attention_heads, + ), + hidden_dim=hf_config.intermediate_size, + norm_eps=hf_config.rms_norm_eps, + rope_theta=hf_config.rope_theta, + max_seq_len=getattr(hf_config, "max_position_embeddings", 32768), + qk_norm=True, + depth_init=True, + eos_id=getattr(hf_config, "eos_token_id", 151645), + ) + + # state_dict is in standard TorchTitan format (w1, w2, w3) + state_dict = load_file(checkpoint_path) + + if model_mode == ModelMode.UNIFIED: + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Set global default dtype to bfloat16. This is needed because vLLM's Attention + # layer uses torch.get_default_dtype() and it doesn't support float32 + torch.set_default_dtype(torch.bfloat16) + # NOTE: Override attention to vllm compatible attention for backward capability. + # Only patch to vllm compatible attention for training. + replace_with_vllm_compatible_flash_attention(model) + + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=True) + elif model_mode == ModelMode.VLLM_COMPAT: + # Create and load model that has bitwise determinism between training and inference + from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( + Qwen3VLLMCompatModel, + ) + + model = Qwen3VLLMCompatModel(model_args) + # Convert to vLLM-compat format (merged gate_up_proj, down_proj) + vllm_compat_state = torchtitan_to_vllm_compat(state_dict) + model.load_state_dict(vllm_compat_state, strict=False) + else: + # Use standard TorchTitan model + from torchtitan.models.qwen3 import Qwen3Model + + model = Qwen3Model(model_args) + # Load standard TorchTitan format directly + model.load_state_dict(state_dict, strict=False) + + model.to(torch.bfloat16) + + return model diff --git a/torchtitan/experiments/rl/unified/models/vllm_wrapper.py b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py new file mode 100644 index 0000000000..f3ae7f348a --- /dev/null +++ b/torchtitan/experiments/rl/unified/models/vllm_wrapper.py @@ -0,0 +1,323 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Base wrapper for TorchTitan models to work with vLLM V1 engine. + +This module provides TorchTitanVLLMModel: Core model class that adapts +TorchTitan models for vLLM. +""" + +from functools import partial + +import torch +import torch.nn as nn +from torch.distributed._tensor import DTensor, Replicate +from torch.distributed.checkpoint.state_dict import ( + set_model_state_dict, + StateDictOptions, +) + +from torchtitan.experiments.rl.unified.infra.parallelism_utils import ( + create_job_config_from_vllm_config, + create_parallel_dims_from_vllm_config, +) + +from torchtitan.experiments.rl.unified.models.utils import replace_with_vllm_attention +from torchtitan.models.qwen3.model.model import precompute_rope_cache +from torchtitan.protocols.model import BaseModelArgs, ModelProtocol +from torchtitan.protocols.state_dict_adapter import BaseStateDictAdapter +from torchtitan.protocols.train_spec import ParallelizeFunction + +from vllm.config import VllmConfig +from vllm.logger import init_logger + + +logger = init_logger(__name__) + + +class TorchTitanVLLMModelWrapper(nn.Module): + """ + Generic vLLM-compatible model wrapper for TorchTitan models. Implemented + required interface required by vLLM Engine. + Doc: https://docs.vllm.ai/en/latest/contributing/model/basic/ + Reference: https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py + + The wrapper handles: + - Direct usage of TorchTitan model args (no HF config mapping needed) + - Attention replacement with vLLM paged attention + - Parallelism setup and DTensor conversion between torchtitan and vLLM + - Weight loading from HF checkpoints + - vLLM forward/compute_logits interface + """ + + is_text_generation_model = True # Required for vLLM runner validation + supports_pp = False # Pipeline parallelism not supported yet + supports_multimodal = False + + def __init__( + self, + *, + model_cls: type[ModelProtocol], + model_args: BaseModelArgs, + state_dict_adapter: type[BaseStateDictAdapter], + parallelize_fn: ParallelizeFunction, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + + assert vllm_config is not None, "vllm_config is required" + + # Store components + self.model_cls = model_cls + self.state_dict_adapter = state_dict_adapter + self.parallelize_fn = parallelize_fn + + # Use TorchTitan model args directly (no HF config mapping) + self.config = model_args + logger.info(f"Creating {self.model_cls.__name__} with config: {model_args}") + self.model = self.model_cls(model_args) + + # Setup RoPE cache extension function if provided + self.rope_cache_extension_fn = partial( + precompute_rope_cache, + dim=self.config.head_dim, + base=self.config.rope_theta, + ) + + # Create ParallelDims and JobConfig from vLLM config at runtime + # vLLM config contains the tensor_parallel_size from command-line args + # and this will be consistent across all worker processes + self.parallel_dims = create_parallel_dims_from_vllm_config(vllm_config) + self.parallel_config = create_job_config_from_vllm_config( + vllm_config=vllm_config, + ) + # Replace attention with vLLM paged attention + tp_size = self.parallel_dims.tp + if tp_size > 1: + assert ( + model_args.n_heads % tp_size == 0 + ), "Only support when n_heads can be divided by tp_size" + + replace_with_vllm_attention(self.model, tp_degree=tp_size) + + # NOTE: We need to apply parallelize within model.__init__ because vllm + # doesn't separate model creation and parallelism application and instead + # requires parallelization to be done inside model constructor. + self.model = parallelize_fn( + model=self.model, + parallel_dims=self.parallel_dims, + job_config=self.parallel_config, + ) + + def _extend_rope_cache_if_needed( + self, rope_cache: torch.Tensor, max_position: int + ) -> torch.Tensor: + """ + Extend RoPE cache if needed during vLLM profiling stage. + + Args: + rope_cache: Current RoPE cache tensor + max_position: Maximum position index needed + + Returns: + Extended RoPE cache if needed, otherwise original cache + """ + required_len = max_position + 1 + + # No extension needed + if required_len <= rope_cache.shape[0]: + return rope_cache + + # If no extension function provided, return original cache + if self.rope_cache_extension_fn is None: + logger.warning( + f"RoPE cache extension needed (required_len={required_len}, " + f"current_len={rope_cache.shape[0]}) but no rope_cache_extension_fn provided. " + "Returning original cache." + ) + return rope_cache + + # Handle DTensor case + is_dtensor = isinstance(rope_cache, DTensor) + if is_dtensor: + device_mesh = rope_cache.device_mesh + local_rope_cache = rope_cache.to_local() + device = local_rope_cache.device + dtype = local_rope_cache.dtype + else: + device = rope_cache.device + dtype = rope_cache.dtype + + # Use provided extension function + try: + extended_cache = self.rope_cache_extension_fn(self.config, required_len) + extended_cache = extended_cache.to(device=device, dtype=dtype) + except Exception as e: + logger.warning( + f"Failed to extend RoPE cache using rope_cache_extension_fn: {e}. " + "Returning original cache." + ) + return rope_cache + + # Convert back to DTensor if needed + if is_dtensor: + rope_cache = DTensor.from_local( + extended_cache, + device_mesh=device_mesh, + placements=[Replicate()], + ) + else: + rope_cache = extended_cache + + return rope_cache + + def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings.""" + return self.model.tok_embeddings(input_ids) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + """Convert input token IDs to embeddings (deprecated vLLM interface).""" + return self.embed_input_ids(input_ids) + + def forward( + self, + input_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, + inputs_embeds: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + """ + Forward pass with vLLM interface. + + Args: + input_ids: Token IDs [total_tokens] (1D varlen format) + positions: Position indices [total_tokens] (1D varlen format) + inputs_embeds: Pre-computed embeddings (optional) + **kwargs: Additional vLLM kwargs + + Returns: + hidden_states: Final hidden states [total_tokens, hidden_size] + """ + if inputs_embeds is not None: + raise NotImplementedError("inputs_embeds not yet supported") + + if input_ids is None: + raise ValueError("Either input_ids or inputs_embeds must be provided") + + # Convert vLLM interface to TorchTitan interface + # vLLM: [total_tokens] → TorchTitan: [batch_size, seq_len] + tokens_2d = input_ids.unsqueeze(0) + + # Get embeddings + h = self.model.tok_embeddings(tokens_2d) + + # Get RoPE cache (handle model-specific attribute names) + # Use hasattr to avoid ambiguous boolean value error with tensors + if hasattr(self.model, "rope_cache"): + rope_attr = self.model.rope_cache + elif hasattr(self.model, "freqs_cis"): + rope_attr = self.model.freqs_cis + else: + rope_attr = None + + # Extend RoPE cache if needed (vLLM profiling may use 2x max_seq_len) + if positions is not None: + max_position = positions.max().item() + else: + max_position = 0 + + rope_cache = self._extend_rope_cache_if_needed(rope_attr, max_position) + positions = positions.unsqueeze(0) + + # Pass through transformer layers + for layer in self.model.layers.values(): + h = layer(h, rope_cache, attention_masks=None, positions=positions) + + # When parallelism is applied, get full tensor before return to vLLM Engine + # The original placement is Shard(1) (shard on sequence dimension, as it will prepare for sequence parallel in `self.norm`). + # vLLM’s engine expects plain, non-distributed tensors to slice the last token for each request. + if isinstance(h, DTensor): + h = h.full_tensor() + + # Convert to vLLM format: [total_tokens, hidden_size] + if h.dim() == 3: + batch_size, seq_len, hidden_size = h.shape + h = h.view(batch_size * seq_len, hidden_size) + + return h + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata=None, + ) -> torch.Tensor | None: + """Compute logits from hidden states.""" + + # When TP is applied, we return the full tensor (plain tensor) to vLLM engine + # at the end of TorchTitanVLLMModelWrapper.forward(). + # We need to wrap the input from vLLM engine back to DTensor with Replicate() placement. + if self.parallel_dims.tp_enabled: + hidden_states = DTensor.from_local( + hidden_states, + device_mesh=self.parallel_dims.get_mesh("tp"), + placements=[ + Replicate(), + ], + ) + + h = self.model.norm(hidden_states) + logits = self.model.output(h) + + return logits + + def load_weights(self, weights_iter): + """ + Load weights from HF checkpoint using the provided state dict adapter. + vLLM engine would call this function to load model weights. + + Args: + weights_iter: Iterator of (name, tensor) pairs from HF checkpoint + + Returns: + Set of loaded parameter names + """ + # Collect weights from iterator + hf_state_dict = {} + for name, tensor in weights_iter: + hf_state_dict[name] = tensor + + # Use adapter to convert HF → TorchTitan format + adapter = self.state_dict_adapter( + model_args=self.config, + hf_assets_path=None, + ) + + torchtitan_state_dict = adapter.from_hf(hf_state_dict) + model_state_dict = {k: v for k, v in self.model.state_dict().items()} + + # Convert to DTensor if target is DTensor + for name, tensor in torchtitan_state_dict.items(): + if name in model_state_dict and isinstance(model_state_dict[name], DTensor): + target_dtensor = model_state_dict[name] + device_mesh = target_dtensor.device_mesh + torchtitan_state_dict[name] = DTensor.from_local( + tensor.to(device_mesh.device_type), + device_mesh=device_mesh, + placements=[Replicate()], + ) + + # Load state dict + set_model_state_dict( + model=self.model, + model_state_dict=torchtitan_state_dict, + options=StateDictOptions(strict=False), + ) + + loaded_params = {f"model.{name}" for name in torchtitan_state_dict.keys()} + + return loaded_params diff --git a/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py new file mode 100644 index 0000000000..3e914f3778 --- /dev/null +++ b/torchtitan/experiments/rl/unified/simple_rl_multiprocess.py @@ -0,0 +1,190 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multiprocess RL training loop using Monarch Actors. + +This demonstrates: +1. Distributed actor architecture with Generator (vLLM) and Trainer (TorchTitan) components +2. File based weight synchronization between trainer and generator + +The architecture mirrors monarch's grpo_actor.py but adapted for vLLM rollouts + TorchTitan training. + +Command to run: +VLLM_BATCH_INVARIANT=1 VLLM_ATTENTION_BACKEND=FLASH_ATTN python3 torchtitan/experiments/rl/unified/simple_rl_multiprocess.py +""" +import asyncio +import logging + +import torch +from monarch.actor import this_host +from monarch.utils import setup_env_for_distributed +from torchtitan.experiments.rl.unified.actors.generator import Generator +from torchtitan.experiments.rl.unified.actors.trainer import Trainer +from torchtitan.experiments.rl.unified.models.utils import ModelMode +from torchtitan.experiments.rl.vllm_compat.simple_rl import ( + download_and_convert_model, + load_gsm8k_dataset, +) +from vllm.model_executor.layers.batch_invariant import ( + init_batch_invariance, + vllm_is_batch_invariant, +) + +logger = logging.getLogger(__name__) + + +async def main(): + """Run the distributed RL training loop using Monarch.""" + # Model Config + model_name = "Qwen/Qwen3-0.6B" + cache_dir = "./models" + output_dir = "./converted" + + # Training config + group_size = 8 + num_steps = 10 + learning_rate = 1e-5 + max_new_tokens = 20 + + # GRPO config + use_stable_grpo = False + grpo_beta = 0.1 + + # Dataset config + use_real_dataset = False + num_dataset_samples = 5 + + # Parallelism sizes + trainer_ddp_size = 2 + trainer_tp_size = 1 + generator_tp_size = 1 + + init_batch_invariance() + batch_invariant = vllm_is_batch_invariant() + mode = ModelMode.UNIFIED + + # Set up batch invariant + if batch_invariant: + logger.info("Batch invariance detected - using vLLM-compatible model") + from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( + enable_batch_invariant_backward_mode, + ) + + enable_batch_invariant_backward_mode() + else: + raise RuntimeError("Batch invariance NOT detected - using standard model") + + # Download and convert model + titan_checkpoint_path, model_path = download_and_convert_model( + model_name, cache_dir, output_dir + ) + + # Load dataset + if use_real_dataset: + logger.info(f"Loading GSM8K dataset ({num_dataset_samples} samples)...") + # TODO: Refactor into loading torchtitan dataset + prompt_texts, expected_answers = load_gsm8k_dataset( + split="train", num_samples=num_dataset_samples + ) + if prompt_texts is None or len(prompt_texts) == 0: + use_real_dataset = False + + if not use_real_dataset: + logger.info("Using default prompts") + prompts_with_answers = [ + ("The capital of France is", "paris"), + ("What is 7 times 8?", "56"), + ("The first president of the United States was", "washington"), + ("The chemical symbol for water is", "h2o"), + ("The largest planet in our solar system is", "jupiter"), + ] + prompt_texts = [p[0] for p in prompts_with_answers] + expected_answers = [p[1] for p in prompts_with_answers] + + logger.info(f"Loaded {len(prompt_texts)} prompts") + + # Create process meshes + trainer_mesh = this_host().spawn_procs(per_host={"gpus": 2}) + gen_mesh = this_host().spawn_procs(per_host={"gpus": 1}) + + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + trainer_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29500, # TODO: figure out what to set + ) + + # Set up distributed env vars so that actors are connected via c10d + await setup_env_for_distributed( + gen_mesh, + master_addr="localhost", # TODO: figure out what to set + master_port=29501, # TODO: figure out what to set + ) + + # Spawn actors on trainer and generator mesh + trainer = trainer_mesh.spawn( + "trainer", + Trainer, + titan_checkpoint_path, + model_path, + learning_rate, + mode, + trainer_ddp_size, + trainer_tp_size, + ) + + generator = gen_mesh.spawn( + "generator", + Generator, + model_path, + prompt_texts, + expected_answers, + group_size, + max_new_tokens, + 1.0, # temperature + use_real_dataset, + grpo_beta, + use_stable_grpo, + generator_tp_size, + ) + + # Initialize generator with trainer weights + initial_weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call(0, initial_weights) + + # Training loop + logger.info("\n" + "=" * 80) + logger.info(f"Starting RL training for {num_steps} steps") + logger.info("=" * 80) + + for step in range(num_steps): + # Fully sync RL loop + batch = generator.generate.call().get().item(gpus=0) + metrics = trainer.step.call(batch).get().item(gpus=0) + weights = trainer.get_weights.call().get().item(gpus=0) + await generator.update.call(metrics["policy_version"], weights) + + logger.info( + f"\nStep {step:3d} | Loss: {metrics['loss']:.4f} | " + f"Reward: {metrics['reward_mean']:+.3f}" + ) + logger.info(f" Sample: {metrics['sample_completion']}...") + + # Check for divergence + if not torch.isfinite(torch.tensor(metrics["loss"])): + logger.info("\n" + "!" * 80) + logger.info("ERROR: Loss is NaN/Inf! Training diverged.") + logger.info("!" * 80) + break + + logger.info("\n" + "=" * 80) + logger.info("RL Training complete") + logger.info("=" * 80) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/torchtitan/experiments/deterministic_vllm_rl/README.md b/torchtitan/experiments/rl/vllm_compat/README.md similarity index 92% rename from torchtitan/experiments/deterministic_vllm_rl/README.md rename to torchtitan/experiments/rl/vllm_compat/README.md index d2ef719c0d..84df62d3ed 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/README.md +++ b/torchtitan/experiments/rl/vllm_compat/README.md @@ -51,8 +51,12 @@ Note: Currently supports single-device training only. ### Prerequisites ```bash -# Install vLLM with deterministic support -pip install vllm +# Install vLLM with deterministic support (from source) +git clone https://github.com/vllm-project/vllm.git +cd vllm +python use_existing_torch.py +uv pip install -r requirements/build.txt +uv pip install --no-build-isolation -e . # Install TorchTitan (from the repository root) pip install -e . @@ -75,15 +79,17 @@ init_batch_invariance() ### Quick Start ```python -import torch from vllm.model_executor.layers.batch_invariant import init_batch_invariance -from torchtitan.experiments.deterministic_vllm_rl import ( +from vllm.v1.attention.backends.registry import AttentionBackendEnum + +import torch +from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, Qwen3VLLMCompatModel, ) # 1. Enable deterministic mode -init_batch_invariance() +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) enable_batch_invariant_backward_mode() # 2. Load model @@ -95,7 +101,7 @@ model_args = Qwen3ModelArgs( n_kv_heads=2, vocab_size=151936, ) -model = Qwen3VLLMCompatModel(model_args) +model = Qwen3VLLMCompatModel(model_args).to('cuda').to(torch.bfloat16) # 3. Forward pass (deterministic) input_ids = torch.randint(0, 151936, (2, 128), device='cuda') @@ -104,6 +110,9 @@ logits = model(input_ids) # 4. Backward pass loss = logits.sum() loss.backward() + +print("Done running simple model") + ``` ### Full RL Training @@ -111,7 +120,7 @@ loss.backward() Run the RL training loop: ```bash -VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.deterministic_vllm_rl.simple_rl +VLLM_BATCH_INVARIANT=1 VLLM_FLASH_ATTN_VERSION=3 python -m torchtitan.experiments.rl.vllm_compat.simple_rl ``` This will: @@ -177,7 +186,7 @@ assert torch.equal(vllm_logprobs, titan_logprobs) Run the test suite: ```bash -cd torchtitan/experiments/deterministic_vllm_rl/tests +cd torchtitan/experiments/rl/vllm_compat/tests # Test backward passes python test_batch_invariant_backward.py @@ -214,7 +223,7 @@ This implementation uses the same kernels for both rollouts (vLLM) and training ## Project Structure ``` -deterministic_vllm_rl/ +rl/vllm_compat/ ├── README.md # Documentation ├── __init__.py # Package initialization ├── batch_invariant_backward.py # Backward passes for vLLM ops diff --git a/torchtitan/experiments/deterministic_vllm_rl/__init__.py b/torchtitan/experiments/rl/vllm_compat/__init__.py similarity index 53% rename from torchtitan/experiments/deterministic_vllm_rl/__init__.py rename to torchtitan/experiments/rl/vllm_compat/__init__.py index 067555251f..b86721fba5 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/__init__.py @@ -5,16 +5,10 @@ # LICENSE file in the root directory of this source tree. """ -Deterministic RL training with vLLM experiment. +vLLM-Compatible approach for deterministic RL training. -This experiment provides tools for bitwise-deterministic reinforcement learning -training using vLLM for fast rollouts and TorchTitan for training. - -Key components: -- VLLMCompatibleFlashAttention: Flash attention with custom backward pass -- Qwen3VLLMCompatModel: vLLM-compatible model with merged projections -- batch_invariant_backward: Gradient support for vLLM's deterministic operations -- simple_rl: End-to-end RL training loop +This module provides models that match vLLM's weight format (e.g., merged gate_up_proj) +with custom backward passes for gradient computation during training. """ from .batch_invariant_backward import ( @@ -22,9 +16,10 @@ rms_norm_with_gradients, silu_and_mul_with_gradients, ) -from .models import VLLMCompatibleFlashAttention +from .models.attention import VLLMCompatibleFlashAttention from .models.qwen3 import Qwen3VLLMCompatModel + __all__ = [ "VLLMCompatibleFlashAttention", "Qwen3VLLMCompatModel", diff --git a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py index faccf8265d..b67244478e 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/batch_invariant_backward.py @@ -62,10 +62,14 @@ def forward(ctx, x): Returns: output: silu(gate) * up, shape [..., hidden_dim] """ + from vllm.config import set_current_vllm_config, VllmConfig from vllm.model_executor.layers.activation import SiluAndMul as VLLMSiluAndMul # Use vLLM's implementation for forward - vllm_silu_and_mul = VLLMSiluAndMul() + # vLLM custom ops require a config context to be set + # Since these are parameter free we instantiate default config + with set_current_vllm_config(VllmConfig()): + vllm_silu_and_mul = VLLMSiluAndMul() output = vllm_silu_and_mul(x) # Save for backward diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/__init__.py similarity index 74% rename from torchtitan/experiments/deterministic_vllm_rl/models/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/__init__.py index c8c11a170a..2e7a5fa6af 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/__init__.py +++ b/torchtitan/experiments/rl/vllm_compat/models/__init__.py @@ -6,8 +6,13 @@ """ Models for deterministic vLLM RL training. + +This module provides vLLM-compatible model components. """ from .attention import VLLMCompatibleFlashAttention -__all__ = ["VLLMCompatibleFlashAttention"] + +__all__ = [ + "VLLMCompatibleFlashAttention", +] diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py b/torchtitan/experiments/rl/vllm_compat/models/attention.py similarity index 95% rename from torchtitan/experiments/deterministic_vllm_rl/models/attention.py rename to torchtitan/experiments/rl/vllm_compat/models/attention.py index 33dd5a140d..752b416922 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/attention.py +++ b/torchtitan/experiments/rl/vllm_compat/models/attention.py @@ -4,12 +4,11 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -vLLM-compatible Flash Attention implementation for deterministic RL training. -""" + +import math import torch -from vllm.vllm_flash_attn import flash_attn_varlen_func +from vllm.v1.attention.backends.fa_utils import flash_attn_varlen_func class VLLMCompatibleFlashAttention(torch.nn.Module): @@ -18,8 +17,8 @@ class VLLMCompatibleFlashAttention(torch.nn.Module): def __init__(self) -> None: super().__init__() self.flash_attn_varlen_func = flash_attn_varlen_func - from vllm.attention.utils.fa_utils import get_flash_attn_version from vllm.model_executor.layers.batch_invariant import vllm_is_batch_invariant + from vllm.v1.attention.backends.fa_utils import get_flash_attn_version self.vllm_is_batch_invariant = vllm_is_batch_invariant self.fa_version = get_flash_attn_version() @@ -56,6 +55,10 @@ def forward( 0, (batch_size + 1) * seq_len, seq_len, dtype=torch.int32, device=q.device ) + # Scaling factor applied prior to softmax. If none, the default value is set to :math:`\frac{1}{\sqrt{E}}`. + if scale is None: + scale = 1.0 / math.sqrt(q.size(-1)) + # Wrap Flash Attention with manual backward pass class FlashAttnWithBackward(torch.autograd.Function): @staticmethod diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/__init__.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py index dd84665091..8dab20908c 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/models/qwen3/model_vllm_compat.py +++ b/torchtitan/experiments/rl/vllm_compat/models/qwen3/model_vllm_compat.py @@ -13,7 +13,7 @@ from torchtitan.components.tokenizer import BaseTokenizer # Import gradient-enabled operations from experiment utilities -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.rl.vllm_compat.batch_invariant_backward import ( rms_norm_with_gradients, silu_and_mul_with_gradients, ) @@ -277,14 +277,14 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Qwen3VLLMCompatModel(nn.Module, ModelProtocol): +class Qwen3VLLMCompatModel(ModelProtocol): """ Qwen3 model with vLLM-compatible implementation. Uses merged gate_up projections and vLLM Flash Attention. """ def __init__(self, model_args: Qwen3ModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers diff --git a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py b/torchtitan/experiments/rl/vllm_compat/simple_rl.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/simple_rl.py rename to torchtitan/experiments/rl/vllm_compat/simple_rl.py index ffc7d52eb0..bd9225bbe6 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/simple_rl.py +++ b/torchtitan/experiments/rl/vllm_compat/simple_rl.py @@ -25,22 +25,24 @@ from huggingface_hub import snapshot_download from safetensors.torch import load_file, save_file from torch.utils.tensorboard import SummaryWriter -from transformers import AutoConfig, AutoTokenizer -from vllm import LLM, SamplingParams -from vllm.model_executor.layers.batch_invariant import init_batch_invariance - -from torchtitan.experiments.deterministic_vllm_rl.weights.converter import ( +from torchtitan.experiments.rl.vllm_compat.weights.converter import ( torchtitan_to_vllm, vllm_to_torchtitan, ) -from torchtitan.experiments.deterministic_vllm_rl.weights_vllm_compat import ( +from torchtitan.experiments.rl.vllm_compat.weights_vllm_compat import ( torchtitan_to_vllm_compat, ) from torchtitan.models.qwen3.model.args import Qwen3ModelArgs +from transformers import AutoConfig, AutoTokenizer + +from vllm import LLM, SamplingParams +from vllm.model_executor.layers.batch_invariant import init_batch_invariance +from vllm.v1.attention.backends.registry import AttentionBackendEnum -init_batch_invariance() + +init_batch_invariance(AttentionBackendEnum.FLASH_ATTN) class VLLMRolloutEngine: @@ -170,6 +172,7 @@ def update_weights(self, vllm_compat_state: dict) -> None: gpu_memory_utilization=0.3, # Reduced from 0.5 seed=42, # Fixed seed for determinism enforce_eager=True, + attention_config={"backend": AttentionBackendEnum.FLASH_ATTN}, ) print("✓ Created new vLLM engine") else: @@ -340,7 +343,7 @@ def load_model(checkpoint_path: str, model_path: str, use_vllm_compat: bool = Tr if use_vllm_compat: # Create and load model (using vLLM-compat for bitwise determinism) - from torchtitan.experiments.deterministic_vllm_rl.models.qwen3 import ( + from torchtitan.experiments.rl.vllm_compat.models.qwen3 import ( Qwen3VLLMCompatModel, ) @@ -481,7 +484,6 @@ def load_gsm8k_dataset(split: str = "train", num_samples: int = 100): def trivial_reward_function( completions: list[str], - tokenizer=None, expected_answers: list[str] | None = None, group_size: int = 4, ) -> torch.Tensor: @@ -494,7 +496,6 @@ def trivial_reward_function( Args: completions: List of completion strings - tokenizer: Tokenizer to count tokens expected_answers: List of expected answers (one per prompt, repeated for group_size) group_size: Number of samples per prompt @@ -891,12 +892,7 @@ def rl_update_step( ) # Compute rewards using provided reward function - if reward_fn == trivial_reward_function: - rewards = reward_fn(completions, tokenizer, expected_answers, group_size) - elif reward_fn == math_reward_function: - rewards = reward_fn(completions, expected_answers, group_size) - else: - rewards = reward_fn(completions, expected_answers, group_size) + rewards = reward_fn(completions, expected_answers, group_size) # Normalize rewards for stability (mean=0, std=1) reward_mean = rewards.mean() @@ -1058,7 +1054,7 @@ def main(): print("✓ Batch invariance detected - using vLLM-compatible model") # Add backward pass support to vLLM's batch_invariant mode print(" Adding gradient support to vLLM's batch_invariant mode...") - from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( + from torchtitan.experiments.rl.vllm_compat import ( enable_batch_invariant_backward_mode, ) diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py b/torchtitan/experiments/rl/vllm_compat/tests/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/tests/__init__.py rename to torchtitan/experiments/rl/vllm_compat/tests/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py similarity index 97% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py index 3ed9604d10..ddf8b01514 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_batch_invariant_backward.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_batch_invariant_backward.py @@ -8,9 +8,11 @@ Test batch_invariant_backward module to ensure it works correctly. """ +import sys + import torch -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( disable_batch_invariant_backward_mode, enable_batch_invariant_backward_mode, linear_batch_invariant_backward, diff --git a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py similarity index 98% rename from torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py rename to torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py index 8d0ac3133e..2a9863ab2f 100644 --- a/torchtitan/experiments/deterministic_vllm_rl/tests/test_exact_determinism.py +++ b/torchtitan/experiments/rl/vllm_compat/tests/test_exact_determinism.py @@ -11,11 +11,11 @@ """ import torch -from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode -from torchtitan.experiments.deterministic_vllm_rl.batch_invariant_backward import ( +from torchtitan.experiments.deterministic_vllm_rl.vllm_compat.batch_invariant_backward import ( enable_batch_invariant_backward_mode, ) +from vllm.model_executor.layers.batch_invariant import disable_batch_invariant_mode print("Enabling batch_invariant_backward mode...") disable_batch_invariant_mode() diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/README.md b/torchtitan/experiments/rl/vllm_compat/weights/README.md similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/README.md rename to torchtitan/experiments/rl/vllm_compat/weights/README.md diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py b/torchtitan/experiments/rl/vllm_compat/weights/__init__.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/__init__.py rename to torchtitan/experiments/rl/vllm_compat/weights/__init__.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights/converter.py b/torchtitan/experiments/rl/vllm_compat/weights/converter.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights/converter.py rename to torchtitan/experiments/rl/vllm_compat/weights/converter.py diff --git a/torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py b/torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py similarity index 100% rename from torchtitan/experiments/deterministic_vllm_rl/weights_vllm_compat.py rename to torchtitan/experiments/rl/vllm_compat/weights_vllm_compat.py diff --git a/torchtitan/experiments/simple_fsdp/README.md b/torchtitan/experiments/simple_fsdp/README.md index a49fa8ad56..a1d40cf2b1 100644 --- a/torchtitan/experiments/simple_fsdp/README.md +++ b/torchtitan/experiments/simple_fsdp/README.md @@ -3,11 +3,13 @@ [![integration and numerics tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_simple_fsdp.yaml?query=branch%3Amain) [![arXiv](https://img.shields.io/badge/arXiv-2411.00284-b31b1b.svg)](https://arxiv.org/abs/2411.00284) -💡 **Note**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via +💡 **Note 1**: SimpleFSDP's composability with Mixed Precision Training and Tensor Parallel requires updates from latest PyTorch, which can be installed (e.g., for CUDA 12.6) via ```bash pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126 --force-reinstall ``` +💡 **Note 2**: Some of SimpleFSDP's functionalities (e.g., reshard_after_forward) is implemented with torch.compile. It is always recommended to open compile (`--compile.enable`) to see desired correct functionality. + This folder includes an experimental frontend implementation for [SimpleFSDP: Simpler Fully Sharded Data Parallel with torch.compile](https://arxiv.org/abs/2411.00284). SimpleFSDP is a compiler-based Fully Sharded Data Parallel (FSDP) framework, which has a simple implementation for maintenance and composability, allows full computation-communication graph tracing, and brings performance enhancement via compiler backend optimizations. ### Run SimpleFSDP Training on Llama3 & DeepSeek_v3 @@ -50,14 +52,16 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing 1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager") 2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance** - - "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. - - -users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs: - -```bash ---compile.model_backend_override "aot_eager_autobucketing" -``` + - "auto_bucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.graph_passes "auto_bucketing" + ``` + +3. manual optimization: perform manual bucketing & reordering with user FQN inputs. + - "transformer_block_bucketing": perform bucketing by transformer blocks at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend). + ```bash + --compile.backend "aot_eager" --compile.graph_passes "transformer_block_bucketing" + ``` ### Citation diff --git a/torchtitan/experiments/simple_fsdp/backend.py b/torchtitan/experiments/simple_fsdp/backend.py index 36abe4ad0b..7fc9d13bf4 100644 --- a/torchtitan/experiments/simple_fsdp/backend.py +++ b/torchtitan/experiments/simple_fsdp/backend.py @@ -4,46 +4,152 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Union +from typing import Any import torch +import torch._functorch.config as functorch_config +from torchtitan.tools.logging import logger +from .job_config import Compile as CompileConfig -def get_compile_backend(backend_name: str) -> Union[str, callable]: - # return the compile backends used in SimpleFSDP training - # Step1: check if backend_name is inside available torch.compile backends - # Step2: check if the backend_name has been registered as a customized backend - available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) - if backend_name in available_torch_backend: - return backend_name +from .reshard_after_forward import annotate_fsdp_all_gather - if backend_name == "aot_eager_autobucketing": - # Perform auto optimization in aten fx-level and execute code in aot_eager backend - # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 - from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend +def get_compile_backend_with_passes( + compile_config: CompileConfig, + fsdp_reshard_after_forward: bool, + fsdp_manual_buckets: list[list[str] | str] | None, +) -> callable: + """ + Apply compile backend and additional graph passes. + Args: + compile_config: compile configs to apply torch.compile. + fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP, + which is implemented via a customized AC graph pass. + fsdp_manual_buckets: used in transformer_block_bucketing to define which modules should be bucketed. + Returns: + compile backend with applied graph passes. + """ + backend = torch._dynamo.lookup_backend(compile_config.backend) + + # Apply bucketing and overlapping pass on fwd and bwd graph separately + if compile_config.graph_passes == "auto_bucketing": + # Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend + # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 from torch._inductor.config import aten_distributed_optimizations as dist_opts from torch._inductor.fx_passes.overlap_scheduling import ( schedule_overlap_bucketing, ) dist_opts.collective_bucketing = True - dist_opts.insert_overlap_deps = False torch._inductor.config.allow_buffer_reuse = False - def aten_autobucketing_reordering_pass( - gm: torch.fx.GraphModule, example_inputs: Any - ) -> torch.fx.GraphModule: - schedule_overlap_bucketing(gm) - gm.recompile() - return gm - - backend = aot_autograd_backend( - fw_compiler=aten_autobucketing_reordering_pass, - bw_compiler=aten_autobucketing_reordering_pass, - keep_inference_input_mutations=True, + if compile_config.backend == "aot_eager": + from torch._dynamo.backends.common import ( + aot_autograd as aot_autograd_backend, + ) + + def aot_eager_autobucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + schedule_overlap_bucketing(gm) + gm.recompile() + return gm + + dist_opts.insert_overlap_deps = False + backend = aot_autograd_backend( + fw_compiler=aot_eager_autobucketing_reordering_pass, + bw_compiler=aot_eager_autobucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_autobucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return schedule_overlap_bucketing(gm.owning_module) + + dist_opts.insert_overlap_deps = True + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_autobucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for auto_bucketing pass" + ) + logger.info("Auto bucketing pass is applied") + + elif compile_config.graph_passes == "transformer_block_bucketing": + # Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend + # The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487 + from functools import partial + + from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend + from torch._inductor.fx_passes.overlap_manual_scheduling import ( + manual_overlap_bucketing, + ) + + torch._inductor.config.allow_buffer_reuse = False + manual_overlap_bucketing = partial( + manual_overlap_bucketing, + module_bucket_plans=fsdp_manual_buckets, ) + + if compile_config.backend == "aot_eager": + + def aot_eager_transformer_block_bucketing_reordering_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + manual_overlap_bucketing(gm, insert_overlap_deps=False) + return gm + + backend = aot_autograd_backend( + fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, + keep_inference_input_mutations=True, + ) + elif compile_config.backend == "inductor": + + def inductor_transformer_block_bucketing_reordering_pass( + gm: torch.fx.Graph, + ) -> torch.fx.GraphModule: + return manual_overlap_bucketing( + gm.owning_module, insert_overlap_deps=True + ) + + torch._inductor.config.reorder_for_peak_memory = False + torch._inductor.config.reorder_for_compute_comm_overlap = False + torch._inductor.config.post_grad_custom_post_pass = ( + inductor_transformer_block_bucketing_reordering_pass + ) + else: + raise ValueError( + f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" + ) + logger.info("Transformer block bucketing pass is applied") + else: - raise AssertionError(f"Unsupported customized backend: {backend_name}") + logger.info("No bucketing or overlapping pass is applied") + + # Apply activation checkpointing on joint graph before partitioner + def joint_ac_pass( + gm: torch.fx.GraphModule, example_inputs: Any + ) -> torch.fx.GraphModule: + # this pass implements simplefsdp's fsdp_reshard_after_forward behavior + # when fsdp_reshard_after_forward set to True, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_RECOMPUTE. + # when fsdp_reshard_after_forward set to False, it will annotate simple_fsdp AG + # to CheckpointPolicy.MUST_SAVE. + gm = annotate_fsdp_all_gather(gm, fsdp_reshard_after_forward) + gm.recompile() + return gm + + def simple_fsdp_custom_pass(*args, **kwargs): + # the ac pass has to operate in a joint graph before partitioner for ac + # annotation to take into effect. + with functorch_config.patch("joint_custom_pass", joint_ac_pass): + return backend(*args, **kwargs) - return backend + return simple_fsdp_custom_pass diff --git a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py index ac6f9bdc9b..9bbaba0ef5 100644 --- a/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py @@ -10,24 +10,50 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.deepseek_v3.infra.parallelize import ( - apply_ac, apply_moe_ep_tp, apply_non_moe_tp, ) from torchtitan.tools.logging import logger +from ..backend import get_compile_backend_with_passes + from ..simple_fsdp import data_parallel, MixedPrecisionPolicy +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + # [TODO](ruisizhang123) add EP support for transformer block bucketing + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + result.append(convert_modules_to_fqns(m, module_to_fqn_mapping)) + else: + result.append(module_to_fqn_mapping.get(m, None)) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + # Adapted from llama4/infra/parallelize.py def parallelize_deepseekv3( model: nn.Module, parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -40,9 +66,9 @@ def parallelize_deepseekv3( if ( job_config.parallelism.context_parallel_degree > 1 - and model.model_args.use_flex_attn + and model.model_args.attn_type != "sdpa" ): - raise NotImplementedError("CP support for FlexAttention is still in progress.") + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -58,29 +84,22 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, + positions_enabled=parallel_dims.cp_enabled or job_config.training.dataset_type == "preprocessed", ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), ) if job_config.activation_checkpoint.mode != "none": @@ -91,20 +110,6 @@ def parallelize_deepseekv3( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - # apply data parallel dp_mesh: DeviceMesh | None = None if ( @@ -114,38 +119,38 @@ def parallelize_deepseekv3( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] - # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] + dp_mesh = parallel_dims.get_mesh(dp_mesh_dim_names) - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") - dp_mod_ep_mesh = world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) for _, transformer_block in model.layers.items(): if transformer_block.moe_enabled and parallel_dims.ep_enabled: experts_shard_dim = 0 - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * parallel_dims.ep + edp_mesh["efsdp"].size() * parallel_dims.ep > transformer_block.moe.experts.num_experts ): experts_shard_dim = 1 # when EP is enable, the routed experts' gradient reduction is done over - # dp_mod_ep_mesh instead of whole dp_mesh. + # edp_mesh instead of whole dp_mesh. # we add a `fsdp_gradient_divide_factor` to scale gradient over dp_mesh # to be consistent with data. # TODO (ruisizhang123): update the logic following the link below instead @@ -153,11 +158,9 @@ def parallelize_deepseekv3( # https://github.com/pytorch/torchtitan/pull/1803#discussion_r2415190883 transformer_block.moe.experts = data_parallel( transformer_block.moe.experts, - dp_mod_ep_mesh, + edp_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, shard_dim=experts_shard_dim, reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, ) @@ -166,9 +169,7 @@ def parallelize_deepseekv3( model, dp_mesh, dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( @@ -178,6 +179,30 @@ def parallelize_deepseekv3( if job_config.compile.enable: torch._inductor.config.reorder_for_peak_memory = False torch._dynamo.config.capture_scalar_outputs = True - model = torch.compile(model, backend=job_config.compile.backend, fullgraph=True) + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + backend = get_compile_backend_with_passes( + job_config.compile, + fsdp_reshard_after_forward, + get_transformer_block_buckets(model), + ) + model = torch.compile( + model, + backend=backend, + fullgraph=True, + ) return model diff --git a/torchtitan/experiments/simple_fsdp/job_config.py b/torchtitan/experiments/simple_fsdp/job_config.py index a7e7c4c22f..f752fa1170 100644 --- a/torchtitan/experiments/simple_fsdp/job_config.py +++ b/torchtitan/experiments/simple_fsdp/job_config.py @@ -5,12 +5,16 @@ # LICENSE file in the root directory of this source tree. from dataclasses import dataclass, field +from typing import Literal @dataclass class Compile: - model_backend_override: str | None = None - """Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing""" + graph_passes: Literal["auto_bucketing", "transformer_block_bucketing"] | None = None + """ + Bucketing and overlapping passes in simplefsdp. Additional passes include: + auto_bucketing, transformer_block_bucketing + """ @dataclass diff --git a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py index d61e74a5dd..d64a8b79fc 100644 --- a/torchtitan/experiments/simple_fsdp/llama3/parallelize.py +++ b/torchtitan/experiments/simple_fsdp/llama3/parallelize.py @@ -14,7 +14,7 @@ from torchtitan.models.llama3.infra.parallelize import apply_tp from torchtitan.tools.logging import logger -from ..backend import get_compile_backend +from ..backend import get_compile_backend_with_passes from ..simple_fsdp import data_parallel, MixedPrecisionPolicy @@ -24,15 +24,45 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn, + torch._higher_order_ops.inductor_compiled_code, } +def get_transformer_block_buckets(model) -> list[list[str] | str]: + module_list = [ + model.tok_embeddings, + [model.norm, model.output], + ] + for layer_id, transformer_block in model.layers.items(): + module_list.append(transformer_block) + + def convert_modules_to_fqns(modules, module_to_fqn_mapping): + """Convert a (possibly nested) list of modules to FQN strings.""" + result = [] + for m in modules: + if isinstance(m, list): + if fqn_list := convert_modules_to_fqns(m, module_to_fqn_mapping): + result.append(fqn_list) + else: + if fqn := module_to_fqn_mapping.get(m): + result.append(fqn) + return result + + module_to_name = {m: n for n, m in model.named_modules()} + module_fqns = convert_modules_to_fqns(module_list, module_to_name) + return module_fqns + + def parallelize_llama( model: nn.Module, parallel_dims: ParallelDims, @@ -67,7 +97,7 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise - tp_mesh = parallel_dims.world_mesh["tp"] + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, tp_mesh, @@ -77,7 +107,6 @@ def parallelize_llama( maybe_enable_async_tp(job_config, tp_mesh) if job_config.activation_checkpoint.mode != "none": - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -85,7 +114,6 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -98,13 +126,13 @@ def parallelize_llama( ): if parallel_dims.dp_replicate_enabled: if parallel_dims.dp_shard_enabled or parallel_dims.cp_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + dp_mesh_dim_names = ["dp_replicate", "fsdp"] dp_mode = "hybrid_shard" else: - dp_mesh_dim_names = ("dp_replicate",) + dp_mesh_dim_names = ["dp_replicate"] dp_mode = "replicate" else: - dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh_dim_names = ["fsdp"] dp_mode = "fully_shard" mp_policy = MixedPrecisionPolicy( @@ -112,27 +140,11 @@ def parallelize_llama( reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - match job_config.parallelism.fsdp_reshard_after_forward: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not parallel_dims.pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." - ) - model = data_parallel( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(dp_mesh_dim_names), mode=dp_mode, - ac_mode=job_config.activation_checkpoint.mode, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, ) logger.info( "Applied Data Parallel (simple_fsdp) (dp mode=%s) to the model", dp_mode @@ -140,13 +152,29 @@ def parallelize_llama( if job_config.compile.enable and "model" in job_config.compile.components: torch._inductor.config.reorder_for_peak_memory = False - backend = ( - getattr(job_config.compile, "model_backend_override", None) - or job_config.compile.backend + + match job_config.parallelism.fsdp_reshard_after_forward: + case "always": + fsdp_reshard_after_forward = True + case "never": + fsdp_reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + fsdp_reshard_after_forward = not parallel_dims.pp_enabled + case _: + raise ValueError( + f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}." + ) + + backend = get_compile_backend_with_passes( + job_config.compile, + fsdp_reshard_after_forward, + get_transformer_block_buckets(model), ) model = torch.compile( model, - backend=get_compile_backend(backend), + backend=backend, fullgraph=True, ) diff --git a/torchtitan/experiments/simple_fsdp/reshard_after_forward.py b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py new file mode 100644 index 0000000000..dac010bfcd --- /dev/null +++ b/torchtitan/experiments/simple_fsdp/reshard_after_forward.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch.utils.checkpoint import CheckpointPolicy + + +def is_graph_input(node: torch.fx.Node) -> bool: + return node.op == "placeholder" + + +def is_wait_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.wait_tensor.default + ) + + +def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: + return ( + node.op == "call_function" + and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default + ) + + +def is_wait_tensor_from_fsdp(node: torch.fx.Node) -> bool: + """ + Returns True if the node is a wait_tensor node that is the result of an all_gather + that can be arbitrarily prefetched, i.e., if all its recursive inputs are + single-input operators that leads to a graph input. + """ + if is_wait_tensor(node) and is_all_gather_into_tensor(node.args[0]): + n: torch.fx.Node = node.all_input_nodes[0] + while len(n.all_input_nodes) == 1: + if is_graph_input(n.all_input_nodes[0]): + return True + n = n.all_input_nodes[0] + return False + + +def annotate_fsdp_all_gather( + gm: torch.fx.GraphModule, reshard_after_forward: bool +) -> None: + """ + Force recompute all_gather nodes from simple fsdp in the graph. + This pass should be added in torch._inductor.config.joint_custom_post_pass + """ + graph = gm.graph + + def force_recompute_node(node): + if reshard_after_forward: + node.meta["recompute"] = CheckpointPolicy.MUST_RECOMPUTE + else: + node.meta["recompute"] = CheckpointPolicy.MUST_SAVE + # ac_graph_id is used in the partitioner to decide + # if two nodes which have AC applied come from a different + # AC regions. This is needed because nodes in the boundary + # of two AC regions are marked as MUST_SAVE. In our case + # we just add a large value of ac_graph_id so that + # all nodes we tag for recomputation do indeed get recomputed + # and are not influenced by other nodes in the graph with + # nearby ac_graph_id values + node.meta["ac_graph_id"] = 100000 + + # Make all-gather nodes (and related nodes) recomputable, to circumvent + # https://github.com/pytorch/pytorch/issues/136433 + for node in graph.nodes: + if is_wait_tensor_from_fsdp(node): + ag_node = node.args[0] + force_recompute_node(ag_node) # all_gather + force_recompute_node(node) # wait_tensor + # Force-recompute slice that comes after wait + for user in node.users: + if ( + user.op == "call_function" + and user.target == torch.ops.aten.slice.Tensor + ): + force_recompute_node(user) + # Force-recompute potential dtype casts from all_gather + if ( + ag_node.all_input_nodes[0].op == "call_function" + and ag_node.args[0].target + == torch.ops.prims.convert_element_type.default + ): + force_recompute_node(ag_node.all_input_nodes[0]) + + return gm diff --git a/torchtitan/experiments/simple_fsdp/simple_fsdp.py b/torchtitan/experiments/simple_fsdp/simple_fsdp.py index 737b6d3ec2..6597c45f9d 100644 --- a/torchtitan/experiments/simple_fsdp/simple_fsdp.py +++ b/torchtitan/experiments/simple_fsdp/simple_fsdp.py @@ -4,7 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from collections.abc import Sequence +from collections.abc import Generator, Sequence from contextlib import contextmanager from dataclasses import dataclass @@ -22,18 +22,12 @@ from torch.distributed.tensor._dtensor_spec import DTensorSpec from torch.distributed.tensor._redistribute import redistribute_local_tensor from torch.distributed.tensor.placement_types import _StridedShard, Placement -from torch.utils.checkpoint import ( - checkpoint, - CheckpointPolicy, - create_selective_checkpoint_contexts, -) - _active_parametrization = True @contextmanager -def disable_active_parametrization(): +def disable_active_parametrization() -> Generator[None, None, None]: global _active_parametrization try: _active_parametrization = False @@ -183,53 +177,32 @@ def _register_parametrization( module.__class__ = module_cls -def fsdp_policy(): - def _fsdp_recomp_policy(): - def _custom_policy(ctx, func, *args, **kwargs): - to_recompute = func in { - torch.ops._c10d_functional.all_gather_into_tensor.default, - torch.ops._c10d_functional.wait_tensor.default, - torch.ops.aten._to_copy.default, # for dtype cast in FSDP - } - return ( - CheckpointPolicy.MUST_RECOMPUTE - if to_recompute - else CheckpointPolicy.MUST_SAVE - ) - - return _custom_policy - - return create_selective_checkpoint_contexts(_fsdp_recomp_policy()) - - class ReplicateComputation(torch.nn.Module): def __init__( self, - device_mesh, - param_sharding, - mode, - regional_ac, - mp_policy, - reshard_after_forward, - reduction_divide_factor, - ): + device_mesh: DeviceMesh, + param_sharding: tuple[Placement, ...], + mode: str, + mp_policy: MixedPrecisionPolicy | None, + reduction_divide_factor: float | None, + full_dtensor: bool = False, + ) -> None: super().__init__() self.device_mesh = device_mesh self.param_sharding = param_sharding self.mode = mode - self.compute_placements = [Replicate()] * self.device_mesh.ndim - self.grad_placements = [ + self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim + self.grad_placements: list[Placement] = [ _ScaledPartial( reduction_divide_factor=reduction_divide_factor, ) if reduction_divide_factor is not None else Partial(reduce_op="avg") ] * self.device_mesh.ndim - self.regional_ac = regional_ac mp_policy = mp_policy or MixedPrecisionPolicy() - self.param_dtype = mp_policy.param_dtype - self.reduce_dtype = mp_policy.reduce_dtype - self.reshard_after_forward = reshard_after_forward + self.param_dtype: torch.dtype | None = mp_policy.param_dtype + self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype + self.full_dtensor = full_dtensor def replicate_compute(self, x: DTensor) -> torch.Tensor: # data parallel runtime replicate parameters and do local compute @@ -239,6 +212,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported" if non_dp_mesh_dims > 0: + if self.full_dtensor: + raise NotImplementedError( + "full_dtensor not implemented for nD parallelisms" + ) dp_mesh = self.device_mesh # re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather sharded_local_tensor = x.to_local() @@ -274,7 +251,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor: placements=self.compute_placements, forward_dtype=self.param_dtype, backward_dtype=self.reduce_dtype, - ).to_local(grad_placements=self.grad_placements) + ) + + if not self.full_dtensor: + output = output.to_local(grad_placements=self.grad_placements) else: raise AssertionError( f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}" @@ -292,21 +272,7 @@ def forward(self, x: DTensor) -> torch.Tensor: if not _active_parametrization: return x - if ( - self.regional_ac - and self.mode in ("fully_shard", "hybrid_shard") - and self.reshard_after_forward - ): - # apply checkpointing to implement reshard_after_forward - output = checkpoint( - self.replicate_compute, - x, - use_reentrant=False, - context_fn=fsdp_policy, - ) - else: - output = self.replicate_compute(x) - + output = self.replicate_compute(x) return output @@ -314,12 +280,12 @@ def data_parallel( model: nn.Module, device_mesh: DeviceMesh, mode: str = "replicate", - ac_mode: str = "none", mp_policy: MixedPrecisionPolicy | None = None, - reshard_after_forward: bool = True, shard_dim: int = 0, reduction_divide_factor: float | None = None, -): + full_dtensor: bool = False, +) -> nn.Module: + param_sharding: tuple[Placement, ...] if mode == "replicate": param_sharding = (Replicate(),) elif mode == "fully_shard": @@ -335,9 +301,6 @@ def data_parallel( modules = list(model.modules()) - # apply regional ac (with fsdp_policy) if no global ac is to be applied - regional_ac = ac_mode == "none" - for mod in modules: params_dict = dict(mod.named_parameters(recurse=False)) # we shouldn't apply data parallel to the modules that are already @@ -366,7 +329,6 @@ def data_parallel( # device_mesh, # param_sharding, # mode, - # regional_ac, # mp_policy=mp_policy, # ), # unsafe=True, @@ -379,10 +341,9 @@ def data_parallel( device_mesh, param_sharding, mode, - regional_ac, mp_policy=mp_policy, - reshard_after_forward=reshard_after_forward, reduction_divide_factor=reduction_divide_factor, + full_dtensor=full_dtensor, ), ) return model diff --git a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py index f18ee95528..c3cee7b52f 100755 --- a/torchtitan/experiments/simple_fsdp/tests/integration_tests.py +++ b/torchtitan/experiments/simple_fsdp/tests/integration_tests.py @@ -35,11 +35,25 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]: "--model.name simple_fsdp.llama3", "--compile.enable", "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", - "--compile.model_backend_override aot_eager_autobucketing", + "--compile.backend aot_eager", + "--compile.graph_passes auto_bucketing", ], ], - "1D+aot_eager_autobucketing", - "1d_aot_eager_autobucketing", + "1D+autobucketing", + "1d_autobucketing", + ), + OverrideDefinitions( + [ + [ + "--model.name simple_fsdp.llama3", + "--compile.enable", + "--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config", + "--compile.backend aot_eager", + "--compile.graph_passes transformer_block_bucketing", + ], + ], + "1D+transformer_block_bucketing", + "1d_transformer_block_bucketing", ), OverrideDefinitions( [ diff --git a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py index 76233aeb87..aaf94a5023 100644 --- a/torchtitan/experiments/simple_fsdp/tests/test_numerics.py +++ b/torchtitan/experiments/simple_fsdp/tests/test_numerics.py @@ -20,13 +20,13 @@ def init_test(self): self.loss_fn = cross_entropy_loss data_parallel_shard_degree = -1 if self.mode == "replicate": - self.dp_mesh_dim_names = ("dp_replicate",) + self.dp_mesh_dim_names = ["dp_replicate"] data_parallel_replicate_degree = self.world_size elif self.mode == "fully_shard": - self.dp_mesh_dim_names = ("dp_shard_cp",) + self.dp_mesh_dim_names = ["fsdp"] data_parallel_replicate_degree = 1 elif self.mode == "hybrid_shard": - self.dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + self.dp_mesh_dim_names = ["dp_replicate", "fsdp"] data_parallel_replicate_degree = self.world_size // 2 else: raise ValueError(f"Unsupported mode {self.mode}") @@ -41,7 +41,6 @@ def init_test(self): etp=1, world_size=self.world_size, ) - self.device_mesh = self.parallel_dims.world_mesh def get_input(self): inputs = torch.randn(8, 8).cuda() @@ -50,7 +49,7 @@ def get_input(self): return model, inputs, labels def run_fsdp2(self, model, inputs, labels, epoch=20): - fully_shard(model, mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)]) + fully_shard(model, mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names)) optim = self.optimizer(model.parameters(), lr=1e-4) losses = [] for _ in range(epoch): @@ -65,7 +64,7 @@ def run_fsdp2(self, model, inputs, labels, epoch=20): def run_simple_fsdp(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) optim = self.optimizer(model.parameters(), lr=1e-4) @@ -82,7 +81,7 @@ def run_simple_fsdp(self, model, inputs, labels, epoch=20): def run_simple_fsdp_compiled_aot_eager(self, model, inputs, labels, epoch=20): model = data_parallel( model, - device_mesh=self.device_mesh[tuple(self.dp_mesh_dim_names)], + device_mesh=self.parallel_dims.get_mesh(self.dp_mesh_dim_names), mode=self.mode, ) # TODO: Add "inductor" backend when it's numerical issues are fixed diff --git a/torchtitan/experiments/transformers_modeling_backend/README.md b/torchtitan/experiments/transformers_modeling_backend/README.md new file mode 100644 index 0000000000..fb70d03a1f --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/README.md @@ -0,0 +1,54 @@ +# Huggingface Transformers Modeling backend + +This enables HF transformers models to be trained with `4D parallelism + torch.compile` + +## Quick start + +- Requirements `transformers==4.57.1` + +- Config: `torchtitan/torchtitan/experiments/transformers_modeling_backend/configs/qwen3.toml` +```diff +... +[model] +- name = "llama3" ++ name = "transformers_modeling_backend" +flavor = "debugmodel" +hf_assets_path = "./tests/assets/tokenizer" + ++[hf_transformers] ++model = "Qwen/Qwen3-4B-Instruct-2507" +... +``` +- Train: `LOG_RANK=7 CONFIG_FILE=/torchtitan/experiments/transformers_modeling_backend/configs/qwen3.toml ./run_train.sh --job.custom_config_module=torchtitan.experiments.transformers_modeling_backend.job_config --compile.enable` + - Make sure you have created the tokenizers beforehand +image + +## Supported Features + +- The following models were tested: + - Dense (FSDP/CP/TP/PP/`torch.compile`) + - `meta-llama/Llama-3.2-1B` + - `microsoft/phi-2` + - `Qwen/Qwen2.5-7B` + - `mistralai/Mistral-7B-v0.1` + - `ByteDance-Seed/Seed-Coder-8B-Instruct` + - `Qwen/Qwen3-4B-Instruct-2507` + - `arcee-ai/AFM-4.5B` + - `ibm-granite/granite-3b-code-base-2k` + - `baidu/ERNIE-4.5-0.3B-Base-PT` + - `kyutai/helium-1-preview-2b` + - `allenai/OLMo-7B-hf` + - `mistralai/Ministral-8B-Instruct-2410` + - MoE (upcoming) + +## Known issues to address later + +- When using HF modeling, the test `FSDP=2 vs FSDP=2 + PP=2`, the `loss` and `grad_norm` not bitwise matching (but converging) while it is the case with Torchtitan modeling. This will be addressed in another PR but the culprit is probably `register_buffer` when loading `seed_checkpoint` +- the HF modeling has lower MFU than Torchtitan MFU + +## Further work + +- Missing `build_optimizers_with_moe_load_balancing` support for MoE +- Missing TP/PP/EP supports for MoE +- Load HF weights +- Add LORA support diff --git a/torchtitan/experiments/transformers_modeling_backend/__init__.py b/torchtitan/experiments/transformers_modeling_backend/__init__.py new file mode 100644 index 0000000000..aec28a0bdd --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/__init__.py @@ -0,0 +1,51 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.hf_datasets.text_datasets import build_text_dataloader +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_hf_transformers + +from .infra.pipeline import pipeline_hf_transformers +from .model.args import HFTransformerModelArgs, TitanDenseModelArgs +from .model.model import HFTransformerModel + +__all__ = [ + "HFTransformerModelArgs", + "HFTransformerModel", +] + + +flavors = { + "debugmodel": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs( + dim=256, + n_layers=2, + n_heads=16, + n_kv_heads=16, + ), + ), + "full": HFTransformerModelArgs( + titan_dense_args=TitanDenseModelArgs(), + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=HFTransformerModel, + model_args=flavors, + parallelize_fn=parallelize_hf_transformers, + pipelining_fn=pipeline_hf_transformers, + build_optimizers_fn=build_optimizers, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_text_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml b/torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml new file mode 100644 index 0000000000..0775ead39b --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/configs/debug_model.toml @@ -0,0 +1,88 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 debug training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_modeling_backend" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +dataset_path = "./tests/assets/c4_test" + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_modeling_backend/configs/full.toml b/torchtitan/experiments/transformers_modeling_backend/configs/full.toml new file mode 100644 index 0000000000..34ec994fb1 --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/configs/full.toml @@ -0,0 +1,87 @@ +# torchtitan Config.toml + +[job] +dump_folder = "./outputs" +description = "Qwen 3 full training" +print_config = true + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "transformers_modeling_backend" +flavor = "full" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[hf_transformers] +model = "Qwen/Qwen3-4B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "selective" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable=false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py new file mode 100644 index 0000000000..fcdb31f27d --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/infra/parallelize.py @@ -0,0 +1,435 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.fsdp import CPUOffloadPolicy, fully_shard, MixedPrecisionPolicy +from torch.distributed.tensor import Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + RowwiseParallel, + SequenceParallel, +) +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.distributed import NoParallel, ParallelDims + +from torchtitan.distributed.activation_checkpoint import apply_ac + +from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.experiments.transformers_modeling_backend.job_config import JobConfig +from torchtitan.models.llama3.infra.parallelize import apply_compile, apply_ddp +from torchtitan.tools.logging import logger + + +def parallelize_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + """ + Apply tensor parallelism, activation checkpointing, torch.compile, and data + parallelism to the model. + + NOTE: The passed-in model preferably should be on meta device. Otherwise, + the model must fit on GPU or CPU memory. + """ + # TODO: TP currently cannot handle uneven seq_len because we set + # `use_local_output=True` to use plain Tensors for legacy reasons. + # Need to revisit this. + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + if parallel_dims.tp_enabled: + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_non_moe_tp( + model, + parallel_dims.get_mesh("tp"), + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + ) + maybe_enable_async_tp(job_config, parallel_dims.get_mesh("tp")) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac(model, job_config.activation_checkpoint) + + # turn on per-TransformerBlock compile after AC wrapping and before FSDP + if model_compile_enabled: + apply_compile(model, job_config.compile) + + if parallel_dims.fsdp_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "fsdp") + else: + dp_mesh_dim_names = ("fsdp",) + + apply_fsdp( + model, + parallel_dims.get_mesh(list(dp_mesh_dim_names)), + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + model.set_cp_mesh(parallel_dims.get_mesh("cp")) + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): + raise RuntimeError("DDP has not supported > 1D parallelism") + apply_ddp( + model, + dp_replicate_mesh, + enable_compile=model_compile_enabled, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + + # skipping nn.Identity modules (which are added by pipeline parallelism for unused modules) + root_plan = {} + + if hasattr(model, "tok_embeddings"): + if isinstance(model.tok_embeddings, nn.Identity): + root_plan["tok_embeddings"] = NoParallel() + else: + root_plan["tok_embeddings"] = RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ) + + if hasattr(model, "norm"): + if isinstance(model.norm, nn.Identity): + root_plan["norm"] = NoParallel() + else: + root_plan["norm"] = SequenceParallel() + + if hasattr(model, "output"): + if isinstance(model.output, nn.Identity): + root_plan["output"] = NoParallel() + else: + root_plan["output"] = ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ) + if root_plan: # Only call if there's something to parallelize + parallelize_module(model, tp_mesh, root_plan) + + # Parallel styles used for transformer block linear weights and their + # inputs may be different for float8 linears with tensorwise scaling. + if enable_float8_tensorwise_tp: + # TODO(vkuzo): add the items below to __init__.py of torchao.float8 and import from there + from torchao.float8.float8_tensor_parallel import ( + Float8ColwiseParallel, + Float8RowwiseParallel, + PrepareFloat8ModuleInput, + ) + + rowwise_parallel, colwise_parallel, prepare_module_input = ( + Float8RowwiseParallel, + Float8ColwiseParallel, + PrepareFloat8ModuleInput, + ) + else: + rowwise_parallel, colwise_parallel, prepare_module_input = ( + RowwiseParallel, + ColwiseParallel, + PrepareModuleInput, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers: + layer_plan = { + "input_layernorm": SequenceParallel(), + "self_attn": prepare_module_input( + input_kwarg_layouts={"hidden_states": Shard(1)}, + desired_input_kwarg_layouts={"hidden_states": Replicate()}, + ), + "post_attention_layernorm": SequenceParallel(), + } + + if getattr(transformer_block.self_attn, "q_lora_rank", None) is None: + layer_plan.update( + { + "self_attn.q_proj": colwise_parallel(), + "self_attn.k_proj": colwise_parallel(), + "self_attn.v_proj": colwise_parallel(), + } + ) + else: + layer_plan.update( + { + "self_attn.q_a_proj": NoParallel(), + "self_attn.q_a_layernorm": NoParallel(), + "self_attn.q_b_proj": colwise_parallel(), + "self_attn.kv_a_proj_with_mqa": NoParallel(), + "self_attn.kv_a_layernorm": NoParallel(), + "self_attn.kv_b_proj": colwise_parallel(), + } + ) + + # Handle different names for the output projection layer, e.g. o_proj vs dense + o_proj_name = ( + "o_proj" if hasattr(transformer_block.self_attn, "o_proj") else "dense" + ) + layer_plan[f"self_attn.{o_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + # For model that uses RMSNorm on Q and K (i.e. Qwen3) + if hasattr(transformer_block.self_attn, "q_norm") and hasattr( + transformer_block.self_attn, "k_norm" + ): + layer_plan["self_attn.q_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + layer_plan["self_attn.k_norm"] = SequenceParallel( + sequence_dim=2, use_local_output=True + ) + + if not transformer_block.moe_enabled: + mlp_plan = { + "mlp": prepare_module_input( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + ), + } + # Handle different names for MLP layers, e.g. gate_proj vs fc1 + gate_proj_name = ( + "gate_proj" if hasattr(transformer_block.mlp, "gate_proj") else "fc1" + ) + mlp_plan[f"mlp.{gate_proj_name}"] = colwise_parallel() + + if hasattr(transformer_block.mlp, "up_proj"): + mlp_plan["mlp.up_proj"] = colwise_parallel() + + down_proj_name = ( + "down_proj" if hasattr(transformer_block.mlp, "down_proj") else "fc2" + ) + mlp_plan[f"mlp.{down_proj_name}"] = rowwise_parallel( + output_layouts=Shard(1) + ) + layer_plan.update(mlp_plan) + + # Some models like Phi-2 don't have post_attention_layernorm + if not hasattr(transformer_block, "post_attention_layernorm"): + layer_plan.pop("post_attention_layernorm") + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_fsdp( + model: nn.Module, + dp_mesh: DeviceMesh, + param_dtype: torch.dtype, + reduce_dtype: torch.dtype, + pp_enabled: bool, + cpu_offload: bool = False, + reshard_after_forward_policy: str = "default", + ep_degree: int = 1, + dp_mod_ep_mesh: DeviceMesh | None = None, + gradient_divide_factor: int | None = None, +): + """ + Apply data parallelism (via FSDP2) to the model. + + Args: + model (nn.Module): The model to apply data parallelism to. + dp_mesh (DeviceMesh): The device mesh to use for data parallelism. + param_dtype (torch.dtype): The data type to use for model parameters. + reduce_dtype (torch.dtype): The data type to use for reduction operations. + pp_enabled (bool): Whether pipeline parallelism is enabled. + cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. + reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". + Other options: "never", "always". + - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. + - "always" will enable `reshard_after_forward` for all forward passes. + - "never" will disable `reshard_after_forward` for all forward passes. + + """ + mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) + fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + if cpu_offload: + fsdp_config["offload_policy"] = CPUOffloadPolicy() + + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) + + if model.tok_embeddings is not None: + fully_shard( + model.tok_embeddings, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + for transformer_block in model.layers: + # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping + # - the router and the shared experts are sharded together with the TransformerBlock + # - the routed experts are sharded with the remaining dp_mod_ep_mesh + if ( + hasattr(transformer_block, "moe_enabled") + and transformer_block.moe_enabled + and ep_degree > 1 + ): + fsdp_mod_ep_config = fsdp_config.copy() + fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + moe_block = transformer_block.mlp + # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). + # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding + # causes inefficiency, so we choose to do FSDP sharding on dim-1. + # Even when EP is not used, we may still want to shard the experts + # on non-0 dim. For now it may not be worth the complexity to support + # shard_placement_fn on the outer TransformerBlock-level FSDP. + _experts_shard_placement_fn = None + assert dp_mod_ep_mesh is not None + if dp_mod_ep_mesh.size() * ep_degree > moe_block.experts.num_experts: + _experts_shard_placement_fn = lambda param: Shard(1) + + fully_shard( + moe_block.experts, + **fsdp_mod_ep_config, + reshard_after_forward=reshard_after_forward, + shard_placement_fn=_experts_shard_placement_fn, + ) + + # NOTE: # Although the FSDP sharding of experts is done on a mesh of + # a different size than other parameters, the gradient division + # factor should be consistent with data. + moe_block.experts.set_gradient_divide_factor( + gradient_divide_factor, + ) + + fully_shard( + transformer_block, + **fsdp_config, + reshard_after_forward=reshard_after_forward, + ) + + # As an optimization, do not reshard_after_forward the last layers by default + # since FSDP would prefetch them immediately after the forward pass + if model.norm is not None and model.output is not None: + fully_shard( + [model.norm, model.output], + **fsdp_config, + reshard_after_forward=reshard_after_forward_policy == "always", + ) + + fully_shard(model, **fsdp_config) + + # NOTE: set up explicit prefetching when EP is enabled, as D2H syncs + # in EP could interfere with implicit prefetching in FSDP + if ep_degree == 1: + return + + # forward + transformer_blocks = list(model.layers.values()) + next_transformer_blocks = transformer_blocks[1:] + [None] + + if model.tok_embeddings is not None and model.layers is not None: + model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) + + for transformer_block, next_transformer_block in zip( + transformer_blocks, next_transformer_blocks + ): + if next_transformer_block is not None: + if next_transformer_block.moe_enabled: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block, next_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_forward_prefetch( + [next_transformer_block] + ) + elif model.norm is not None and model.output is not None: + transformer_block.set_modules_to_forward_prefetch( + [model.norm, model.output] + ) + + # backward + reversed_transformer_blocks = list(reversed(model.layers.values())) + prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + + if model.norm is not None and model.output is not None and model.layers is not None: + model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) + + for transformer_block, prev_transformer_block in zip( + reversed_transformer_blocks, prev_transformer_blocks + ): + if prev_transformer_block is not None: + if prev_transformer_block.moe_enabled: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block, prev_transformer_block.mlp.experts] + ) + else: + transformer_block.set_modules_to_backward_prefetch( + [prev_transformer_block] + ) + elif model.tok_embeddings is not None: + transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) diff --git a/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py new file mode 100644 index 0000000000..f27f884014 --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/infra/pipeline.py @@ -0,0 +1,391 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import copy +import math + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import ( + _PipelineSchedule, + get_schedule_class, + PipelineScheduleSingle, + ScheduleDualPipeV, + ScheduleZBVZeroBubble, +) + +from torchtitan.components.loss import LossFunction +from torchtitan.distributed import ParallelDims +from torchtitan.distributed.pipeline_parallel import build_pipeline_schedule +from torchtitan.experiments.transformers_modeling_backend.job_config import JobConfig +from torchtitan.protocols.train_spec import BaseModelArgs, ParallelizeFunction +from torchtitan.tools.logging import logger + +# NOTE(3outeille): the only modifications comes from replacing None to nn.Identity and adding rotary_emb per model_part + + +def generate_llm_fqn_per_model_part( + num_stages: int, + num_layers: int, + input_weight: int = 1, + output_weight: int = 1, +) -> list[list[str]]: + """ + Programmatically generates module names model part, focused on LLMs models. + Args: + num_stages: Number of pipeline stages + num_layers: Total number of transformer layers in the model + input_weight: Weight for input modules (embed_tokens) in layer calculation + output_weight: Weight for output modules (norm + output) in layer calculation + Returns: + List of lists containing module names for each model part + Example: + generate_llm_fqn_per_model_part(2, 3, input_weight=2, output_weight=2) + treats embeddings as 2 layers and norm+output as 2 layers for distribution + """ + if num_stages < 1: + raise ValueError("Number of stages must be at least 1") + + if num_stages == 1: + # Single stage gets everything + layer_names = [f"layers.{i}" for i in range(num_layers)] + return [["tok_embeddings"] + layer_names + ["norm", "output", "rotary_emb"]] + + # Calculate effective layers including weights + num_effective_layers = num_layers + input_weight + output_weight + + if num_stages > num_effective_layers: + raise ValueError( + f"Number of stages ({num_stages}) cannot be greater than effective layers ({num_effective_layers})" + ) + + # Calculate layers per stage (distribute evenly) + layers_per_stage = num_effective_layers // num_stages + extra_layers = num_effective_layers % num_stages + + # Feasibility check: Ensure at least 1 layer in each PP stage + if layers_per_stage == 0: + raise ValueError( + f"Configuration would result in empty stages. " + f"With {num_stages} stages and {num_effective_layers} effective layers " + f"(num_layers={num_layers} + input_weight={input_weight} + output_weight={output_weight}), " + f"each stage would get {layers_per_stage} layers on average. " + f"Reduce num_stages or increase num_layers/weights." + ) + + # Balance check: Ensure weights don't exceed minimum layers per stage + if input_weight > layers_per_stage: + raise ValueError( + f"input_weight ({input_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + if output_weight > layers_per_stage: + raise ValueError( + f"output_weight ({output_weight}) exceeds minimum layers per stage ({layers_per_stage})." + ) + + module_names_per_stage = [] + current_layer = 0 + + for stage_idx in range(num_stages): + stage_modules = [] + + # Calculate effective layers for this stage + effective_layers_for_stage = layers_per_stage + if stage_idx < extra_layers: + effective_layers_for_stage += 1 + + # First stage: handle input modules with weighting + if stage_idx == 0: + stage_modules.append("tok_embeddings") + # Account for input weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - input_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Last stage: handle output modules with weighting + elif stage_idx == num_stages - 1: + # Account for output weight in layer distribution + remaining_layers_for_stage = effective_layers_for_stage - output_weight + + # Add transformer layers + for _ in range(remaining_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + # Add output modules + stage_modules.extend(["norm", "output"]) + + # Middle stages: only transformer layers + else: + for _ in range(effective_layers_for_stage): + if current_layer < num_layers: + stage_modules.append(f"layers.{current_layer}") + current_layer += 1 + + stage_modules.append("rotary_emb") + module_names_per_stage.append(stage_modules) + + return module_names_per_stage + + +def pipeline_module_split( + whole_model: nn.Module, + pp_mesh: DeviceMesh, + pp_schedule: str, + device: torch.device, + module_names_per_stage: list[list[str]], +) -> tuple[list[PipelineStage], list[nn.Module]]: + """ + This API creates pipeline stages based on specified module names for each stage. + + Some model restrictions include: + - forward() method should tolerate deleted layers + - weight initialization methods should tolerate deleted layers + - Does not support nested moduledict and modulelist structures + + Args: + whole_model: The complete model to be split + pp_mesh: Pipeline parallel device mesh + pp_schedule: Name of pipeline parallelism schedule + device: Device + module_names_per_stage: List of lists, where each inner list contains the module names + that should be included in that stage. Module names should be + dot-separated paths. Examples: + - "tok_embeddings" for token embeddings + - "layers.0", "layers.1" for specific transformer layers + - "norm" for the final normalization layer + - "output" for the output projection layer + + Returns: + Tuple of (stages, models) where stages are PipelineStage objects and models are the + corresponding model chunks + + Example usage: + module_names_per_stage = [ + ["tok_embeddings", "layers.0"], # Stage 0: embeddings + first layer + ["layers.1", "layers.2"], # Stage 1: middle layers + ["norm", "output"] # Stage 2: final norm + output + ] + """ + pp_rank = pp_mesh.get_local_rank() + pp_degree = pp_mesh.size() + + def _build_stage_from_modules( + stage_idx: int, module_names: list[str], num_stages: int + ) -> tuple[PipelineStage, nn.Module]: + model = copy.deepcopy(whole_model) + + # Create a set of modules to keep for faster lookup + modules_to_keep = set(module_names) + for module_name, module_value in model.named_children(): + # Handle layer-like structures (e.g., "layers.0", "layers.1") + if isinstance(module_value, (nn.ModuleDict, nn.ModuleList)): + layers_to_keep = { + name.split(".", 1)[1] + for name in modules_to_keep + if name.startswith(f"{module_name}.") + } + if layers_to_keep: + # Keep only specified layers + if isinstance(module_value, nn.ModuleDict): + for layer_name in list(module_value.keys()): + if layer_name not in layers_to_keep: + del module_value[layer_name] + elif isinstance(module_value, nn.ModuleList): + indices_to_keep = { + int(idx) for idx in layers_to_keep if idx.isdigit() + } + new_layers = nn.ModuleList( + [ + layer + for i, layer in enumerate(module_value) + if i in indices_to_keep + ] + ) + setattr(model, module_name, new_layers) + else: + # No layers from this structure needed, set to empty structure + if isinstance(module_value, nn.ModuleDict): + setattr(model, module_name, nn.ModuleDict()) + elif isinstance(module_value, nn.ModuleList): + setattr(model, module_name, nn.ModuleList()) + # Handle simple module attributes (e.g., "linear", "norm") + elif module_name not in modules_to_keep: + # Replace with Identity + setattr(model, module_name, nn.Identity()) + + stage = PipelineStage( + model, + stage_idx, + num_stages, + device, + group=pp_mesh.get_group("pp"), + ) + return stage, model + + num_stages = len(module_names_per_stage) + stages = [] + models = [] + + schedule_class = get_schedule_class(pp_schedule) + style = ( + "v" if schedule_class in (ScheduleZBVZeroBubble, ScheduleDualPipeV) else "loop" + ) + + def _get_stage_indices() -> tuple[int]: + """ + Compute the stage ids for the stages that will run on this pp rank + for either a looped or V style schedule + """ + assert ( + num_stages % pp_degree == 0 + ), f"num_stages {num_stages} must be evenly divisible by pp_degree {pp_degree}" + stages_per_rank = num_stages // pp_degree + if style == "loop": + return tuple(pp_rank + s * pp_degree for s in range(stages_per_rank)) + elif style == "v": + assert ( + stages_per_rank == 2 + ), f"v schedules assume 2 stages per rank, got {stages_per_rank}" + stage_v_pairs = list( + zip(range(pp_degree), range(num_stages - 1, pp_degree - 1, -1)) + ) + return stage_v_pairs[pp_rank] + + for stage_idx in _get_stage_indices(): + module_names = module_names_per_stage[stage_idx] + stage, model_chunk = _build_stage_from_modules( + stage_idx, + module_names, + num_stages, + ) + logger.info( + f"PP rank {pp_rank} is building stage_idx {stage_idx} " + f"with modules {module_names}" + ) + stages.append(stage) + models.append(model_chunk) + + return stages, models + + +def pipeline_hf_transformers( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, + device: torch.device, + model_args: BaseModelArgs, + parallelize_fn: ParallelizeFunction, + loss_fn: LossFunction, +) -> tuple[_PipelineSchedule, list[nn.Module], bool, bool]: + pp_mesh = parallel_dims.get_mesh("pp") + + # Determine the number of virtual stages based on schedule type + schedule_class = get_schedule_class( + job_config.parallelism.pipeline_parallel_schedule + ) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + layers_per_stage = job_config.parallelism.pipeline_parallel_layers_per_stage + if hasattr(model_args, "n_layers"): + num_layers = model_args.n_layers + else: + raise ValueError("Model does not have n_layers attribute.") + + # You can adjust these weights based on the computational cost of embeddings and output layers + # Higher weights mean these modules are treated as "heavier" in the distribution + input_weight = job_config.parallelism.pipeline_parallel_first_stage_less_layers + output_weight = job_config.parallelism.pipeline_parallel_last_stage_less_layers + + # Calculate number of virtual stages + if layers_per_stage is not None: + + # Calculate number of virtual stages needed (using ceiling division) + # This allows for unequal distribution where stages can differ by at most 1 layer + num_virtual_stages = math.ceil( + (num_layers + input_weight + output_weight) / layers_per_stage + ) + + # Validation: check stages per rank based on schedule type + model_config_info = f"Model has {num_layers} layers with pipeline_parallel_layers_per_stage={layers_per_stage}" + stage_distribution_info = ( + f"resulting in {num_virtual_stages=} across {parallel_dims.pp} PP ranks" + ) + + if num_virtual_stages % parallel_dims.pp != 0: + raise ValueError( + f"Number of virtual stages ({num_virtual_stages}) must be divisible by " + f"pipeline parallel size ({parallel_dims.pp}). " + f"{model_config_info}. " + f"Please adjust pipeline_parallel_layers_per_stage to a value that results in a number of stages " + f"divisible by {parallel_dims.pp}." + ) + + stages_per_rank = num_virtual_stages // parallel_dims.pp + + if is_single_stage_schedule and stages_per_rank != 1: + raise ValueError( + f"Single stage schedule requires exactly 1 stage per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please increase pipeline_parallel_layers_per_stage to {num_layers // parallel_dims.pp} or higher " + f"to achieve 1 stage per rank." + ) + + if not is_single_stage_schedule and stages_per_rank < 2: + raise ValueError( + f"Multi-stage schedule requires at least 2 stages per rank, but got {stages_per_rank} stages per rank. " + f"{model_config_info}, {stage_distribution_info}. " + f"Please decrease pipeline_parallel_layers_per_stage to achieve at least 2 stages per rank." + ) + else: + # Fallback to default behavior when layers_per_stage is not provided + # For multi-stage schedules, default is 2 virtual stages per rank + # For single-stage schedules, default is 1 virtual stage per rank + stages_per_rank = 1 if is_single_stage_schedule else 2 + num_virtual_stages = parallel_dims.pp * stages_per_rank + + module_names_per_stage = job_config.parallelism.module_fqns_per_model_part + if module_names_per_stage is None: + module_names_per_stage = generate_llm_fqn_per_model_part( + num_virtual_stages, num_layers, input_weight, output_weight + ) + + stages, model_parts = pipeline_module_split( + model, + pp_mesh, + job_config.parallelism.pipeline_parallel_schedule, + device, + module_names_per_stage, + ) + + # For PP with looped schedules, each item in model_parts is one stage-model-chunk. + # We need to iterate through model_parts to apply SPMD parallelisms, compilation, + # optimizer, and checkpointing + for i, m in enumerate(model_parts): + # apply SPMD-style PT-D techniques + m = parallelize_fn(m, parallel_dims, job_config) + model_parts[i] = m + # NOTE: this is to update the model in the stage + # in case the model is modified e.g. by torch.compile + stages[i].submod = m + + pp_schedule = build_pipeline_schedule(job_config, stages, loss_fn) + + # This is used in the train loop to determine whether to pass in the input_ids and labels + has_first_stage = False + has_last_stage = False + for stage in stages: + if stage.is_first: + has_first_stage = True + if stage.is_last: + has_last_stage = True + + return pp_schedule, model_parts, has_first_stage, has_last_stage diff --git a/torchtitan/experiments/transformers_modeling_backend/job_config.py b/torchtitan/experiments/transformers_modeling_backend/job_config.py new file mode 100644 index 0000000000..f3b1667798 --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/job_config.py @@ -0,0 +1,18 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass, field + + +@dataclass +class HFTransformers: + model: str = "" + """HuggingFace model ID (e.g., 'Qwen/Qwen3-4B-Instruct-2507')""" + + +@dataclass +class JobConfig: + hf_transformers: HFTransformers = field(default_factory=HFTransformers) diff --git a/torchtitan/experiments/transformers_modeling_backend/model/args.py b/torchtitan/experiments/transformers_modeling_backend/model/args.py new file mode 100644 index 0000000000..25ab328f15 --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/model/args.py @@ -0,0 +1,199 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from dataclasses import dataclass + +from torch import nn +from torchtitan.config.job_config import JobConfig +from torchtitan.models.utils import get_dense_model_nparams_and_flops +from torchtitan.protocols import BaseModelArgs +from transformers import AutoConfig +from transformers.configuration_utils import PretrainedConfig +from transformers.integrations.sdpa_attention import sdpa_attention_forward +from transformers.modeling_utils import AttentionInterface + + +@dataclass +class TitanDenseModelArgs: + """Arguments for the base TorchTitan model.""" + + dim: int = 4096 + n_layers: int = 32 + n_heads: int = 32 + n_kv_heads: int | None = None + vocab_size: int | None = None + multiple_of: int = 256 + ffn_dim_multiplier: float | None = None + norm_eps: float = 1e-5 + rope_theta: float = 10000 + max_seq_len: int = 2048 + depth_init: bool = True + use_flex_attn: bool = False + attn_mask_type: str = "causal" + + +@dataclass +class HFTransformerModelArgs(PretrainedConfig, BaseModelArgs): + """ + Configuration class that bridges TorchTitan and HuggingFace Transformers naming conventions. + + Uses properties to provide TorchTitan-style access while maintaining HuggingFace compatibility. + Properties are created dynamically based on which arguments are provided. + """ + + # Define all possible mappings organized by argument type + _TT_TO_HF_MAPPINGS = { + "dense": { + # TorchTitan dense model mappings (always available) + "dim": "hidden_size", + "n_layers": "num_hidden_layers", + "n_heads": "num_attention_heads", + "n_kv_heads": "num_key_value_heads", + "norm_eps": "rms_norm_eps", + "max_seq_len": "max_position_embeddings", + "eos_id": "eos_token_id", + } + } + + # Declarative list of TorchTitan-only attributes (no HF equivalent) + _TT_SPECIFIC_ATTRIBUTES = [ + "multiple_of", + "ffn_dim_multiplier", + "depth_init", + "use_flex_attn", + "attn_mask_type", + ] + + def __init__( + self, + titan_dense_args, + # HuggingFace specific args + attn_implementation: str = "sdpa_torchtitan", + **kwargs, + ): + super().__init__(attn_implementation=attn_implementation, **kwargs) + assert titan_dense_args is not None, "titan_dense_args is required" + + # Create getter/setter dynamically for TT <-> HF attribute mappings + self._create_getter_setter_dynamically(has_moe=False) + + self._titan_injected_model_args = {} + self._configure_hf_attention(attn_implementation) + + self._initialize_dense_attributes(titan_dense_args) + + def _initialize_dense_attributes(self, titan_dense_args): + """Initialize all dense model attributes.""" + # Set mapped attributes (TorchTitan <-> HuggingFace) + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(titan_dense_args, titan_name): + value = getattr(titan_dense_args, titan_name) + setattr(self, hf_name, value) + + # Set TorchTitan-only attributes + for attr_name in self._TT_SPECIFIC_ATTRIBUTES: + if hasattr(titan_dense_args, attr_name): + setattr(self, attr_name, getattr(titan_dense_args, attr_name)) + + # Update passed_args + self._titan_injected_model_args.update(titan_dense_args.__dict__) + + def _configure_hf_attention(self, attn_implementation: str): + """Configure HuggingFace attention settings.""" + self._titan_injected_model_args["attn_implementation"] = attn_implementation + self.attn_implementation = attn_implementation + # NOTE:(3outeille):This will force create_causal_mask to return None + AttentionInterface._global_mapping[attn_implementation] = sdpa_attention_forward + + def _create_getter_setter_dynamically(self, has_moe: bool): + """ + Create properties dynamically based on tt and hf attribute mappings. + For example, creates a property 'dim' that reads/writes to 'hidden_size'. + """ + + def _create_property(hf_name: str) -> property: + def getter(self): + return getattr(self, hf_name) + + def setter(self, value): + setattr(self, hf_name, value) + + return property(getter, setter) + + # Setup attribute mappings + self._tt_to_hf_attribute_map = dict(self._TT_TO_HF_MAPPINGS["dense"]) + if has_moe: + self._tt_to_hf_attribute_map.update(self._TT_TO_HF_MAPPINGS["moe"]) + + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + # Create getter/setter for attribute that don't already exist + if not hasattr(self.__class__, titan_name): + setattr(self.__class__, titan_name, _create_property(hf_name)) + + def __repr__(self) -> str: + # HFTransformerModelArgs is a dataclass that also inherits from PretrainedConfig. + # PretrainedConfig has a __repr__ that serializes the object to JSON, but it + # doesn't work well with how HFTransformerModelArgs is initialized. + # This custom __repr__ provides a dataclass-like representation that correctly + # displays the arguments passed during initialization. + args_lines = [ + f"{k}={getattr(self, k)!r}" + for k in sorted(self._titan_injected_model_args.keys()) + if hasattr(self, k) + ] + args_str = "\n".join(args_lines) + return f"{self.__class__.__name__}(\n{args_str}\n)" + + def update_from_config(self, job_config: JobConfig): + # Load HF config (overwrites our HF attributes) + hf_model_config = AutoConfig.from_pretrained( + job_config.hf_transformers.model, + attn_implementation=self.attn_implementation, + trust_remote_code=True, + ) + + # Explicitly update attributes based on mappings + for titan_name, hf_name in self._tt_to_hf_attribute_map.items(): + if hasattr(hf_model_config, hf_name): + setattr(self, titan_name, getattr(hf_model_config, hf_name)) + + # Copy any other attributes that might not be in the mapping + for key, value in hf_model_config.to_dict().items(): + setattr(self, key, value) + + # Update our attributes with the passed args from flavors + for key, value in self._titan_injected_model_args.items(): + if hasattr(self, key) and value is not None: + setattr(self, key, value) + + self.max_seq_len = job_config.training.seq_len + + self.deterministic = job_config.debug.deterministic + + # Configure HF-specific settings to match TorchTitan settings + # TODO: false ? + self.attention_bias = False + self.mlp_bias = False + self.use_cache = False + self.initializer_range = 1.0 # use as std for normal init in embedding + + if not hasattr(self, "inter_dim"): # Only for llama model + ffn_hidden_size = 4 * self.dim + ffn_hidden_size = int(2 * ffn_hidden_size / 3) + if self.ffn_dim_multiplier is not None: + ffn_hidden_size = int(self.ffn_dim_multiplier * ffn_hidden_size) + self.intermediate_size = self.multiple_of * ( + (ffn_hidden_size + self.multiple_of - 1) // self.multiple_of + ) + + self.head_dim = self.dim // self.num_attention_heads + + return self + + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: + return get_dense_model_nparams_and_flops( + self, model, head_dims=self.head_dim, seq_len=seq_len + ) diff --git a/torchtitan/experiments/transformers_modeling_backend/model/model.py b/torchtitan/experiments/transformers_modeling_backend/model/model.py new file mode 100644 index 0000000000..b88fffc54b --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/model/model.py @@ -0,0 +1,477 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import importlib +import math + +import torch +from torch import nn +from torch.nn import init +from torchtitan.tools.logging import logger +from transformers.configuration_utils import PretrainedConfig +from transformers.modeling_utils import PreTrainedModel + +from .args import HFTransformerModelArgs + + +class SliceableModuleDict(nn.ModuleDict): + """ + A ModuleDict that supports slicing like ModuleList. + Keys are expected to be string representations of integers (e.g., "0", "1", "2"). + """ + + def __getitem__(self, key): + if isinstance(key, slice): + # Handle slicing: convert slice to list of keys + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + sliced_keys = keys[key] + # Return a new SliceableModuleDict with the sliced items + return SliceableModuleDict({k: self[k] for k in sliced_keys}) + return super().__getitem__(key) + + def __iter__(self): + # Iterate over values in sorted order by key (as integers) + keys = sorted( + self.keys(), key=lambda x: int(x) if x.isdigit() else float("inf") + ) + for key in keys: + yield self[key] + + def __len__(self): + return len(self._modules) + + +class HFTransformerModel(nn.Module): + def __init__(self, model_args: HFTransformerModelArgs): + super().__init__() + + # NOTE(3outeille): This prevents Hugging Face modeling from initializing ROPE (inv_freq) buffers to NaN. + # Needed when loading from seed checkpoint. + if hasattr(model_args, "deterministic") and model_args.deterministic: + torch.utils.deterministic.fill_uninitialized_memory = False + + # Try to import the model class dynamically from the transformers library if not found in globals + model_class_name = model_args.architectures[0] + model_cls = globals().get(model_class_name, None) + if model_cls is None: + try: + transformers_mod = importlib.import_module("transformers") + model_cls = getattr(transformers_mod, model_class_name) + except (ImportError, AttributeError) as e: + raise ImportError( + f"Could not find model class '{model_class_name}' in globals or transformers. " + f"Make sure the class is available. Original error: {e}" + ) from e + + # Attempt to patch model weight initialization based on architecture type + try: + model_name_prefix = model_class_name.replace("ForCausalLM", "") + model_module = importlib.import_module(model_cls.__module__) + + attention_cls = getattr(model_module, f"{model_name_prefix}Attention", None) + mlp_cls = getattr(model_module, f"{model_name_prefix}MLP", None) + decoder_layer_cls = getattr( + model_module, f"{model_name_prefix}DecoderLayer", None + ) + + required_classes = { + "Attention": attention_cls, + "DecoderLayer": decoder_layer_cls, + } + + if all(required_classes.values()): + logger.info(f"Applying Llama-like patch for {model_name_prefix}") + self._patch_hf_llama_like( + decoder_layer_cls=decoder_layer_cls, + attention_cls=attention_cls, + mlp_cls=mlp_cls, # mlp_cls can be None + ) + else: + missing = [name for name, cls in required_classes.items() if not cls] + logger.warning( + f"Could not find required classes ({', '.join(missing)}) for {model_name_prefix}. " + "Skipping Llama-like patch." + ) + + except Exception as e: + logger.warning( + f"Failed to apply agnostic patch for {model_class_name} due to: {e}. " + "Weight initialization might not match TorchTitan." + ) + + self.model = model_cls(config=model_args) + self.max_seq_len = model_args.max_seq_len + self.cp_mesh = None + + # Convert ModuleList to ModuleDict to preserve original indices + # This ensures state dict keys match checkpoint keys + if isinstance(self.model.model.layers, nn.ModuleList): + self.model.model.layers = SliceableModuleDict( + {str(i): layer for i, layer in enumerate(self.model.model.layers)} + ) + + for layer in self.model.model.layers.values(): + layer.moe_enabled = False + + def set_cp_mesh(self, mesh): + self.cp_mesh = mesh + + def _patch_hf_llama_like(self, decoder_layer_cls, attention_cls, mlp_cls=None): + """ + This patch modifies a Hugging Face Llama-like model's weight initialization to match + the initialization scheme used in TorchTitan. This is crucial for ensuring + bit-for-bit reproducibility when converting checkpoints between the native + TorchTitan format and the Hugging Face format. + + The patch targets the following aspects of the model: + - `PreTrainedModel._initialize_weights`: Handles meta device initialization correctly. + - `PreTrainedModel._init_weights`: Implements TorchTitan's specific initialization + for attention, MLP, embedding, and layer norm layers. This includes depth-dependent + initialization for attention and MLP layers. + - `DecoderLayer.__init__`: Adds `layer_idx` to attention and MLP modules within + each decoder layer, which is required for the depth-dependent initialization. + """ + + _original_decoder_layer_init = decoder_layer_cls.__init__ + + def _decoder_layer_init_patched(self, config: PretrainedConfig, layer_idx: int): + _original_decoder_layer_init(self, config, layer_idx) + self.layer_idx = layer_idx + # Ensure both attention and mlp modules have layer_idx for depth-based init + if hasattr(self, "self_attn"): + self.self_attn.layer_idx = layer_idx + # some models might not have mlp in each layer + if hasattr(self, "mlp") and self.mlp is not None: + self.mlp.layer_idx = layer_idx + + def _initialize_weights_patched(self, module): + # NOTE(3outeille): monkey-patch PreTrainedModel to handle meta device initialization correctly + # The default _initialize_weights sets _is_hf_initialized = True even on a meta device, + # which prevents subsequent proper initialization. + if getattr(module, "_is_hf_initialized", False): + return + + for param in module.parameters(recurse=True): + if param.device.type == "meta": + return + + # If not on a meta device, call the original weight initialization + self._init_weights(module) + module._is_hf_initialized = True + + def _init_weights_patched(self, module): + """ + Patched version of _init_weights to match TorchTitan's initialization for Llama-like models. + `self` is a PreTrainedModel instance. + """ + config = self.config + # Build tuple of classes to check for layer_idx-based init_std calculation + layer_idx_classes = [attention_cls] + if mlp_cls: + layer_idx_classes.append(mlp_cls) + layer_idx_classes = tuple(layer_idx_classes) + + if isinstance(module, layer_idx_classes): + if not hasattr(module, "layer_idx"): + raise ValueError( + f"Module {module} does not have a layer_idx attribute" + ) + + layer_idx = module.layer_idx + + if hasattr(config, "depth_init") and config.depth_init: + init_std = 0.02 / (2 * (layer_idx + 1)) ** 0.5 + else: + init_std = 0.02 / (2 * config.num_hidden_layers) ** 0.5 + + if isinstance(module, attention_cls): + # Initialize weights and biases for q, k, v projections + for proj_name in ["q_proj", "k_proj", "v_proj"]: + proj = getattr(module, proj_name) + nn.init.trunc_normal_(proj.weight, mean=0.0, std=0.02) + if proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(proj.bias, -bound, bound) + + # Handle different names for the output projection layer + o_proj = getattr(module, "o_proj", getattr(module, "dense", None)) + if o_proj is not None: + nn.init.trunc_normal_(o_proj.weight, mean=0.0, std=init_std) + if o_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(o_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(o_proj.bias, -bound, bound) + + elif mlp_cls and isinstance(module, mlp_cls): + # Handle different names for MLP layers + gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None)) + up_proj = getattr(module, "up_proj", None) + down_proj = getattr(module, "down_proj", getattr(module, "fc2", None)) + + # gate_proj (or fc1) should always use std=0.02 for numerical stability. + if gate_proj is not None: + nn.init.trunc_normal_(gate_proj.weight, mean=0.0, std=0.02) + if gate_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(gate_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(gate_proj.bias, -bound, bound) + # up_proj and down_proj (or fc2) use the depth-dependent init_std. + if up_proj is not None: + nn.init.trunc_normal_(up_proj.weight, mean=0.0, std=init_std) + if up_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(up_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(up_proj.bias, -bound, bound) + if down_proj is not None: + nn.init.trunc_normal_(down_proj.weight, mean=0.0, std=init_std) + if down_proj.bias is not None: + fan_in, _ = init._calculate_fan_in_and_fan_out(down_proj.weight) + bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 + init.uniform_(down_proj.bias, -bound, bound) + + elif module is getattr( + self, "lm_head", None + ): # TODO(3outeille): find a better way to detect lm_head + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + if module.bias is not None: + module.bias.data.zero_() + + elif isinstance(module, nn.Embedding): + # When tie_word_embeddings is True, use lm_head initialization + if ( + hasattr(config, "tie_word_embeddings") + and config.tie_word_embeddings + ): + final_out_std = config.hidden_size**-0.5 + cutoff_factor = 3 + nn.init.trunc_normal_( + module.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + else: + std = config.initializer_range + module.weight.data.normal_(mean=0.0, std=std) + + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + elif ( + isinstance( + module, + (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d), + ) + or "LayerNorm" in module.__class__.__name__ + or "RMSNorm" in module.__class__.__name__ + ): + # Norms can exist without weights (in which case they are None from torch primitives) + if hasattr(module, "weight") and module.weight is not None: + module.weight.data.fill_(1.0) + if hasattr(module, "bias") and module.bias is not None: + module.bias.data.zero_() + + decoder_layer_cls.__init__ = _decoder_layer_init_patched + PreTrainedModel._init_weights = _init_weights_patched + PreTrainedModel._initialize_weights = _initialize_weights_patched + + @property + def tok_embeddings(self): + """Returns the model's embed_tokens, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + return self.model.model.embed_tokens + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @tok_embeddings.setter + def tok_embeddings(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "embed_tokens" + ): # Llama-like + self.model.model.embed_tokens = value + else: + raise AttributeError( + "Could not find embed_tokens in the model. Please check the model structure." + ) + + @property + def layers(self): + """Returns the model's layers, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + return self.model.model.layers + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @layers.setter + def layers(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "layers" + ): # Llama-like + self.model.model.layers = value + else: + raise AttributeError( + "Could not find layers in the model. Please check the model structure." + ) + + @property + def norm(self): + """Returns the model's norm, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + return self.model.model.norm + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + return self.model.model.final_layernorm + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @norm.setter + def norm(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "norm" + ): # Llama-like + self.model.model.norm = value + elif hasattr(self.model, "model") and hasattr( + self.model.model, "final_layernorm" + ): # Phi-like + self.model.model.final_layernorm = value + else: + raise AttributeError( + "Could not find norm in the model. Please check the model structure." + ) + + @property + def output(self): + """Returns the model's output layer, handling different Hugging Face model structures.""" + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + return self.model.lm_head + else: + # Add more cases here if needed for other model architectures + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @output.setter + def output(self, value): + if hasattr(self.model, "lm_head"): # For models like LlamaForCausalLM + self.model.lm_head = value + else: + raise AttributeError( + "Could not find output (lm_head) in the model. Please check the model structure." + ) + + @property + def rotary_emb(self): + """Returns the model's rotary_emb, handling different Hugging Face model structures.""" + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + return self.model.model.rotary_emb + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + @rotary_emb.setter + def rotary_emb(self, value): + if hasattr(self.model, "model") and hasattr( + self.model.model, "rotary_emb" + ): # Llama-like + self.model.model.rotary_emb = value + else: + raise AttributeError( + "Could not find rotary_emb in the model. Please check the model structure." + ) + + def forward(self, *args, **kwargs): + local_seq_len = self.max_seq_len + local_seq_len //= ( + self.cp_mesh.size() + if self.cp_mesh is not None and self.cp_mesh.size() > 1 + else 1 + ) + kwargs["position_ids"] = torch.arange( + local_seq_len, device=args[0].device + ).unsqueeze(0) + output = self.model.model(*args, **kwargs) + output = self.model.lm_head(output.last_hidden_state) + return output + + def init_weights(self, *args, **kwargs): + # This method replicates the behavior of the original PreTrainedModel.init_weights, + # but with a custom weight initialization function that skips nn.Identity modules (when PP is enabled) + + if self.model.config.pruned_heads: + logger.info("Pruning heads as per model configuration.") + self.model.prune_heads(self.model.config.pruned_heads) + + original_init_weights_fn = self.model._init_weights + + def selective_init(module): + # For pipeline parallel, we need to skip nn.Identity modules + if not isinstance(module, nn.Identity): + original_init_weights_fn(module) + else: + logger.info("Skipping nn.Identity module during weight initialization.") + + self.model.apply(selective_init) + + # TODO(3outeille): For pipeline parallel, only tie weights if both input and output embeddings are on the same device + # Maybe better way of handling this? + if not isinstance(self.tok_embeddings, nn.Identity) and not isinstance( + self.output, nn.Identity + ): + self.model.tie_weights() + + def named_children(self): + """ + Provides a flattened view of the model's main components, + making it compatible with TorchTitan's expectations. + """ + yield "tok_embeddings", self.tok_embeddings + yield "layers", self.layers + yield "norm", self.norm + yield "output", self.output + yield "rotary_emb", self.rotary_emb + + def __setattr__(self, name, value): + # If a property with a setter exists for this name, use it. + # This is to bypass the nn.Module.__setattr__ logic that + # directly registers modules and skips property setters. + cls = self.__class__ + if hasattr(cls, name): + prop = getattr(cls, name) + if isinstance(prop, property) and prop.fset is not None: + prop.fset(self, value) + return + + # Otherwise, fall back to the default nn.Module behavior. + super().__setattr__(name, value) diff --git a/torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py b/torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py new file mode 100644 index 0000000000..35df7bb86a --- /dev/null +++ b/torchtitan/experiments/transformers_modeling_backend/tests/integration_tests.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +from tests.integration_tests import OverrideDefinitions +from tests.integration_tests.run_tests import run_tests + + +def build_transformers_modeling_backend_test_list() -> list[OverrideDefinitions]: + """ + key is the config file name and value is a list of OverrideDefinitions + that is used to generate variations of integration tests based on the + same root config file. + """ + integration_tests_flavors = [ + OverrideDefinitions( + [ + [ + "--model.name transformers_modeling_backend", + "--job.custom_config_module=torchtitan.experiments.transformers_modeling_backend.job_config", + "--hf_transformers.model Qwen/Qwen2.5-7B", + "--parallelism.data_parallel_shard_degree 2", + "--parallelism.tensor_parallel_degree 2", + "--parallelism.pipeline_parallel_degree 2", + "--parallelism.pipeline_parallel_schedule 1F1B", + ], + ], + "Transformers Backend FSDP+TP+PP", + "transformers_modeling_backend_fsdp+tp+pp", + ngpu=8, + ), + ] + return integration_tests_flavors + + +_TEST_SUITES_FUNCTION = { + "transformers_modeling_backend": build_transformers_modeling_backend_test_list, +} + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("output_dir") + parser.add_argument( + "--config_path", + default="./tests/integration_tests/base_config.toml", + help="Base config path for integration tests. This is the config that will be used as a base for all tests.", + ) + parser.add_argument( + "--test_name", + default="all", + help="test to run, acceptable values: `test_name` in `build_test_list` (default: all)", + ) + parser.add_argument("--ngpu", default=8, type=int) + args = parser.parse_args() + + if not os.path.exists(args.output_dir): + os.makedirs(args.output_dir) + if os.listdir(args.output_dir): + raise RuntimeError("Please provide an empty output directory.") + + test_list = _TEST_SUITES_FUNCTION["transformers_modeling_backend"]() + run_tests(args, test_list) + + +if __name__ == "__main__": + main() diff --git a/torchtitan/experiments/vlm/datasets/mm_datasets.py b/torchtitan/experiments/vlm/datasets/mm_datasets.py index 2a6a2000b9..2234496983 100644 --- a/torchtitan/experiments/vlm/datasets/mm_datasets.py +++ b/torchtitan/experiments/vlm/datasets/mm_datasets.py @@ -11,6 +11,7 @@ It supports both streaming and non-streaming datasets from HuggingFace. """ +from dataclasses import asdict from typing import Any, Callable import torch @@ -381,14 +382,14 @@ def build_mm_dataloader( """Build a data loader for multimodal datasets. Args: - dp_world_size: Data parallel world size - dp_rank: Data parallel rank - tokenizer: Tokenizer for text processing - job_config: Job configuration - infinite: Whether to loop infinitely + dp_world_size: Data parallel world size. + dp_rank: Data parallel rank. + tokenizer: Tokenizer for text processing. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop infinitely. Returns: - DataLoader with appropriate parallelism handling + DataLoader with appropriate parallelism handling. """ dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -429,12 +430,17 @@ def build_mm_dataloader( special_tokens=special_tokens, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + "collate_fn": collate_fn, + } + base_dataloader = ParallelAwareDataloader( dataset=dataset, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, - collate_fn=collate_fn, + **dataloader_kwargs, ) return base_dataloader diff --git a/torchtitan/experiments/vlm/infra/loss.py b/torchtitan/experiments/vlm/infra/loss.py index bba51f2819..7a3a490fb7 100644 --- a/torchtitan/experiments/vlm/infra/loss.py +++ b/torchtitan/experiments/vlm/infra/loss.py @@ -104,7 +104,7 @@ def build_token_imbalance_ce_loss( # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: # DP split the batch dim B # CP split the sequence dim S - token_mesh = parallel_dims.world_mesh["dp_cp"] + token_mesh = parallel_dims.get_mesh("loss") ft_pg = ft_manager.loss_sync_pg loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) if job_config.compile.enable and "loss" in job_config.compile.components: diff --git a/torchtitan/experiments/vlm/infra/parallelize.py b/torchtitan/experiments/vlm/infra/parallelize.py index 6a97e4ece1..d87070bee6 100644 --- a/torchtitan/experiments/vlm/infra/parallelize.py +++ b/torchtitan/experiments/vlm/infra/parallelize.py @@ -38,7 +38,6 @@ def parallelize_vlm( the model must fit on GPU or CPU memory. """ assert isinstance(model.encoder, nn.Module) - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -48,9 +47,9 @@ def parallelize_vlm( Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type != "sdpa": + raise NotImplementedError("CP support is only supported for SDPA.") if parallel_dims.tp_enabled: raise NotImplementedError("TP support for VLM training is still in progress.") @@ -63,7 +62,6 @@ def parallelize_vlm( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, op_sac_save_list=_op_sac_save_list, ) apply_ac(model.encoder, job_config.activation_checkpoint) @@ -75,14 +73,13 @@ def parallelize_vlm( if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + parallel_dims.get_mesh(names), param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -101,11 +98,12 @@ def parallelize_vlm( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=job_config.compile.enable, ) diff --git a/torchtitan/experiments/vlm/model/args.py b/torchtitan/experiments/vlm/model/args.py index 11b6439ddd..49ba31246b 100644 --- a/torchtitan/experiments/vlm/model/args.py +++ b/torchtitan/experiments/vlm/model/args.py @@ -53,7 +53,7 @@ class Siglip2ModelArgs: spatial_merge_size: int = 1 layer_norm_eps: float = 1e-6 - use_flex_attn: bool = True + attn_type: str = "flex" attn_mask_type: str = "causal" diff --git a/torchtitan/hf_datasets/text_datasets.py b/torchtitan/hf_datasets/text_datasets.py index 4c88a36f07..12bc708dd6 100644 --- a/torchtitan/hf_datasets/text_datasets.py +++ b/torchtitan/hf_datasets/text_datasets.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from dataclasses import asdict from functools import partial from typing import Any, Callable @@ -166,7 +167,7 @@ def load_state_dict(self, state_dict): self._data.load_state_dict(state_dict["data"]) def state_dict(self): - _state_dict = {"token_buffer": self._token_buffer} + _state_dict: dict[str, Any] = {"token_buffer": self._token_buffer} if isinstance(self._data, Dataset): _state_dict["sample_idx"] = self._sample_idx @@ -185,7 +186,15 @@ def build_text_dataloader( job_config: JobConfig, infinite: bool = True, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + tokenizer: Tokenizer to use for encoding text. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -201,11 +210,16 @@ def build_text_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=hf_ds, + hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) @@ -216,7 +230,15 @@ def build_text_validation_dataloader( job_config: JobConfig, infinite: bool = False, ) -> ParallelAwareDataloader: - """Build a validation data loader for HuggingFace datasets.""" + """Build a validation data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + tokenizer: Tokenizer to use for encoding text. + job_config: Job configuration containing dataset and DataLoader settings. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.validation.dataset dataset_path = job_config.validation.dataset_path batch_size = job_config.validation.local_batch_size @@ -232,9 +254,14 @@ def build_text_validation_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.validation.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=hf_ds, + hf_ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) diff --git a/torchtitan/models/__init__.py b/torchtitan/models/__init__.py index 527b45bc19..f372d29461 100644 --- a/torchtitan/models/__init__.py +++ b/torchtitan/models/__init__.py @@ -5,5 +5,5 @@ # LICENSE file in the root directory of this source tree. _supported_models = frozenset( - ["deepseek_v3", "flux", "llama3", "llama3_ft", "llama4", "qwen2", "qwen3"] + ["deepseek_v3", "flux", "gpt_oss", "llama3", "llama3_ft", "llama4", "qwen3"] ) diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 8873ef2f90..2493323ab7 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -6,33 +6,101 @@ # # Copyright (c) Meta Platforms, Inc. All Rights Reserved. -import functools from collections.abc import Callable -from typing import ClassVar +from typing import ClassVar, NamedTuple import torch import torch.nn.functional as F from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, + _score_mod_signature, BlockMask, create_block_mask, flex_attention, ) +from torch.nn.attention.varlen import varlen_attn +from torch.types import Number + __all__ = [ "FlexAttentionWrapper", "ScaledDotProductAttentionWrapper", + "VarlenAttentionWrapper", + "VarlenMetadata", "get_causal_mask_mod", "get_document_mask_mod", "get_sliding_window_mask_mod", - "get_fixed_block_mask_mod", "get_block_causal_mask_mod_by_seq_lens", + "get_fixed_block_mask_mod", "create_attention_mask", + "create_varlen_metadata_from_sequence_lengths", ] +class VarlenMetadata(NamedTuple): + """ + Cumulative sequence positions for queries and keys/values. + + """ + + cu_seq_q: torch.Tensor + cu_seq_k: torch.Tensor + max_q: Number + max_k: Number + + +class VarlenAttentionWrapper(torch.nn.Module): + _compiled_varlen_attn: ClassVar[Callable] = torch.compile( + varlen_attn, mode="max-autotune-no-cudagraphs" + ) + + def forward( + self, + xq: torch.Tensor, + xk: torch.Tensor, + xv: torch.Tensor, + head_dim: torch.Tensor, + attention_masks: VarlenMetadata, + scale: float | None = None, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + cu_seq_q = attention_masks.cu_seq_q + cu_seq_k = attention_masks.cu_seq_k + max_q = attention_masks.max_q + max_k = attention_masks.max_k + + n_local_heads = xq.shape[1] + # pyrefly: ignore [no-matching-overload] + xq_packed = xq.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] + xk_packed = xk.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + # pyrefly: ignore [no-matching-overload] + xv_packed = xv.transpose(1, 2).reshape(-1, n_local_heads, head_dim) + + return VarlenAttentionWrapper._compiled_varlen_attn( + xq_packed, + xk_packed, + xv_packed, + cu_seq_q, + cu_seq_k, + max_q, + max_k, + scale=scale, + # window_size=(left, right) controls the attention window relative to each + # query position. 'left' is how many tokens before the query to attend to, + # and 'right' is how many tokens after. A value of -1 means unlimited. + # + # This replaces the is_causal flag: + # - (-1, 0): Causal attention - each token attends to all previous tokens + # and itself, but no future tokens. Equivalent to is_causal=True. + # - (-1, -1): Full bidirectional attention (no masking). Equivalent to + # is_causal=False. + # - (W, 0): Sliding window causal - attend to at most W previous tokens. + window_size=(-1, 0), + ) + + class FlexAttentionWrapper(torch.nn.Module): """Wrapper around `flex_attention` to make it torch.compile and CP compatible. @@ -42,13 +110,20 @@ class FlexAttentionWrapper(torch.nn.Module): 2) Being a wrapper allows us to apply _ContextParallel to it. Note: - The forward function must have q, k, v as the first three arguments, and - block_mask as a keyword argument to be compatible with _ContextParallel. + The forward function accepts q, k, v as the first three arguments, followed by + optional arguments (score_mod, block_mask, scale, return_lse) that can be passed + either positionally or as keywords to be compatible with _ContextParallel. """ _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, - mode="max-autotune-no-cudagraphs", + # This options also encapsulate max-autotune-no-cudagraphs. + options={ + "wrap_inductor_compiled_regions": True, + "max_autotune": True, + "coordinate_descent_tuning": True, + "triton.cudagraphs": False, + }, ) def forward( @@ -56,8 +131,8 @@ def forward( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, - *, - block_mask: BlockMask, + score_mod: _score_mod_signature | None = None, + block_mask: BlockMask | None = None, scale: float | None = None, return_lse: bool = False, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: @@ -68,7 +143,6 @@ def forward( # `FlexAttentionWrapper._compiled_flex_attn` is correct. # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation # to convert `lse` to be DTensor. - return FlexAttentionWrapper._compiled_flex_attn( q, k, @@ -92,7 +166,7 @@ class ScaledDotProductAttentionWrapper(torch.nn.Module): """ # TODO: remove sdpa_backends after PyTorch 2.9 is released. - sdpa_backends: ClassVar[list[SDPBackend]] = [] + sdpa_backends: list[SDPBackend] = [] def __init__(self) -> None: super().__init__() @@ -116,22 +190,19 @@ def forward( return F.scaled_dot_product_attention(q, k, v, scale=scale, is_causal=True) -# We cannot do inner function/closure because we won't be able to cache it -- -# if we an inner function, a new closure will be created every time -# `get_causal_mask_mod` is called. -def _causal_mask( - b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor -) -> torch.Tensor: - """Causal mask that prevents attention to future tokens.""" - return q_idx >= kv_idx - - def get_causal_mask_mod() -> _mask_mod_signature: """Returns a causal mask modifier for flex attention. Returns: A mask modifier function that implements causal masking. """ + + def _causal_mask( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + """Causal mask that prevents attention to future tokens.""" + return q_idx >= kv_idx + return _causal_mask @@ -268,11 +339,126 @@ def sliding_window_mod( _compiled_create_block_mask = torch.compile(create_block_mask) -@functools.lru_cache(4) def create_attention_mask(*args, **kwargs): - """Create an attention mask using compiled create_block_mask. + """Create an attention mask using compiled create_block_mask.""" + return _compiled_create_block_mask(*args, **kwargs) - This function is cached to avoid recreating BlockMasks for the same - arguments. + +def create_varlen_metadata_for_document( + input_batch: torch.Tensor, eos_id: int +) -> VarlenMetadata: """ - return _compiled_create_block_mask(*args, **kwargs) + Creates cumulative sequence length indices needed for variable length attention + + Args: + input_batch + eos_id: the EOS id marker + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size, seq_len = input_batch.shape + device = input_batch.device + cu_seqlens_list, all_seq_lengths = [], [] + offset = 0 + max_seqlen = 0 + + for b in range(batch_size): + tokens = input_batch[b] + eos_positions = (tokens == eos_id).nonzero(as_tuple=True)[0].to(torch.int32) + sample_cu_seqlens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + eos_positions + 1, + torch.tensor([seq_len], dtype=torch.int32, device=device), + ] + ) + sample_cu_seqlens = torch.unique_consecutive(sample_cu_seqlens) + + seq_lengths = torch.diff(sample_cu_seqlens) + all_seq_lengths.append(seq_lengths) + + cu_seqlens_adjusted = sample_cu_seqlens[:-1] + offset + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] + ) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) + + +def create_varlen_metadata_from_sequence_lengths( + sequence_lengths: list[torch.Tensor], + seq_len: int, + device: torch.device, +) -> VarlenMetadata: + """ + Creates cumulative sequence length indices needed for variable length attention + from explicit sequence lengths provided by the data loader. + + This is an alternative to `create_varlen_metadata_for_document` that doesn't + rely on EOS token detection, making it suitable for multi-turn chat data + that has multiple EOS tokens per sample. + + Args: + sequence_lengths: List of tensors, one per batch element, containing + the lengths of each document/turn within that batch element. + seq_len: The sequence length dimension of the batch. + device: The device to place the output tensors on. + + Returns: + VarlenMetadata containing cumulative sequence length indices for q, k, and max_seq_len + """ + batch_size = len(sequence_lengths) + cu_seqlens_list = [] + all_seq_lengths = [] + offset = 0 + + for b in range(batch_size): + sample_seq_lens = sequence_lengths[b] + # Compute cumulative sequence lengths for this sample + sample_cu_seqlens = torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=device), + torch.cumsum(sample_seq_lens.to(torch.int32), dim=0), + ] + ) + + all_seq_lengths.append(sample_seq_lens) + + # Adjust for batch offset (excluding the final cumulative sum) + cu_seqlens_adjusted = (sample_cu_seqlens[:-1] + offset).to(torch.int32) + cu_seqlens_list.append(cu_seqlens_adjusted) + + offset += seq_len + + packed_cu_seqlens = torch.cat( + cu_seqlens_list + [torch.tensor([offset], dtype=torch.int32, device=device)] + ).to(torch.int32) + + max_seqlen = 0 + if len(all_seq_lengths) > 0: + all_seq_lengths_cat = torch.cat(all_seq_lengths) + # device to host sync but only done once per model forward + max_seqlen = all_seq_lengths_cat.max().item() + + return VarlenMetadata( + cu_seq_q=packed_cu_seqlens, + cu_seq_k=packed_cu_seqlens, + max_q=max_seqlen, + max_k=max_seqlen, + ) diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index eedc20cbb5..d6ae7e2017 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -7,8 +7,8 @@ from torchtitan.components.loss import build_cross_entropy_loss from torchtitan.components.lr_scheduler import build_lr_schedulers from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing -from torchtitan.components.tokenizer import build_hf_tokenizer from torchtitan.distributed.pipeline_parallel import pipeline_llm +from torchtitan.experiments.kimi_linear.model.tokenizer import build_kimi_tokenizer from torchtitan.hf_datasets.dataloader import build_dataloader from torchtitan.models.moe import MoEArgs from torchtitan.protocols.train_spec import TrainSpec @@ -72,7 +72,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "16B": DeepSeekV3ModelArgs( @@ -97,7 +97,7 @@ qk_rope_head_dim=64, v_head_dim=128, mscale=0.70, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "236B": DeepSeekV3ModelArgs( @@ -112,19 +112,19 @@ num_experts=160, num_shared_experts=2, top_k=6, + num_expert_groups=8, + num_limited_groups=3, score_func="softmax", route_norm=False, route_scale=16.0, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=3, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "671B": DeepSeekV3ModelArgs( @@ -139,19 +139,19 @@ num_experts=256, num_shared_experts=1, top_k=8, + num_expert_groups=8, + num_limited_groups=4, score_func="sigmoid", route_norm=True, route_scale=2.5, score_before_experts=False, ), - n_expert_groups=8, - n_limited_groups=4, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "kimi_k2": DeepSeekV3ModelArgs( @@ -173,19 +173,47 @@ route_scale=2.827, score_before_experts=False, ), - n_expert_groups=1, - n_limited_groups=1, q_lora_rank=1536, kv_lora_rank=512, qk_nope_head_dim=128, qk_rope_head_dim=64, v_head_dim=128, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", rope_theta=50000.0, rope_factor=32.0, beta_fast=1, ), + "kimi_k2_sft": DeepSeekV3ModelArgs( + vocab_size=163840, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + # n_layers=9, #smaller for testing + n_layers=61, + n_dense_layers=1, + n_heads=64, + norm_eps=1e-6, + moe_args=MoEArgs( + num_experts=384, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.827, + score_before_experts=False, + ), + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + attn_type="flex", + attn_mask_type="block_causal_by_sequence_lengths", + rope_theta=50000.0, + rope_factor=32.0, + beta_fast=1, + ), } @@ -198,7 +226,7 @@ def get_train_spec() -> TrainSpec: build_optimizers_fn=build_optimizers_with_moe_load_balancing, build_lr_schedulers_fn=build_lr_schedulers, build_dataloader_fn=build_dataloader, - build_tokenizer_fn=build_hf_tokenizer, + build_tokenizer_fn=build_kimi_tokenizer, # falls back to hf tokenizer if tiktoken.model not found build_loss_fn=build_cross_entropy_loss, state_dict_adapter=DeepSeekV3StateDictAdapter, ) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7da79c361e..05c345adab 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -15,10 +15,11 @@ RowwiseParallel, SequenceParallel, ) - from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( @@ -28,12 +29,14 @@ ) from torchtitan.tools.logging import logger - # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save @@ -41,6 +44,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -50,7 +54,6 @@ def parallelize_deepseekv3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -61,9 +64,13 @@ def parallelize_deepseekv3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": + raise NotImplementedError( + f"Context Parallel only supports SDPA and FlexAttention." + f"Got attn_type='{attn_type}'. " + f"Varlen attention is not supported with CP." + ) if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -79,29 +86,59 @@ def parallelize_deepseekv3( "Currently, float8 tensorwise TP is not tested for deepseekv3" ) + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, - use_flex_attn=use_flex_attn, + positions_enabled=parallel_dims.cp_enabled or job_config.training.dataset_type == "preprocessed", ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) + + # Check if using DeepEP for MoE communication + if job_config.parallelism.expert_parallel_comm_backend == "deepep": + if not parallel_dims.ep_enabled: + raise ValueError( + "DeepEP requires expert parallelism (ep_degree > 1). " + "The DeepEP MoE model code does not support EP=1. " + "Please set expert_parallel_degree > 1 or use standard communication backend." + ) + if parallel_dims.etp_enabled: + raise NotImplementedError( + "DeepEP with Expert Tensor Parallelism (ETP) is not supported yet. " + "Please set expert_tensor_parallel_degree=1 or use standard communication backend." + ) + + use_deepep = True + + # Import deepep module to register custom ops before accessing them + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) + _op_sac_save_list.add(torch.ops.deepep.combine.default) + else: + use_deepep = False if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, - use_deepep=model.model_args.moe_args.use_deepep, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + dual_pipe_v=dual_pipe_v, + use_deepep=use_deepep, + ) + + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) model_compile_enabled = ( @@ -113,29 +150,29 @@ def parallelize_deepseekv3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -146,12 +183,9 @@ def parallelize_deepseekv3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -159,15 +193,12 @@ def parallelize_deepseekv3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -182,7 +213,7 @@ def apply_non_moe_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, - use_flex_attn: bool, + positions_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -212,34 +243,28 @@ def apply_non_moe_tp( PrepareModuleInput, ) - if use_flex_attn: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) - else: - attention_kernel_plan = prepare_module_input( - input_layouts=(Shard(1), Shard(1), Shard(1)), - desired_input_layouts=(Shard(1), Shard(1), Shard(1)), - use_local_output=True, - ) + attention_kernel_plan = prepare_module_input( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_output=True, + ) # Apply tensor + sequence parallelism to every transformer block # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if positions_enabled else None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), - input_kwarg_layouts={ - "position_ids": Replicate(), - }, - desired_input_kwarg_layouts={ - "position_ids": Replicate(), - }, + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), # NOTE: use_local_output=False make the output to be a DTensor instead of a plain Tensor # so that the intermedidate results k is generated as a DTensor and its gradient is @@ -253,6 +278,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if transformer_block.attention.q_lora_rank == 0: layer_plan.update( { @@ -270,6 +296,7 @@ def apply_non_moe_tp( } ) + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -284,8 +311,10 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 1a6ff3cf6e..d7ed015300 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from torch import nn - from torchtitan.config import JobConfig from torchtitan.models.moe import MoEArgs from torchtitan.models.utils import get_moe_model_nparams_and_flops @@ -37,14 +36,12 @@ class DeepSeekV3ModelArgs(BaseModelArgs): n_heads (int): Number of attention heads. norm_eps (float): Epsilon value used for RMSNorm. moe_args (MoEArgs): MoE configuration. - n_expert_groups (int): Number of expert groups. - n_limited_groups (int): Number of limited groups for MoE routing. q_lora_rank (int): LoRA rank for query projections. kv_lora_rank (int): LoRA rank for key-value projections. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. - use_flex_attn (bool): Whether to use FlexAttention. + attn_type (str): Attention type. attn_mask_type (str): Type of attention mask. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. @@ -66,9 +63,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): # MoE moe_args: MoEArgs = field(default_factory=MoEArgs) - # TODO: node-limited routing is not supported yet - n_expert_groups: int = 1 - n_limited_groups: int = 1 + + # Expert parallel communication backend (set from config) + expert_parallel_comm_backend: str = "standard" # "standard" or "deepep" # Multi-Head Latent Attention (MLA) q_lora_rank: int = 0 @@ -76,7 +73,7 @@ class DeepSeekV3ModelArgs(BaseModelArgs): qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # yarn @@ -102,22 +99,14 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) - # Pass DeepEP config to MoE layer and validate - self.moe_args.deepep_config = job_config.deepep - self.moe_args.validate_deepep_config() + # Configure expert parallel communication backend from config (defaults to "standard") + self.moe_impl = job_config.parallelism.expert_parallel_comm_backend - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 67bb37480e..d99b4b3247 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -8,7 +8,6 @@ import torch from torch import nn - from torch.nn.attention.flex_attention import and_masks, BlockMask from torchtitan.components.peft.lora import lora_or_linear, per_layer_config @@ -17,12 +16,12 @@ from torchtitan.models.attention import ( create_attention_mask, FlexAttentionWrapper, - get_block_causal_mask_mod_by_seq_lens, get_causal_mask_mod, get_document_mask_mod, + get_block_causal_mask_mod_by_seq_lens, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import build_moe, fast_init_trunc_normal_, fast_init_normal_, FeedForward from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -129,8 +128,56 @@ def linear_ramp_factor(min: float, max: float, dim: int) -> torch.Tensor: return freqs_cis +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + freqs_cis (torch.Tensor): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + seqlen = x.shape[1] + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis + + def apply_rotary_emb( - x: torch.Tensor, freqs_cis: torch.Tensor, position_ids: torch.Tensor | None = None + x: torch.Tensor, freqs_cis: torch.Tensor, positions: torch.Tensor | None = None ) -> torch.Tensor: """ Applies rotary positional embeddings to the input tensor. @@ -138,16 +185,14 @@ def apply_rotary_emb( Args: x (torch.Tensor): Input tensor with positional embeddings to be applied. freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Tensor with rotary embeddings applied. """ dtype = x.dtype x = torch.view_as_complex(x.float().view(*x.shape[:-1], -1, 2)) - if position_ids is None: - freqs_cis = freqs_cis[: x.size(1)].view(1, x.size(1), 1, x.size(-1)) - else: - freqs_cis = freqs_cis[position_ids].view(x.shape[0], x.size(1), 1, x.size(-1)) + freqs_cis = reshape_for_broadcast(freqs_cis, x, positions) y = torch.view_as_real(x * freqs_cis).flatten(3) return y.to(dtype) @@ -235,18 +280,24 @@ def yarn_get_mscale(scale: float, mscale: float) -> float: self.softmax_scale * effective_mscale * effective_mscale ) - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "sdpa": + # pyrefly: ignore [bad-assignment] + self.inner_attention = ScaledDotProductAttentionWrapper() + case "varlen": + raise ValueError("Varlen attention is not supported with Deepseek V3.") + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Multi-Head Latent Attention (MLA) Layer. @@ -254,6 +305,8 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. @@ -273,7 +326,7 @@ def forward( q_nope, q_pe = torch.split( q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 ) - q_pe = apply_rotary_emb(q_pe, freqs_cis, position_ids=position_ids) + q_pe = apply_rotary_emb(q_pe, freqs_cis, positions) q = torch.cat([q_nope, q_pe], dim=-1) # (bsz, seqlen, n_heads, qk_head_dim) # Key-value projection @@ -281,7 +334,7 @@ def forward( kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) k_pe = apply_rotary_emb( - k_pe.unsqueeze(2), freqs_cis, position_ids=position_ids + k_pe.unsqueeze(2), freqs_cis, positions ) # (bsz, seqlen, 1, qk_rope_head_dim) kv = self.wkv_b( @@ -298,14 +351,15 @@ def forward( k = k.transpose(1, 2) # (bsz, n_heads, seqlen, qk_head_dim) v = v.transpose(1, 2) # (bsz, n_heads, seqlen, v_head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask) - output = self.inner_attention( - q, k, v, block_mask=attention_masks, scale=self.softmax_scale - ) - else: - assert attention_masks is None - output = self.inner_attention(q, k, v, scale=self.softmax_scale) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask) + output = self.inner_attention( + q, k, v, block_mask=attention_masks, scale=self.softmax_scale + ) + case _: + assert attention_masks is None + output = self.inner_attention(q, k, v, scale=self.softmax_scale) # Reshape and project output output = output.transpose( @@ -325,8 +379,8 @@ def init_weights(self, init_std: float): linear_list.append(self.wq) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.wo.weight, mean=0.0, std=init_std) self.kv_norm.reset_parameters() if self.q_lora_rank > 0: @@ -353,10 +407,11 @@ def __init__( self.moe_enabled = layer_id >= model_args.n_dense_layers if self.moe_enabled: - self.moe = MoE( - model_args.moe_args, + self.moe = build_moe( + args=model_args.moe_args, dim=model_args.dim, hidden_dim=model_args.moe_inter_dim, + moe_impl=model_args.moe_impl, peft_config=peft_config, ) else: @@ -373,7 +428,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer block. @@ -381,15 +436,14 @@ def forward( Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor with the same shape as the input. """ x = x + self.attention( - self.attention_norm(x), - freqs_cis, - attention_masks, - position_ids=position_ids, + self.attention_norm(x), freqs_cis, attention_masks, positions ) if self.moe_enabled: x = x + self.moe(self.ffn_norm(x)) @@ -407,13 +461,13 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class DeepSeekV3Model(nn.Module, ModelProtocol): +class DeepSeekV3Model(ModelProtocol): """ DeepSeek-V3 Transformer model with attention and feed-forward layers. """ def __init__(self, model_args: DeepSeekV3ModelArgs, peft_config: PEFT): - super().__init__() + super().__init__(model_args) self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) self.register_buffer( @@ -456,16 +510,17 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: with torch.device(buffer_device): self.freqs_cis = precompute_freqs_cis(self.model_args) if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 if self.output is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.output.weight, mean=0.0, std=final_out_std, @@ -508,7 +563,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass for the Transformer model. @@ -518,6 +573,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). @@ -526,7 +583,7 @@ def forward( h = self.tok_embeddings(tokens) if self.tok_embeddings is not None else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks, position_ids=position_ids) + h = layer(h, self.freqs_cis, attention_masks, positions) h = self.norm(h) if self.norm is not None else h output = self.output(h) if self.output is not None else h return output diff --git a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py index fd4ec30284..1970c7a161 100644 --- a/torchtitan/models/deepseek_v3/model/state_dict_adapter.py +++ b/torchtitan/models/deepseek_v3/model/state_dict_adapter.py @@ -106,6 +106,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "moe.experts" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -115,6 +116,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key ] = value.placements self.grouped_expert_weight_shape[abstract_key] = value.shape + self.grouped_expert_weight_mesh[abstract_key] = value.device_mesh # Split GroupedExperts weight to local individual expert weights local_expert_fqn = self._get_local_experts_weights( @@ -128,15 +130,19 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(0, self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -174,18 +180,20 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: int(expert_num) ] = value - if isinstance(value, DTensor): + # Use stored metadata to decide path (online vs offline) + # Online mode: local_experts_indices was populated during to_hf() + if titan_abstract_key in self.local_experts_indices: stacked_value = self._concatenate_expert_weights_dtensor( expert_weights_by_layer, titan_abstract_key, layer_num, - value.device_mesh, ) else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -194,6 +202,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] new_key = new_key.format(layer_num) diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml new file mode 100644 index 0000000000..e3e1c9d497 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml @@ -0,0 +1,72 @@ +# Kimi K2 - 12 nodes - EP=96, CP=16, LBS=11 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep/configs/exp1acd_12n_ep96_cp16_lbs11.toml +# Job ID: 2307 +# +# Expected Performance: +# - TPS: 402 +# - Memory: 67.55 GiB (85.2%) +# - MFU: 17.72% +# - TFLOPS: ~175 +# +# Parallelism: EP=96, CP=16, DP=1 (dp_replicate=1, dp_shard=1) +# Nodes: 12 (96 GPUs) +# Seq Length: 32768 +# Local Batch Size: 11 +# + +[job] +dump_folder = "./outputs/kimi_k2/12n_ep96_cp16_lbs11" +description = "Kimi K2 - 12n EP=96 CP=16 LBS=11" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 11 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml new file mode 100644 index 0000000000..7ee95edec6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml @@ -0,0 +1,73 @@ +# Kimi K2 - 36 nodes - EP=96, CP=16, HSDP (dp_replicate=3, dp_shard=6), LBS=10 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep_dp/configs/exp1aj_HSDP_r3_s6_lbs10.toml +# Job ID: 2485 +# +# Expected Performance: +# - TPS: 378 +# - Memory: 69.45 GiB (87.6%) +# - MFU: 16.64% +# +# Parallelism: EP=96, CP=16, dp_replicate=3, dp_shard=6 +# HSDP: Shard within 12 nodes, all-reduce between 3 replica groups +# Nodes: 36 (288 GPUs) +# Seq Length: 32768 +# Local Batch Size: 10 +# + +[job] +dump_folder = "./outputs/kimi_k2/36n_ep96_cp16_hsdp_replicate3_shard6_lbs10" +description = "Kimi K2 - 36n HSDP dp_replicate=3 dp_shard=6 LBS=10" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 10 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 3 +data_parallel_shard_degree = 6 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml new file mode 100644 index 0000000000..f6328e98b5 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_ep64_cp1_seq24k_lbs1.toml @@ -0,0 +1,74 @@ +# ============================================================================= +# Kimi K2 - Best Configuration (D4) +# ============================================================================= +# Run: D4 | Job ID: 2889 +# Performance: 160 TPS (BEST OVERALL) | MFU: 6.04% | Memory: 76.99GiB (97.06%) +# +# Parameters: +# - EP=64, CP=1, SEQ=24K (24576), LBS=1 +# - Nodes: 8 (64 GPUs) +# - DP_replicate=1, DP_shard=1 (auto-calculated) +# +# Source Files: +# - Config: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/configs/D4_ep64_seq24k_lbs1.toml +# - SLURM: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/scripts/launch_D4_ep64_seq24k_lbs1.slurm +# - Log: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/results/D4_ep64_seq24k_lbs1_2889.out +# - Err: /home/phuc/worklogs/2026-02-03/ep_cp_sweep/results/D4_ep64_seq24k_lbs1_2889.err +# +# Reference: /home/phuc/worklogs/2026-02-03/sweep_ep_cp_upstream_branch.md +# ============================================================================= + +[job] +dump_folder = "./outputs/ep_cp_sweep/D4_ep64_seq24k_lbs1" +description = "EP/CP Sweep D4: EP=64 CP=1 SEQ=24K LBS=1" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 1 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 1 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/flux/README.md b/torchtitan/models/flux/README.md index 2498d1a346..aa83b845db 100644 --- a/torchtitan/models/flux/README.md +++ b/torchtitan/models/flux/README.md @@ -1,11 +1,5 @@ -
- # FLUX model in torchtitan -[![integration tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_flux.yaml/badge.svg?branch=main) - -
- ## Overview This directory contains the implementation of the [FLUX](https://github.com/black-forest-labs/flux/tree/main) model in torchtitan. In torchtitan, we showcase the pre-training process of text-to-image part of the FLUX model. diff --git a/torchtitan/models/flux/__init__.py b/torchtitan/models/flux/__init__.py index 0fee76e60d..d5ec94b1d6 100644 --- a/torchtitan/models/flux/__init__.py +++ b/torchtitan/models/flux/__init__.py @@ -20,6 +20,7 @@ __all__ = [ "FluxModelArgs", "FluxModel", + # pyrefly: ignore [missing-module-attribute] "flux_configs", "parallelize_flux", ] diff --git a/torchtitan/models/flux/flux_datasets.py b/torchtitan/models/flux/flux_datasets.py index f3cf283aa6..5b6492dff1 100644 --- a/torchtitan/models/flux/flux_datasets.py +++ b/torchtitan/models/flux/flux_datasets.py @@ -6,10 +6,11 @@ import itertools import math +from dataclasses import asdict from typing import Any, Callable, Optional import numpy as np -import PIL +import PIL.Image import torch from datasets import Dataset, load_dataset @@ -271,6 +272,7 @@ def __iter__(self): # skip low quality image or image with color channel = 1 if sample_dict["image"] is None: + # pyrefly: ignore [missing-attribute] sample = sample.get("__key__", "unknown") logger.warning( f"Low quality image {sample} is skipped in Flux Dataloader." @@ -279,6 +281,7 @@ def __iter__(self): # Classifier-free guidance: Replace some of the strings with empty strings. # Distinct random seed is initialized at the beginning of training for each FSDP rank. + # pyrefly: ignore [missing-attribute] dropout_prob = self.job_config.training.classifier_free_guidance_prob if dropout_prob > 0.0: if torch.rand(1).item() < dropout_prob: @@ -314,7 +317,15 @@ def build_flux_dataloader( tokenizer: FluxTokenizer | None, infinite: bool = True, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + job_config: Job configuration containing dataset and DataLoader settings. + tokenizer: Tokenizer (kept for compatibility, not used). + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.training.dataset dataset_path = job_config.training.dataset_path batch_size = job_config.training.local_batch_size @@ -332,11 +343,16 @@ def build_flux_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.training.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( dataset=ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) @@ -400,7 +416,16 @@ def build_flux_validation_dataloader( generate_timestamps: bool = True, infinite: bool = False, ) -> ParallelAwareDataloader: - """Build a data loader for HuggingFace datasets.""" + """Build a validation data loader for HuggingFace datasets. + + Args: + dp_world_size: Data parallelism world size. + dp_rank: Data parallelism rank. + job_config: Job configuration containing dataset and DataLoader settings. + tokenizer: Tokenizer (kept for compatibility, not used). + generate_timestamps: Whether to generate timesteps for validation. + infinite: Whether to loop the dataset infinitely. + """ dataset_name = job_config.validation.dataset dataset_path = job_config.validation.dataset_path batch_size = job_config.validation.local_batch_size @@ -419,9 +444,14 @@ def build_flux_validation_dataloader( infinite=infinite, ) + dataloader_kwargs = { + **asdict(job_config.validation.dataloader), + "batch_size": batch_size, + } + return ParallelAwareDataloader( - dataset=ds, + ds, dp_rank=dp_rank, dp_world_size=dp_world_size, - batch_size=batch_size, + **dataloader_kwargs, ) diff --git a/torchtitan/models/flux/inference/infer.py b/torchtitan/models/flux/inference/infer.py index 0c06a385ef..bffdb2a2e7 100644 --- a/torchtitan/models/flux/inference/infer.py +++ b/torchtitan/models/flux/inference/infer.py @@ -25,9 +25,16 @@ def inference(config: JobConfig): # Distributed processing setup: Each GPU/process handles a subset of prompts world_size = int(os.environ["WORLD_SIZE"]) global_rank = int(os.environ["RANK"]) + # pyrefly: ignore [missing-attribute] original_prompts = open(config.inference.prompts_path).readlines() total_prompts = len(original_prompts) + if total_prompts < world_size: + raise ValueError( + f"Number of prompts ({total_prompts}) must be >= number of ranks ({world_size}). " + f"FSDP all-gather will hang if some ranks have no prompts to process." + ) + # Distribute prompts across processes using round-robin assignment prompts = original_prompts[global_rank::world_size] @@ -39,13 +46,14 @@ def inference(config: JobConfig): if prompts: # Generate images for this process's assigned prompts + # pyrefly: ignore [missing-attribute] bs = config.inference.local_batch_size output_dir = os.path.join( config.job.dump_folder, + # pyrefly: ignore [missing-attribute] config.inference.save_img_folder, ) - # Create mapping from local indices to global prompt indices global_ids = list(range(global_rank, total_prompts, world_size)) @@ -54,6 +62,7 @@ def inference(config: JobConfig): device=trainer.device, dtype=trainer._dtype, job_config=trainer.job_config, + # pyrefly: ignore [bad-argument-type] model=trainer.model_parts[0], prompt=prompts[i : i + bs], autoencoder=trainer.autoencoder, diff --git a/torchtitan/models/flux/inference/sampling.py b/torchtitan/models/flux/inference/sampling.py index f43d0fc2c5..4c8c4ce993 100644 --- a/torchtitan/models/flux/inference/sampling.py +++ b/torchtitan/models/flux/inference/sampling.py @@ -36,6 +36,7 @@ def time_shift(mu: float, sigma: float, t: Tensor): + # pyrefly: ignore[unsupported-operation] return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) @@ -93,10 +94,13 @@ def generate_image( prompt = [prompt] # allow for packing and conversion to latent space. Use the same resolution as training time. + # pyrefly: ignore [missing-attribute] img_height = 16 * (job_config.training.img_size // 16) + # pyrefly: ignore [missing-attribute] img_width = 16 * (job_config.training.img_size // 16) enable_classifier_free_guidance = ( + # pyrefly: ignore [missing-attribute] job_config.validation.enable_classifier_free_guidance ) @@ -104,7 +108,9 @@ def generate_image( clip_tokens = clip_tokenizer.encode(prompt) t5_tokens = t5_tokenizer.encode(prompt) if len(prompt) == 1: + # pyrefly: ignore [missing-attribute] clip_tokens = clip_tokens.unsqueeze(0) + # pyrefly: ignore [missing-attribute] t5_tokens = t5_tokens.unsqueeze(0) batch = preprocess_data( @@ -113,6 +119,7 @@ def generate_image( autoencoder=None, clip_encoder=clip_encoder, t5_encoder=t5_encoder, + # pyrefly: ignore [bad-argument-type] batch={ "clip_tokens": clip_tokens, "t5_tokens": t5_tokens, @@ -124,7 +131,9 @@ def generate_image( empty_clip_tokens = clip_tokenizer.encode("") empty_t5_tokens = t5_tokenizer.encode("") + # pyrefly: ignore [missing-attribute] empty_clip_tokens = empty_clip_tokens.repeat(num_images, 1) + # pyrefly: ignore [missing-attribute] empty_t5_tokens = empty_t5_tokens.repeat(num_images, 1) empty_batch = preprocess_data( @@ -145,16 +154,24 @@ def generate_image( model=model, img_width=img_width, img_height=img_height, + # pyrefly: ignore [missing-attribute] denoising_steps=job_config.validation.denoising_steps, clip_encodings=batch["clip_encodings"], t5_encodings=batch["t5_encodings"], enable_classifier_free_guidance=enable_classifier_free_guidance, empty_t5_encodings=( - empty_batch["t5_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["t5_encodings"] + if enable_classifier_free_guidance + else None ), empty_clip_encodings=( - empty_batch["clip_encodings"] if enable_classifier_free_guidance else None + # pyrefly: ignore [unbound-name] + empty_batch["clip_encodings"] + if enable_classifier_free_guidance + else None ), + # pyrefly: ignore [missing-attribute] classifier_free_guidance_scale=job_config.validation.classifier_free_guidance_scale, ) @@ -190,7 +207,9 @@ def denoise( if enable_classifier_free_guidance: # Double batch size for CFG: [unconditional, conditional] latents = torch.cat([latents, latents], dim=0) + # pyrefly: ignore [no-matching-overload] t5_encodings = torch.cat([empty_t5_encodings, t5_encodings], dim=0) + # pyrefly: ignore [no-matching-overload] clip_encodings = torch.cat([empty_clip_encodings, clip_encodings], dim=0) bsz *= 2 diff --git a/torchtitan/models/flux/infra/parallelize.py b/torchtitan/models/flux/infra/parallelize.py index fc9c926af0..321a73dcc9 100644 --- a/torchtitan/models/flux/infra/parallelize.py +++ b/torchtitan/models/flux/infra/parallelize.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any import torch import torch.nn as nn @@ -16,6 +17,7 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.tools.logging import logger @@ -27,15 +29,18 @@ def parallelize_flux( if job_config.activation_checkpoint.mode != "none": apply_ac(model, job_config.activation_checkpoint) + if parallel_dims.cp_enabled: + apply_cp(model, parallel_dims.get_mesh("cp")) + if parallel_dims.fsdp_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, - parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], cpu_offload=job_config.training.enable_cpu_offload, @@ -46,16 +51,6 @@ def parallelize_flux( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - # The attention in Flux does not use causal mask. - # Currently, load_balance must be disabled in order to support Context Parallelism - # in Pytorch's experimental ring attention module - # https://github.com/pytorch/pytorch/blob/v2.9.0/torch/distributed/tensor/experimental/_attention.py#L395 - from torch.distributed.tensor.experimental._attention import _cp_options - - _cp_options.enable_load_balance = False - logger.info("Applied Context Parallel to the model") - return model @@ -77,7 +72,7 @@ def apply_fsdp( cpu_offload (bool): Whether to offload model parameters to CPU. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() @@ -88,21 +83,27 @@ def apply_fsdp( model.txt_in, ] for layer in linear_layers: + # pyrefly: ignore [no-matching-overload] fully_shard(layer, **fsdp_config) + # pyrefly: ignore [not-iterable] for block in model.double_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) + # pyrefly: ignore [not-iterable] for block in model.single_blocks: + # pyrefly: ignore [no-matching-overload] fully_shard( block, **fsdp_config, ) # apply FSDP to last layer. Set reshard_after_forward=False for last layer to avoid gather right after reshard + # pyrefly: ignore [no-matching-overload] fully_shard(model.final_layer, **fsdp_config, reshard_after_forward=False) # Wrap all the rest of model @@ -112,17 +113,57 @@ def apply_fsdp( def apply_ac(model: nn.Module, ac_config): """Apply activation checkpointing to the model.""" + # pyrefly: ignore [missing-attribute] for layer_id, block in model.double_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.double_blocks.register_module(layer_id, block) + # pyrefly: ignore [missing-attribute] for layer_id, block in model.single_blocks.named_children(): block = ptd_checkpoint_wrapper(block, preserve_rng_state=False) + # pyrefly: ignore [missing-attribute] model.single_blocks.register_module(layer_id, block) logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") +def apply_cp(model: nn.Module, cp_mesh: DeviceMesh) -> None: + """ + Apply context parallelism to the Flux model. + + Args: + model: The Flux model with double_blocks and single_blocks containing + inner attention modules. + cp_mesh: Device mesh for context parallel dimension + + Note: + - Uses SDPA attention type + - Applies to all inner_attention modules in double_blocks and single_blocks + """ + # Collect all inner_attention modules from the Flux model + attention_modules = [] + + # pyrefly: ignore [not-iterable] + for double_block in model.double_blocks: + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.img_attn.inner_attention) + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.txt_attn.inner_attention) + # pyrefly: ignore [missing-attribute] + attention_modules.append(double_block.inner_attention) + + # pyrefly: ignore [not-iterable] + for single_block in model.single_blocks: + # pyrefly: ignore [missing-attribute] + attention_modules.append(single_block.inner_attention) + + # Apply CP using the shared implementation (always uses SDPA for Flux) + apply_cp_to_attention_module(attention_modules, cp_mesh, "sdpa") + + logger.info("Applied Context Parallel to the Flux model") + + def parallelize_encoders( t5_model: nn.Module, clip_model: nn.Module, @@ -130,17 +171,17 @@ def parallelize_encoders( job_config: JobConfig, ): if parallel_dims.dp_shard_enabled: # apply FSDP or HSDP - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard") - else: - dp_mesh_dim_names = ("dp_shard",) + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) mp_policy = MixedPrecisionPolicy( param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], ) - fsdp_config = { - "mesh": parallel_dims.world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh = parallel_dims.get_mesh(names) + fsdp_config: dict[str, Any] = { + "mesh": dp_mesh, "mp_policy": mp_policy, } if job_config.training.enable_cpu_offload: @@ -148,8 +189,10 @@ def parallelize_encoders( # NOTE: only apply FSDP to the T5 encoder, not the CLIP text encoder. # CLIP Text encoder has low computation / communication ratio, so it's not necessary to apply FSDP to it. + # pyrefly: ignore [missing-attribute] for block in t5_model.hf_module.encoder.block: fully_shard(block, **fsdp_config) + # pyrefly: ignore [no-matching-overload] fully_shard(t5_model.hf_module, **fsdp_config) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/flux/model/autoencoder.py b/torchtitan/models/flux/model/autoencoder.py index dc6fb1d061..a50e4a5ba3 100644 --- a/torchtitan/models/flux/model/autoencoder.py +++ b/torchtitan/models/flux/model/autoencoder.py @@ -19,7 +19,7 @@ class AutoEncoderParams: in_channels: int = 3 ch: int = 128 out_ch: int = 3 - ch_mult: tuple[int] = (1, 2, 4, 4) + ch_mult: tuple[int, ...] = (1, 2, 4, 4) num_res_blocks: int = 2 z_channels: int = 16 scale_factor: float = 0.3611 @@ -191,17 +191,24 @@ def forward(self, x: Tensor) -> Tensor: hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): + # pyrefly: ignore[bad-index, not-callable] h = self.down[i_level].block[i_block](hs[-1]) + # pyrefly: ignore [bad-argument-type] if len(self.down[i_level].attn) > 0: + # pyrefly: ignore[bad-index, not-callable] h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: + # pyrefly: ignore [not-callable] hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # end h = self.norm_out(h) @@ -276,8 +283,11 @@ def forward(self, z: Tensor) -> Tensor: h = self.conv_in(z) # middle + # pyrefly: ignore [not-callable] h = self.mid.block_1(h) + # pyrefly: ignore [not-callable] h = self.mid.attn_1(h) + # pyrefly: ignore [not-callable] h = self.mid.block_2(h) # cast to proper dtype @@ -285,10 +295,14 @@ def forward(self, z: Tensor) -> Tensor: # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): + # pyrefly: ignore[bad-index, not-callable] h = self.up[i_level].block[i_block](h) + # pyrefly: ignore [bad-argument-type] if len(self.up[i_level].attn) > 0: + # pyrefly: ignore[bad-index, not-callable] h = self.up[i_level].attn[i_block](h) if i_level != 0: + # pyrefly: ignore [not-callable] h = self.up[i_level].upsample(h) # end @@ -321,6 +335,7 @@ def __init__(self, params: AutoEncoderParams): resolution=params.resolution, in_channels=params.in_channels, ch=params.ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, @@ -330,6 +345,7 @@ def __init__(self, params: AutoEncoderParams): in_channels=params.in_channels, ch=params.ch, out_ch=params.out_ch, + # pyrefly: ignore [bad-argument-type] ch_mult=params.ch_mult, num_res_blocks=params.num_res_blocks, z_channels=params.z_channels, diff --git a/torchtitan/models/flux/model/hf_embedder.py b/torchtitan/models/flux/model/hf_embedder.py index 90be8767a9..56d82b1d2f 100644 --- a/torchtitan/models/flux/model/hf_embedder.py +++ b/torchtitan/models/flux/model/hf_embedder.py @@ -7,6 +7,8 @@ import os from torch import nn, Tensor + +# pyrefly: ignore[import-error] from transformers import CLIPTextModel, T5EncoderModel @@ -19,6 +21,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize CLIP model with random weights for test purpose only self.hf_module = CLIPTextModel._from_config( + # pyrefly: ignore [missing-attribute] CLIPTextModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) @@ -31,6 +34,7 @@ def __init__(self, version: str, random_init=False, **hf_kwargs): if random_init: # Initialize T5 model with random weights for test purpose only self.hf_module = T5EncoderModel._from_config( + # pyrefly: ignore [missing-attribute] T5EncoderModel.config_class.from_pretrained( os.path.join(version, "config.json"), **hf_kwargs ) diff --git a/torchtitan/models/flux/model/layers.py b/torchtitan/models/flux/model/layers.py index 923c5a422c..6d0e696dd9 100644 --- a/torchtitan/models/flux/model/layers.py +++ b/torchtitan/models/flux/model/layers.py @@ -6,12 +6,15 @@ # imported from black-forest-labs/FLUX import math +from collections.abc import Sequence from dataclasses import dataclass import torch from einops import rearrange from torch import nn, Tensor +from torchtitan.models.attention import ScaledDotProductAttentionWrapper + def rope(pos: Tensor, dim: int, theta: int) -> Tensor: assert dim % 2 == 0 @@ -34,7 +37,7 @@ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tenso class EmbedND(nn.Module): - def __init__(self, dim: int, theta: int, axes_dim: list[int]): + def __init__(self, dim: int, theta: int, axes_dim: Sequence[int]): super().__init__() self.dim = dim self.theta = theta @@ -123,6 +126,7 @@ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.norm = QKNorm(head_dim) self.proj = nn.Linear(dim, dim) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.qkv, self.proj): @@ -135,7 +139,7 @@ def forward(self, x: Tensor, pe: Tensor) -> Tensor: q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) q, k = self.norm(q, k, v) q, k = apply_rope(q, k, pe) - x = torch.nn.functional.scaled_dot_product_attention(q, k, v) + x = self.inner_attention(q, k, v) x = rearrange(x, "B H L D -> B L (H D)") x = self.proj(x) return x @@ -205,6 +209,8 @@ def __init__( nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) + self.inner_attention = ScaledDotProductAttentionWrapper() + def init_weights(self): # initialize all the nn.Linear submodules for layer in ( @@ -213,7 +219,9 @@ def init_weights(self): self.txt_mlp[0], self.txt_mlp[2], ): + # pyrefly: ignore [bad-argument-type] nn.init.xavier_uniform_(layer.weight) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(layer.bias, 0) # initialize Modulation layers, SelfAttention layers @@ -254,7 +262,7 @@ def forward( v = torch.cat((txt_v, img_v), dim=2) q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] @@ -305,6 +313,7 @@ def __init__( self.mlp_act = nn.GELU(approximate="tanh") self.modulation = Modulation(hidden_size, double=False) + self.inner_attention = ScaledDotProductAttentionWrapper() def init_weights(self): for layer in (self.linear1, self.linear2): @@ -326,7 +335,7 @@ def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: # compute attention q, k = apply_rope(q, k, pe) - attn = torch.nn.functional.scaled_dot_product_attention(q, k, v) + attn = self.inner_attention(q, k, v) attn = rearrange(attn, "B H L D -> B L (H D)") # compute activation in mlp stream, cat again and run second linear layer @@ -346,7 +355,9 @@ def __init__(self, hidden_size: int, patch_size: int, out_channels: int): ) def init_weights(self): + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].weight, 0) + # pyrefly: ignore [bad-argument-type] nn.init.constant_(self.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.linear.weight, 0) nn.init.constant_(self.linear.bias, 0) diff --git a/torchtitan/models/flux/model/model.py b/torchtitan/models/flux/model/model.py index 6cfb02c9c0..1f661c074b 100644 --- a/torchtitan/models/flux/model/model.py +++ b/torchtitan/models/flux/model/model.py @@ -21,7 +21,7 @@ from .args import FluxModelArgs -class FluxModel(nn.Module, ModelProtocol): +class FluxModel(ModelProtocol): """ Transformer model for flow matching on sequences. @@ -33,7 +33,7 @@ class FluxModel(nn.Module, ModelProtocol): """ def __init__(self, model_args: FluxModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args @@ -51,7 +51,9 @@ def __init__(self, model_args: FluxModelArgs): self.hidden_size = model_args.hidden_size self.num_heads = model_args.num_heads self.pe_embedder = EmbedND( - dim=pe_dim, theta=model_args.theta, axes_dim=model_args.axes_dim + dim=pe_dim, + theta=model_args.theta, + axes_dim=model_args.axes_dim, ) self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) @@ -95,8 +97,10 @@ def init_weights(self, buffer_device=None): # Initialize transformer blocks: for block in self.single_blocks: + # pyrefly: ignore [not-callable] block.init_weights() for block in self.double_blocks: + # pyrefly: ignore [not-callable] block.init_weights() # Zero-out output layers: diff --git a/torchtitan/models/flux/model/state_dict_adapter.py b/torchtitan/models/flux/model/state_dict_adapter.py index c976df6919..2526bcd521 100644 --- a/torchtitan/models/flux/model/state_dict_adapter.py +++ b/torchtitan/models/flux/model/state_dict_adapter.py @@ -58,6 +58,7 @@ def __init__(self, model_args: FluxModelArgs, hf_assets_path: str | None): if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = indx else: @@ -173,6 +174,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): # Extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -242,6 +244,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): # extract layer_num and abstract key if necessary if "blocks" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -273,6 +276,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for tt_fqn, hf_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", tt_fqn).group(0) tt_abstract_key = re.sub(r"(\d+)", "{}", tt_fqn, count=1) combine_values = [] diff --git a/torchtitan/models/flux/tokenizer.py b/torchtitan/models/flux/tokenizer.py index b5cca546b9..bf99026047 100644 --- a/torchtitan/models/flux/tokenizer.py +++ b/torchtitan/models/flux/tokenizer.py @@ -11,6 +11,8 @@ from typing import List import torch + +# pyrefly: ignore[import-error] from transformers import CLIPTokenizer, T5Tokenizer from torchtitan.components.tokenizer import BaseTokenizer, HuggingFaceTokenizer @@ -46,6 +48,7 @@ def _pad_and_chunk_tokens( def get_vocab_size(self) -> int: return self.tiktokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode(self, text: str | list[str]) -> torch.Tensor: """ Use TikTokenizer to encode the text into tokens, and then pad and chunk the tokens to max_length. @@ -72,6 +75,7 @@ def encode(self, text: str | list[str]) -> torch.Tensor: tokens = self._pad_and_chunk_tokens(tokens, self._max_length, self.pad_id) return torch.tensor(tokens) + # pyrefly: ignore [bad-override] def decode(self, t: List[int]) -> str: """ Decode function. This function will not be called. @@ -96,10 +100,12 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar self.is_clip = "clip" in model_path.lower() if self.is_clip: + # pyrefly: ignore [bad-assignment] self._tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) else: + # pyrefly: ignore [bad-assignment] self._tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( model_path, max_length=max_length, **hf_kwargs ) @@ -107,6 +113,7 @@ def __init__(self, model_path: str = "t5-small", max_length: int = 77, **hf_kwar def get_vocab_size(self) -> int: return self._tokenizer.vocab_size + # pyrefly: ignore [bad-override] def encode( self, s: str | list[str], @@ -125,22 +132,27 @@ def encode( )["input_ids"] return tokens - def decode(self, t: List[int]) -> str: + # pyrefly: ignore [bad-override] + def decode(self, t: list[int]) -> list[str] | str: """ Decode function. This function will not be called. """ - return self._tokenizer.decode(t) + return self._tokenizer.decode(t) # pyrefly: ignore[bad-return] def build_flux_tokenizer(job_config: JobConfig) -> tuple[BaseTokenizer, BaseTokenizer]: """ Build the tokenizer for Flux. """ + # pyrefly: ignore [missing-attribute] t5_tokenizer_path = job_config.encoder.t5_encoder + # pyrefly: ignore [missing-attribute] clip_tokenzier_path = job_config.encoder.clip_encoder + # pyrefly: ignore [missing-attribute] max_t5_encoding_len = job_config.encoder.max_t5_encoding_len # NOTE: This tokenizer is used for offline CI and testing only, borrowed from llama3 tokenizer + # pyrefly: ignore [missing-attribute] if job_config.training.test_mode: tokenizer_class = FluxTestTokenizer t5_tokenizer_path = clip_tokenzier_path = job_config.model.hf_assets_path diff --git a/torchtitan/models/flux/train.py b/torchtitan/models/flux/train.py index 5af9959050..7d85d2b3a1 100644 --- a/torchtitan/models/flux/train.py +++ b/torchtitan/models/flux/train.py @@ -28,10 +28,10 @@ def __init__(self, job_config: JobConfig): # (mainly for debugging, expect perf loss). # For Flux model, we need distinct seed across FSDP ranks to ensure we randomly dropout prompts info in dataloader dist_utils.set_determinism( - self.parallel_dims.world_mesh, + self.parallel_dims, self.device, job_config.debug, - distinct_seed_mesh_dims=["dp_shard", "dp_replicate"], + distinct_seed_mesh_dims=["fsdp", "dp_replicate"], ) # NOTE: self._dtype is the data type used for encoders (image encoder, T5 text encoder, CLIP text encoder). @@ -48,23 +48,31 @@ def __init__(self, job_config: JobConfig): model_args = self.train_spec.model_args[job_config.model.flavor] self.autoencoder = load_ae( + # pyrefly: ignore [missing-attribute] job_config.encoder.autoencoder_path, + # pyrefly: ignore [missing-attribute] model_args.autoencoder_params, device=self.device, dtype=self._dtype, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ) self.clip_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.clip_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) self.t5_encoder = FluxEmbedder( + # pyrefly: ignore [missing-attribute] version=job_config.encoder.t5_encoder, + # pyrefly: ignore [missing-attribute] random_init=job_config.training.test_mode, ).to(device=self.device, dtype=self._dtype) # Apply FSDP to the T5 model / CLIP model + # pyrefly: ignore [bad-assignment] self.t5_encoder, self.clip_encoder = parallelize_encoders( t5_model=self.t5_encoder, clip_model=self.clip_encoder, @@ -73,6 +81,7 @@ def __init__(self, job_config: JobConfig): ) if job_config.validation.enable: + # pyrefly: ignore [missing-attribute] self.validator.flux_init( device=self.device, _dtype=self._dtype, @@ -127,30 +136,24 @@ def forward_backward_step( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=self.parallel_dims.world_mesh["cp"], - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + # Apply CP sharding if enabled + if self.parallel_dims.cp_enabled: + from torchtitan.distributed.context_parallel import cp_shard + + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = cp_shard( + self.parallel_dims.get_mesh("cp"), + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + load_balancer_type=None, ) - if self.parallel_dims.cp_enabled - else None - ) - with self.train_context(optional_context_parallel_ctx): + + with self.train_context(): with self.maybe_enable_amp: latent_noise_pred = model( img=latents, @@ -164,6 +167,7 @@ def forward_backward_step( loss = self.loss_fn(latent_noise_pred, target) # latent_noise_pred.shape=(bs, seq_len, vocab_size) # need to free to before bwd to avoid peaking memory + # pyrefly: ignore[unsupported-delete] del (latent_noise_pred, noise, target) loss.backward() diff --git a/torchtitan/models/flux/train_configs/debug_model.toml b/torchtitan/models/flux/train_configs/debug_model.toml index 47a033c546..b943925c1c 100644 --- a/torchtitan/models/flux/train_configs/debug_model.toml +++ b/torchtitan/models/flux/train_configs/debug_model.toml @@ -21,6 +21,7 @@ enable_wandb = false [model] name = "flux" flavor = "flux-debug" +hf_assets_path = "tests/assets/tokenizer" [optimizer] name = "AdamW" @@ -48,6 +49,7 @@ autoencoder_path = "assets/hf/FLUX.1-dev/ae.safetensors" # Autoencoder to use f [parallelism] data_parallel_replicate_degree = 1 data_parallel_shard_degree = -1 +context_parallel_degree = 1 [activation_checkpoint] mode = "full" diff --git a/torchtitan/models/flux/validate.py b/torchtitan/models/flux/validate.py index 189385e0f2..70dfff4bb3 100644 --- a/torchtitan/models/flux/validate.py +++ b/torchtitan/models/flux/validate.py @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import os -from typing import Generator +from contextlib import AbstractContextManager import torch import torch.nn as nn @@ -15,7 +15,7 @@ from torchtitan.components.loss import LossFunction from torchtitan.components.metrics import MetricsProcessor from torchtitan.components.tokenizer import BaseTokenizer -from torchtitan.components.validate import Validator +from torchtitan.components.validate import ValidationContext, Validator from torchtitan.config import JobConfig from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.models.flux.flux_datasets import build_flux_validation_dataloader @@ -53,16 +53,18 @@ def __init__( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, pp_has_last_stage: bool | None = None, ): self.job_config = job_config + self.tokenizer = tokenizer self.parallel_dims = parallel_dims self.loss_fn = loss_fn + # pyrefly: ignore [missing-attribute] self.all_timesteps = self.job_config.validation.all_timesteps self.validation_dataloader = build_flux_validation_dataloader( job_config=job_config, @@ -74,6 +76,7 @@ def __init__( ) self.validation_context = validation_context self.maybe_enable_amp = maybe_enable_amp + # pyrefly: ignore [bad-assignment] self.metrics_processor = metrics_processor self.t5_tokenizer, self.clip_tokenizer = build_flux_tokenizer(self.job_config) @@ -91,6 +94,7 @@ def flux_init( t5_encoder: FluxEmbedder, clip_encoder: FluxEmbedder, ): + # pyrefly: ignore [read-only] self.device = device self._dtype = _dtype self.autoencoder = autoencoder @@ -109,9 +113,12 @@ def validate( model.eval() # Disable cfg dropout during validation + # pyrefly: ignore [missing-attribute] training_cfg_prob = self.job_config.training.classifier_free_guidance_prob + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = 0.0 + # pyrefly: ignore [missing-attribute] save_img_count = self.job_config.validation.save_img_count parallel_dims = self.parallel_dims @@ -120,6 +127,7 @@ def validate( device_type = dist_utils.device_type num_steps = 0 + # pyrefly: ignore [not-iterable] for input_dict, labels in self.validation_dataloader: if ( self.job_config.validation.steps != -1 @@ -137,8 +145,9 @@ def validate( device=self.device, dtype=self._dtype, job_config=self.job_config, + # pyrefly: ignore [bad-argument-type] model=model, - prompt=p, + prompt=p, # pyrefly: ignore[bad-argument-type] autoencoder=self.autoencoder, t5_tokenizer=self.t5_tokenizer, clip_tokenizer=self.clip_tokenizer, @@ -150,11 +159,12 @@ def validate( name=f"image_rank{str(torch.distributed.get_rank())}_{step}.png", output_dir=os.path.join( self.job_config.job.dump_folder, + # pyrefly: ignore [missing-attribute] self.job_config.validation.save_img_folder, ), x=image, add_sampling_metadata=True, - prompt=p, + prompt=p, # pyrefly: ignore[bad-argument-type] ) save_img_count -= 1 @@ -211,42 +221,35 @@ def validate( latents = pack_latents(latents) target = pack_latents(noise - labels) - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - ], - cp_seq_dims=[1, 1, 1, 1, 1], - cp_no_restore_buffers={ - latents, - latent_pos_enc, - t5_encodings, - text_pos_enc, - target, - }, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None + # Apply CP sharding if enabled + if parallel_dims.cp_enabled: + from torchtitan.distributed.context_parallel import cp_shard + + ( + latents, + latent_pos_enc, + t5_encodings, + text_pos_enc, + target, + ), _ = cp_shard( + parallel_dims.get_mesh("cp"), + (latents, latent_pos_enc, t5_encodings, text_pos_enc, target), + None, # No attention masks for Flux + load_balancer_type=None, ) - with self.validation_context(optional_context_parallel_ctx): - with self.maybe_enable_amp: - latent_noise_pred = model( - img=latents, - img_ids=latent_pos_enc, - txt=t5_encodings, - txt_ids=text_pos_enc, - y=clip_encodings, - timesteps=timesteps, - ) + with self.validation_context(): + with self.maybe_enable_amp: + latent_noise_pred = model( + img=latents, + img_ids=latent_pos_enc, + txt=t5_encodings, + txt_ids=text_pos_enc, + y=clip_encodings, + timesteps=timesteps, + ) - loss = self.loss_fn(latent_noise_pred, target) + loss = self.loss_fn(latent_noise_pred, target) del noise, target, latent_noise_pred, latents @@ -259,7 +262,7 @@ def validate( loss /= num_steps if parallel_dims.dp_cp_enabled: global_avg_loss = dist_utils.dist_mean( - loss, parallel_dims.world_mesh["dp_cp"] + loss, parallel_dims.get_optional_mesh("loss") ) else: global_avg_loss = loss.item() @@ -270,6 +273,7 @@ def validate( model.train() # re-enable cfg dropout for training + # pyrefly: ignore [missing-attribute] self.job_config.training.classifier_free_guidance_prob = training_cfg_prob @@ -280,8 +284,8 @@ def build_flux_validator( tokenizer: BaseTokenizer, parallel_dims: ParallelDims, loss_fn: LossFunction, - validation_context: Generator[None, None, None], - maybe_enable_amp: Generator[None, None, None], + validation_context: ValidationContext, + maybe_enable_amp: AbstractContextManager[None], metrics_processor: MetricsProcessor | None = None, pp_schedule: _PipelineSchedule | None = None, pp_has_first_stage: bool | None = None, diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/models/gpt_oss/README.md similarity index 75% rename from torchtitan/experiments/gpt_oss/README.md rename to torchtitan/models/gpt_oss/README.md index a8283ab7b6..c16898bd80 100644 --- a/torchtitan/experiments/gpt_oss/README.md +++ b/torchtitan/models/gpt_oss/README.md @@ -12,6 +12,3 @@ CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./ ## TODO 1. More parallelism support: CP, PP -2. Conversion between HF weights (StateDictAdapter) -3. Forward parity verification -4. CI support diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/models/gpt_oss/__init__.py similarity index 95% rename from torchtitan/experiments/gpt_oss/__init__.py rename to torchtitan/models/gpt_oss/__init__.py index c12ad13a5c..0ebc20645f 100644 --- a/torchtitan/experiments/gpt_oss/__init__.py +++ b/torchtitan/models/gpt_oss/__init__.py @@ -16,6 +16,7 @@ from .infra.parallelize import parallelize_gptoss from .model.args import GptOssModelArgs from .model.model import GptOssModel +from .model.state_dict_adapter import GptOssStateDictAdapter __all__ = [ "parallelize_gptoss", @@ -84,4 +85,5 @@ def get_train_spec() -> TrainSpec: build_dataloader_fn=build_text_dataloader, build_tokenizer_fn=build_hf_tokenizer, build_loss_fn=build_cross_entropy_loss, + state_dict_adapter=GptOssStateDictAdapter, ) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/models/gpt_oss/infra/expert_parallel.py similarity index 72% rename from torchtitan/experiments/gpt_oss/infra/expert_parallel.py rename to torchtitan/models/gpt_oss/infra/expert_parallel.py index 96ad157c2f..33de706648 100644 --- a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py +++ b/torchtitan/models/gpt_oss/infra/expert_parallel.py @@ -6,9 +6,10 @@ import torch.nn as nn -from torch.distributed.tensor import distribute_tensor, Replicate, Shard +from torch.distributed.tensor import DeviceMesh, distribute_tensor, Replicate, Shard from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel + # implementation of Tensor Parallel for the GroupedExperts in MoE class GptossTensorParallel(TensorParallel): def _partition_fn(self, name, module, device_mesh): @@ -38,28 +39,32 @@ def _partition_fn(self, name, module, device_mesh): # This class is for dp2ep with TP (without TP we can just use GptossExpertParallel) class GptossExpertTensorParallel(ExpertTensorParallel): - def _partition_fn_2d(self, name, mod, ep_tp_mesh): + def _partition_fn(self, name: str, mod: nn.Module, device_mesh: DeviceMesh) -> None: mod.register_parameter( "mlp1_weight", nn.Parameter( - distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)]) + # pyrefly: ignore [bad-argument-type] + distribute_tensor(mod.mlp1_weight, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp1_bias", nn.Parameter( - distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) + # pyrefly: ignore [bad-argument-type] + distribute_tensor(mod.mlp1_bias, device_mesh, [Shard(0), Shard(1)]) ), ) # Column-wise sharding mod.register_parameter( "mlp2_weight", nn.Parameter( - distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)]) + # pyrefly: ignore [bad-argument-type] + distribute_tensor(mod.mlp2_weight, device_mesh, [Shard(0), Shard(2)]) ), ) # Row-wise sharding mod.register_parameter( "mlp2_bias", nn.Parameter( - distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) + # pyrefly: ignore [bad-argument-type] + distribute_tensor(mod.mlp2_bias, device_mesh, [Shard(0), Replicate()]) ), ) # Replicate diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/models/gpt_oss/infra/parallelize.py similarity index 82% rename from torchtitan/experiments/gpt_oss/infra/parallelize.py rename to torchtitan/models/gpt_oss/infra/parallelize.py index 7714d497e4..80a8bc8bc2 100644 --- a/torchtitan/experiments/gpt_oss/infra/parallelize.py +++ b/torchtitan/models/gpt_oss/infra/parallelize.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh @@ -21,7 +22,12 @@ from torchtitan.config.job_config import JobConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.dual_pipe_v import ( + DualPipeExpertParallel, + get_dual_pipe_v_flag, +) from torchtitan.distributed.expert_parallel import ( + BaseExpertParallel, ExpertParallel, ReordererSequenceParallel, ) @@ -37,6 +43,9 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save @@ -44,6 +53,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -53,8 +63,6 @@ def parallelize_gptoss( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh - assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -62,9 +70,9 @@ def parallelize_gptoss( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) if parallel_dims.tp_enabled: if ( @@ -86,55 +94,48 @@ def parallelize_gptoss( apply_non_moe_tp( model, - world_mesh["tp"], + parallel_dims.get_mesh("tp"), loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=False, enable_async_tp=False, ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), etp_enabled=parallel_dims.etp_enabled, + dual_pipe_v=dual_pipe_v, ) - model_compile_enabled = ( - job_config.compile.enable and "model" in job_config.compile.components - ) - if job_config.activation_checkpoint.mode != "none": apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, ) dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -145,11 +146,8 @@ def parallelize_gptoss( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -163,9 +161,9 @@ def parallelize_gptoss( if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh is not None and dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -205,6 +203,7 @@ def apply_non_moe_tp( ) # Apply tensor + sequence parallelism to every transformer block + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), @@ -228,6 +227,7 @@ def apply_non_moe_tp( } # shard attention.sinks across heads + # pyrefly: ignore [missing-attribute] attn = transformer_block.attention attn.register_parameter( "sinks", @@ -235,16 +235,15 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=layer_plan, ) if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" @@ -256,12 +255,15 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, etp_enabled: bool, + dual_pipe_v: bool = False, ): assert ep_mesh is not None or tp_mesh is not None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue @@ -283,11 +285,14 @@ def apply_moe_ep_tp( # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=moe_layer_plan, ) @@ -301,10 +306,14 @@ def apply_moe_ep_tp( # input / output sharding on the batch / tokens dim experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = GptossExpertTensorParallel() + if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): + experts_plan = DualPipeExpertParallel(experts_plan) + parallelize_module( + # pyrefly: ignore [missing-attribute] module=transformer_block.moe.experts, device_mesh=experts_mesh, parallelize_plan=experts_plan, diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/models/gpt_oss/model/args.py similarity index 95% rename from torchtitan/experiments/gpt_oss/model/args.py rename to torchtitan/models/gpt_oss/model/args.py index e78eac4d74..2a9aa970e4 100644 --- a/torchtitan/experiments/gpt_oss/model/args.py +++ b/torchtitan/models/gpt_oss/model/args.py @@ -39,7 +39,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads (int): Number of key-value heads. sliding_window_size (int): Size of the sliding attention window. attn_mask_type (str): Type of basic attention mask. - use_flex_attn (bool): Whether to use FlexAttention. Only supports True. + attn_type (bool): Attention type, only supports Flex. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. @@ -64,7 +64,7 @@ class GptOssModelArgs(BaseModelArgs): n_kv_heads: int = 8 sliding_window_size: int = 128 attn_mask_type: str = "causal" - use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention + attn_type: str = "flex" # NOTE: gpt-oss only support FlexAttention # yarn original_seq_len: int = 4096 rope_theta: float = 150000.0 @@ -91,6 +91,7 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: "CP support for gpt-oss model is still in progress." ) + # pyrefly: ignore [bad-override] def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/models/gpt_oss/model/model.py similarity index 98% rename from torchtitan/experiments/gpt_oss/model/model.py rename to torchtitan/models/gpt_oss/model/model.py index 8732d70290..79b24805c3 100644 --- a/torchtitan/experiments/gpt_oss/model/model.py +++ b/torchtitan/models/gpt_oss/model/model.py @@ -262,8 +262,10 @@ def forward( """ # Extract the appropriate mask for this layer if self.use_sliding_attention: + # pyrefly: ignore [missing-attribute] layer_mask = attention_masks.get("sliding_window_mask", None) else: + # pyrefly: ignore [missing-attribute] layer_mask = attention_masks.get("basic_mask", None) assert layer_mask is not None @@ -283,13 +285,13 @@ def init_weights(self, buffer_device: torch.device): self.moe.init_weights(self.weight_init_std, buffer_device, self.n_layers) -class GptOssModel(nn.Module, ModelProtocol): +class GptOssModel(ModelProtocol): """ GPT-OSS Transformer model with attention and feed-forward layers. """ def __init__(self, model_args: GptOssModelArgs): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.max_seq_len = model_args.max_seq_len self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) @@ -321,6 +323,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/models/gpt_oss/model/moe.py similarity index 94% rename from torchtitan/experiments/gpt_oss/model/moe.py rename to torchtitan/models/gpt_oss/model/moe.py index 94cd266761..f9f5b085bd 100644 --- a/torchtitan/experiments/gpt_oss/model/moe.py +++ b/torchtitan/models/gpt_oss/model/moe.py @@ -25,6 +25,7 @@ class ScaleBiasForward(torch.autograd.Function): """ @staticmethod + # pyrefly: ignore [bad-override] def forward(ctx, bias, tp_degree): ctx.tp_degree = tp_degree if tp_degree > 1: @@ -32,6 +33,7 @@ def forward(ctx, bias, tp_degree): return bias @staticmethod + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): # Don't scale the gradient - pass it through as-is return grad_output, None @@ -101,6 +103,7 @@ def _run_experts_for_loop( tp_degree: int = 1, ) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host + # pyrefly: ignore [bad-assignment] num_tokens_per_expert = num_tokens_per_expert.tolist() # side-effect code due to the usage of generate_permute_indices @@ -108,8 +111,10 @@ def _run_experts_for_loop( # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) + # pyrefly: ignore [bad-assignment] x = torch.split( x[: sum(num_tokens_per_expert)], + # pyrefly: ignore [bad-argument-type] split_size_or_sections=num_tokens_per_expert, dim=0, ) @@ -127,6 +132,7 @@ def _run_experts_for_loop( out = torch.cat(out_experts_splits, dim=0) # side-effect code due to the usage of generate_permute_indices + # pyrefly: ignore [no-matching-overload] out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) return out @@ -201,8 +207,11 @@ def forward( # Convert parameters from DTensors to plain Tensors, to work with # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. mlp1_weight = self.mlp1_weight.to_local() + # pyrefly: ignore [missing-attribute] mlp1_bias = self.mlp1_bias.to_local() + # pyrefly: ignore [missing-attribute] mlp2_weight = self.mlp2_weight.to_local() + # pyrefly: ignore [missing-attribute] mlp2_bias = self.mlp2_bias.to_local() else: mlp1_weight = self.mlp1_weight @@ -214,13 +223,16 @@ def forward( tp_degree = 1 if isinstance(self.mlp1_weight, DTensor): mesh_dim_names = self.mlp1_weight.device_mesh.mesh_dim_names + # pyrefly: ignore[not-iterable] if "tp" in mesh_dim_names: + # pyrefly: ignore [missing-attribute] tp_dim_idx = mesh_dim_names.index("tp") tp_degree = self.mlp1_weight.device_mesh.size(tp_dim_idx) if self.use_grouped_mm: if ( not isinstance(self.mlp1_weight, DTensor) + # pyrefly: ignore[not-iterable] or "ep" not in self.mlp1_weight.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) @@ -266,6 +278,7 @@ def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): super().__init__(moe_args, dim, hidden_dim) # Override the base GroupedExperts with GptOssGroupedExperts + # pyrefly: ignore [bad-assignment] self.experts = GptOssGroupedExperts( dim=dim, hidden_dim=hidden_dim, diff --git a/torchtitan/models/gpt_oss/model/state_dict_adapter.py b/torchtitan/models/gpt_oss/model/state_dict_adapter.py new file mode 100644 index 0000000000..9198505257 --- /dev/null +++ b/torchtitan/models/gpt_oss/model/state_dict_adapter.py @@ -0,0 +1,117 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import re +from typing import Any + +from torch.distributed.checkpoint import HuggingFaceStorageReader +from torchtitan.models.utils import MoEStateDictAdapter + +from .args import GptOssModelArgs + + +class GptOssStateDictAdapter(MoEStateDictAdapter): + def __init__(self, model_args: GptOssModelArgs, hf_assets_path: str | None): + super().__init__(model_args, hf_assets_path) + self.from_hf_map = { + "model.embed_tokens.weight": "tok_embeddings.weight", + # Attention module + "model.layers.{}.self_attn.q_proj.weight": "layers.{}.attention.wq.weight", + "model.layers.{}.self_attn.q_proj.bias": "layers.{}.attention.wq.bias", + "model.layers.{}.self_attn.k_proj.weight": "layers.{}.attention.wk.weight", + "model.layers.{}.self_attn.k_proj.bias": "layers.{}.attention.wk.bias", + "model.layers.{}.self_attn.v_proj.weight": "layers.{}.attention.wv.weight", + "model.layers.{}.self_attn.v_proj.bias": "layers.{}.attention.wv.bias", + "model.layers.{}.self_attn.o_proj.weight": "layers.{}.attention.wo.weight", + "model.layers.{}.self_attn.o_proj.bias": "layers.{}.attention.wo.bias", + "model.layers.{}.self_attn.sinks": "layers.{}.attention.sinks", + # Transformer layer + "model.layers.{}.input_layernorm.weight": "layers.{}.attention_norm.weight", + "model.layers.{}.post_attention_layernorm.weight": "layers.{}.ffn_norm.weight", + # MoE + "model.layers.{}.mlp.experts.gate_up_proj_blocks": "layers.{}.moe.experts.mlp1_weight", + "model.layers.{}.mlp.experts.gate_up_proj_bias": "layers.{}.moe.experts.mlp1_bias", + "model.layers.{}.mlp.experts.down_proj_blocks": "layers.{}.moe.experts.mlp2_weight", + "model.layers.{}.mlp.experts.down_proj_bias": "layers.{}.moe.experts.mlp2_bias", + "model.layers.{}.mlp.router.weight": "layers.{}.moe.router.gate.weight", + "model.layers.{}.mlp.router.bias": "layers.{}.moe.router.gate.bias", + "model.norm.weight": "norm.weight", + "lm_head.weight": "output.weight", + } + + def get_hf_storage_reader( + self, path: str, from_quantized: bool = False + ) -> HuggingFaceStorageReader: + """ + Override default get_hf_storage_reader function to return QuantizedHFStorageReader. + """ + if from_quantized: + from torch.distributed.checkpoint.quantized_hf_storage import ( + QuantizedHuggingFaceStorageReader, + ) + + # NOTE: Now we use Quantized HF storage reader to read GPT-OSS model where + # expert weights are saved in MXFP4 format. + # If loading checkpoints without quantization, use HuggingFaceStorageReader instead + return QuantizedHuggingFaceStorageReader( + path=path, + thread_count=4, + ) + else: + return HuggingFaceStorageReader(path) + + def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from a tt model state dict to a hf format state dict. + + Only map keys without changing shapes to the same as MXFP4 checkpoint. + For loading from quantized checkpoints, the QuantizedHuggingFaceStorageReader + will handle dequantization during load. + + Warning: Conversion does not support saving to mxfp4 quantization format. + One can save into unquantized hf checkpoints with last_save_in_hf = true. + """ + to_hf_map = {v: k for k, v in self.from_hf_map.items()} + hf_state_dict = {} + + for key, value in state_dict.items(): + if "layers" in key: + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + if abstract_key not in to_hf_map: + continue + # pyrefly: ignore + layer_num = re.search(r"\d+", key).group(0) + hf_key = to_hf_map[abstract_key] + hf_key = hf_key.format(layer_num) + hf_state_dict[hf_key] = value + else: + if key not in to_hf_map: + continue + hf_key = to_hf_map[key] + hf_state_dict[hf_key] = value + + return hf_state_dict + + def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: + """ + Convert from hf format state dict to tt model state dict. + """ + + state_dict = {} + + for key, value in hf_state_dict.items(): + if "layers" in key: + # pyrefly: ignore + layer_num = re.search(r"\d+", key).group(0) + abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + tt_key = self.from_hf_map[abstract_key] + tt_key = tt_key.format(layer_num) + state_dict[tt_key] = value + else: + tt_key = self.from_hf_map[key] + state_dict[tt_key] = value + + return state_dict diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/models/gpt_oss/train_configs/debug_model.toml similarity index 100% rename from torchtitan/experiments/gpt_oss/train_configs/debug_model.toml rename to torchtitan/models/gpt_oss/train_configs/debug_model.toml diff --git a/torchtitan/models/llama3/__init__.py b/torchtitan/models/llama3/__init__.py index 3207323a8a..4d4508458f 100644 --- a/torchtitan/models/llama3/__init__.py +++ b/torchtitan/models/llama3/__init__.py @@ -36,7 +36,16 @@ n_heads=16, vocab_size=2048, rope_theta=500000, - use_flex_attn=True, + attn_type="flex", + attn_mask_type="block_causal", + ), + "debugmodel_varlen_attn": TransformerModelArgs( + dim=256, + n_layers=6, + n_heads=16, + vocab_size=2048, + rope_theta=500000, + attn_type="varlen", attn_mask_type="block_causal", ), "8B": TransformerModelArgs( @@ -48,7 +57,7 @@ multiple_of=1024, rope_theta=500000, ), - "8B_flex_attn": TransformerModelArgs( + "8B_flex": TransformerModelArgs( dim=4096, n_layers=32, n_heads=32, @@ -56,19 +65,21 @@ ffn_dim_multiplier=1.3, multiple_of=1024, rope_theta=500000, - use_flex_attn=True, - attn_mask_type="block_causal_by_sequence_lengths", + attn_type="flex", + attn_mask_type="block_causal", ), - "70B": TransformerModelArgs( - dim=8192, - n_layers=80, - n_heads=64, + "8B_varlen": TransformerModelArgs( + dim=4096, + n_layers=32, + n_heads=32, n_kv_heads=8, ffn_dim_multiplier=1.3, - multiple_of=4096, + multiple_of=1024, rope_theta=500000, + attn_type="varlen", + attn_mask_type="block_causal", ), - "70B_flex_attn": TransformerModelArgs( + "70B": TransformerModelArgs( dim=8192, n_layers=80, n_heads=64, @@ -76,8 +87,6 @@ ffn_dim_multiplier=1.3, multiple_of=4096, rope_theta=500000, - use_flex_attn=True, - attn_mask_type="block_causal_by_sequence_lengths", ), "405B": TransformerModelArgs( dim=16384, @@ -96,8 +105,8 @@ ffn_dim_multiplier=2, multiple_of=432, rope_theta=10000000, - use_flex_attn=True, - attn_mask_type="block_causal_by_sequence_lengths", + attn_type="varlen", + attn_mask_type="block_causal", use_qkv_bias=True, vocab_size=155136, head_dim=128, diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index ccc3c1f07b..98244665a7 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -26,6 +26,7 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp from torchtitan.tools.logging import logger @@ -35,12 +36,17 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -56,7 +62,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -67,10 +72,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -83,13 +84,24 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, + cp_enabled=parallel_dims.cp_enabled, + ) + maybe_enable_async_tp(job_config, tp_mesh) + + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -100,7 +112,7 @@ def parallelize_llama( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) @@ -110,15 +122,14 @@ def parallelize_llama( apply_compile(model, job_config.compile) if parallel_dims.fsdp_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - + # dp_mesh is the mesh for FSDP/HSDP + names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(names) apply_fsdp( model, - world_mesh[tuple(dp_mesh_dim_names)], + dp_mesh, param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], pp_enabled=parallel_dims.pp_enabled, @@ -131,17 +142,15 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_replicate_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_replicate_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_replicate_mesh, enable_compile=model_compile_enabled, ) @@ -153,6 +162,7 @@ def apply_tp( tp_mesh: DeviceMesh, loss_parallel: bool, enable_float8_tensorwise_tp: bool, + cp_enabled: bool = False, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -202,12 +212,18 @@ def apply_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama3. + # TODO: https://github.com/pytorch/torchtitan/pull/2149 would fix this + # inconsistency. "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), @@ -226,6 +242,7 @@ def apply_tp( } parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, @@ -242,10 +259,12 @@ def apply_compile(model: nn.Module, compile_config: CompileConfig): Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): transformer_block = torch.compile( transformer_block, backend=compile_config.backend, fullgraph=compile_config.fullgraph ) + # pyrefly: ignore [missing-attribute] model.layers.register_module(layer_id, transformer_block) logger.info("Compiling each TransformerBlock with torch.compile") @@ -280,6 +299,7 @@ def apply_fsdp( mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: + # pyrefly: ignore[bad-typed-dict-key] fsdp_config["offload_policy"] = CPUOffloadPolicy() match reshard_after_forward_policy: @@ -297,12 +317,15 @@ def apply_fsdp( ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): + # pyrefly: ignore[no-matching-overload] fully_shard( transformer_block, **fsdp_config, @@ -311,11 +334,13 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, reshard_after_forward=reshard_after_forward_policy == "always", ) + # pyrefly: ignore[no-matching-overload] fully_shard(model, **fsdp_config) @@ -327,6 +352,7 @@ def apply_ddp( if enable_compile: torch._dynamo.config.optimize_ddp = "ddp_optimizer" + # pyrefly: ignore [invalid-param-spec] replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index ff60cae708..43f5c69dea 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -10,7 +10,6 @@ from dataclasses import dataclass, field from torch import nn - from torchtitan.config import JobConfig from torchtitan.models.utils import get_dense_model_nparams_and_flops from torchtitan.protocols.model import BaseModelArgs @@ -49,7 +48,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 0 @@ -63,14 +62,17 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type == "varlen" + ): raise NotImplementedError( - "CP support for FlexAttention is still in progress." + f"Context Parallel only supports SDPA and FlexAttention." + f"Got attn_type='{self.attn_type}'. " + f"Varlen attention is not supported with CP." ) - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_dense_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama3/model/model.py b/torchtitan/models/llama3/model/model.py index 22f0638b99..37da1f85d7 100644 --- a/torchtitan/models/llama3/model/model.py +++ b/torchtitan/models/llama3/model/model.py @@ -16,11 +16,14 @@ from torchtitan.config.job_config import PEFT from torchtitan.models.attention import ( create_attention_mask, + create_varlen_metadata_for_document, + create_varlen_metadata_from_sequence_lengths, FlexAttentionWrapper, - get_block_causal_mask_mod_by_seq_lens, get_causal_mask_mod, get_document_mask_mod, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -86,19 +89,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -106,17 +113,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -130,28 +155,17 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. - position_ids (torch.Tensor, optional): Custom position IDs of shape [batch_size, seq_len]. - If provided, will use these to index into freqs_cis. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - - if position_ids is not None: - gathered_freqs = freqs_cis[position_ids] # [bs, seqlen, head_dim/2] - gathered_freqs = gathered_freqs.unsqueeze(2) # [bs, seqlen, 1, head_dim/2] - - xq_out = torch.view_as_real(xq_ * gathered_freqs).flatten(3) - xk_out = torch.view_as_real(xk_ * gathered_freqs).flatten(3) - - return xq_out.type_as(xq), xk_out.type_as(xk) - else: - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) - xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) - xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) - return xq_out.type_as(xq), xk_out.type_as(xk) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: @@ -237,11 +251,18 @@ def __init__(self, model_args: TransformerModelArgs, peft_config: PEFT): self.q_norm.weight.requires_grad = False self.k_norm.weight.requires_grad = False - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "varlen": + # pyrefly: ignore [bad-assignment] + self.inner_attention = VarlenAttentionWrapper() + case "sdpa": + # pyrefly: ignore [bad-assignment] + self.inner_attention = ScaledDotProductAttentionWrapper() + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -257,7 +278,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -265,6 +286,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -285,7 +308,7 @@ def forward( self.q_norm(xq), self.k_norm(xk), freqs_cis=freqs_cis, - position_ids=position_ids, + positions=positions, ) # repeat k/v heads if n_kv_heads < n_heads @@ -296,20 +319,31 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - assert ( - isinstance(attention_masks, BlockMask) or attention_masks is None - ), attention_masks - - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + case "varlen": + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + ) + case "sdpa": + assert attention_masks is None + output = self.inner_attention(xq, xk, xv) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) output = output.view(bs, seqlen, -1) return self.wo(output) @@ -421,7 +455,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -429,16 +463,15 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ h = x + self.attention( - self.attention_norm(x), - freqs_cis, - attention_masks, - position_ids=position_ids, + self.attention_norm(x), freqs_cis, attention_masks, positions ) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -450,7 +483,7 @@ def init_weights(self): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module, ModelProtocol): +class Transformer(ModelProtocol): """ Transformer Module @@ -471,7 +504,7 @@ class Transformer(nn.Module, ModelProtocol): """ def __init__(self, model_args: TransformerModelArgs, peft_config: PEFT): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers @@ -528,6 +561,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights() if self.norm is not None: self.norm.reset_parameters() @@ -557,42 +591,65 @@ def _precompute_freqs_cis(self) -> torch.Tensor: self.model_args.rope_scaling_args, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] + match self.model_args.attn_mask_type: case "causal": B = 1 case "block_causal": B = input_batch.shape[0] mask_mods.append(get_document_mask_mod(input_batch, tokenizer.eos_id)) - case "block_causal_by_sequence_lengths": - sequence_lengths = extra_inputs.pop("sequence_lengths", None) - if sequence_lengths is None: - raise RuntimeError( - "`sequence_lengths` required for `block_causal_by_sequence_lengths`" - ) - B = input_batch.shape[0] - mask_mods.append( - get_block_causal_mask_mod_by_seq_lens(sequence_lengths) - ) case _: raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) + return create_attention_mask( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer, extra_inputs + ) + case "varlen": + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + # Use explicit sequence_lengths from extra_inputs if available, + # otherwise fall back to EOS-based document detection + if extra_inputs is not None and "sequence_lengths" in extra_inputs: + return create_varlen_metadata_from_sequence_lengths( + extra_inputs["sequence_lengths"], + input_batch.shape[1], + input_batch.device, + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id + ) + case _: + raise TypeError("Only varlen and flex attn masks are supported") + def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -602,6 +659,8 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. @@ -609,16 +668,15 @@ def forward( """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): h = layer( - h, - self.freqs_cis, - attention_masks=attention_masks, - position_ids=position_ids, + h, self.freqs_cis, attention_masks=attention_masks, positions=positions ) - + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama3/model/state_dict_adapter.py b/torchtitan/models/llama3/model/state_dict_adapter.py index 15eec5ec95..f0a1f5f674 100644 --- a/torchtitan/models/llama3/model/state_dict_adapter.py +++ b/torchtitan/models/llama3/model/state_dict_adapter.py @@ -130,6 +130,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] # We need to permute the weights in wq and wk layer in order to account for the difference between @@ -175,6 +176,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] @@ -200,5 +202,6 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/llama4/__init__.py b/torchtitan/models/llama4/__init__.py index 24196c2326..b8bd9a4484 100644 --- a/torchtitan/models/llama4/__init__.py +++ b/torchtitan/models/llama4/__init__.py @@ -67,7 +67,7 @@ rope_scaling_args=RoPEScalingArgs(), every_n_layers_nope=4, fixed_attn_block_size=256, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx16e_irope": TransformerModelArgs( @@ -83,7 +83,7 @@ moe_args=MoEArgs(num_experts=16), interleave_moe_layer_step=1, every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), "17bx128e_irope": TransformerModelArgs( @@ -96,7 +96,7 @@ rope_theta=500000, moe_args=MoEArgs(num_experts=128), every_n_layers_nope=4, - use_flex_attn=True, + attn_type="flex", attn_mask_type="block_causal", ), } diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index efc73ccc0e..31b49af3fc 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +from typing import Any + import torch import torch.nn as nn from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( @@ -38,10 +40,15 @@ from torchtitan.config.job_config import Compile as CompileConfig from torchtitan.distributed import NoParallel, ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac - +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module +from torchtitan.distributed.dual_pipe_v import ( + DualPipeExpertParallel, + get_dual_pipe_v_flag, +) from torchtitan.distributed.expert_parallel import ( + BaseExpertParallel, + DeepEPExpertParallel, ExpertParallel, - ExpertParallelDeepEP, ExpertTensorParallel, ReordererSequenceParallel, TensorParallel, @@ -51,12 +58,14 @@ from torchtitan.models.moe import moe as moe_module from torchtitan.tools.logging import logger - # for selective op activation checkpointing _op_sac_save_list = { torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, torch.ops._c10d_functional.all_to_all_single.default, # for low precision training, it's useful to always save @@ -64,6 +73,7 @@ # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch._higher_order_ops.inductor_compiled_code, } @@ -79,7 +89,6 @@ def parallelize_llama( NOTE: The passed-in model preferably should be on meta device. Otherwise, the model must fit on GPU or CPU memory. """ - world_mesh = parallel_dims.world_mesh # TODO: TP currently cannot handle uneven seq_len because we set # `use_local_output=True` to use plain Tensors for legacy reasons. # Need to revisit this. @@ -90,10 +99,7 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - + tp_mesh = None if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -106,61 +112,97 @@ def parallelize_llama( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, ) - maybe_enable_async_tp(job_config, world_mesh["tp"]) + maybe_enable_async_tp(job_config, tp_mesh) + + # Check if using DeepEP for MoE communication + if job_config.parallelism.expert_parallel_comm_backend == "deepep": + if not parallel_dims.ep_enabled: + raise ValueError( + "DeepEP requires expert parallelism (ep_degree > 1). " + "The DeepEP MoE model code does not support EP=1. " + "Please set expert_parallel_degree > 1 or use standard communication backend." + ) + if parallel_dims.etp_enabled: + raise NotImplementedError( + "DeepEP with Expert Tensor Parallelism (ETP) is not supported yet. " + "Please set expert_tensor_parallel_degree=1 or use standard communication backend." + ) + + use_deepep = True + + # Import deepep module to register custom ops before accessing them + import torchtitan.distributed.deepep # noqa: F401 - registers torch.ops.deepep + + _op_sac_save_list.add(torch.ops.deepep.dispatch.default) + _op_sac_save_list.add(torch.ops.deepep.combine.default) + else: + use_deepep = False if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + dual_pipe_v=dual_pipe_v, + use_deepep=use_deepep, + ) + + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) if job_config.activation_checkpoint.mode != "none": + if job_config.activation_checkpoint.selective_ac_option == "op": + logger.info( + f"SAC save list contains {len(_op_sac_save_list)} ops: " + f"{sorted([str(op) for op in _op_sac_save_list])}" + ) apply_ac( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) - dp_mesh: DeviceMesh | None = None if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: - # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + # dp_mesh is the mesh for FSDP/HSDP + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -171,12 +213,9 @@ def parallelize_llama( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -184,15 +223,12 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if parallel_dims.world_size != dp_mesh.size(): raise RuntimeError("DDP has not supported > 1D parallelism") - dp_mesh = world_mesh apply_ddp( model, dp_mesh, @@ -253,12 +289,16 @@ def apply_non_moe_tp( ) # Apply tensor + sequence parallelism to every transformer block + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), + # NOTE: when the fourth argument (positions) is not None, its input layout + # and desired input layout is still None as we don't convert freqs_cis to + # a DTensor for llama4. "attention": prepare_module_input( - input_layouts=(Shard(1), None, None), - desired_input_layouts=(Replicate(), None, None), + input_layouts=(Shard(1), None, None, None), + desired_input_layouts=(Replicate(), None, None, None), ), "attention.wq": colwise_parallel(), "attention.wk": colwise_parallel(), @@ -266,6 +306,7 @@ def apply_non_moe_tp( "attention.wo": rowwise_parallel(output_layouts=Shard(1)), "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -280,6 +321,7 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, @@ -298,10 +340,11 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, - reshard_after_forward_policy: str = "default", + reshard_after_forward_policy: str | int = "default", ep_degree: int = 1, - dp_mod_ep_mesh: DeviceMesh | None = None, + edp_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, + disable_prefetch: bool = False, ): """ Apply data parallelism (via FSDP2) to the model. @@ -313,46 +356,60 @@ def apply_fsdp( reduce_dtype (torch.dtype): The data type to use for reduction operations. pp_enabled (bool): Whether pipeline parallelism is enabled. cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. - reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". - Other options: "never", "always". + reshard_after_forward_policy (str | int, optional): The policy to use for resharding after forward pass. Defaults to "default". + String options: "never", "always", "default". - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + Integer option: N (e.g., 8) for partial resharding to N-GPU groups. + - Reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. + - Use N=8 for intra-node resharding (fast NVLink communication). + - N must be a factor of the FSDP shard world size. + disable_prefetch (bool, optional): Whether to disable FSDP forward/backward prefetching. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) - fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} + fsdp_config: dict[str, Any] = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() - match reshard_after_forward_policy: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." - ) + # Handle integer reshard_after_forward (partial resharding to N-GPU groups) + if isinstance(reshard_after_forward_policy, int): + reshard_after_forward = reshard_after_forward_policy + logger.info( + f"Using partial reshard_after_forward={reshard_after_forward} (resharding to {reshard_after_forward}-GPU groups)" + ) + else: + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) if model.tok_embeddings is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( model.tok_embeddings, **fsdp_config, reshard_after_forward=reshard_after_forward, ) + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.items(): # NOTE: When EP is enabled, In an MoE layer, we use the following FSDP wrapping # - the router and the shared experts are sharded together with the TransformerBlock - # - the routed experts are sharded with the remaining dp_mod_ep_mesh + # - the routed experts are sharded with the remaining edp_mesh if transformer_block.moe_enabled and ep_degree > 1: fsdp_mod_ep_config = fsdp_config.copy() - fsdp_mod_ep_config["mesh"] = dp_mod_ep_mesh + fsdp_mod_ep_config["mesh"] = edp_mesh # NOTE: EP alreadys shards the routed experts on dim 0 (num_experts). # When dp_mod_ep * ep > num_experts, FSDP default dim-0 sharding @@ -361,10 +418,10 @@ def apply_fsdp( # on non-0 dim. For now it may not be worth the complexity to support # shard_placement_fn on the outer TransformerBlock-level FSDP. _experts_shard_placement_fn = None - assert dp_mod_ep_mesh is not None + assert edp_mesh is not None assert hasattr(transformer_block, "moe") if ( - dp_mod_ep_mesh.size() * ep_degree + edp_mesh["efsdp"].size() * ep_degree > transformer_block.moe.experts.num_experts ): _experts_shard_placement_fn = lambda param: Shard(1) @@ -448,6 +505,7 @@ def apply_fsdp( # As an optimization, do not reshard_after_forward the last layers by default # since FSDP would prefetch them immediately after the forward pass if model.norm is not None and model.output is not None: + # pyrefly: ignore [no-matching-overload] fully_shard( [model.norm, model.output], **fsdp_config, @@ -461,50 +519,71 @@ def apply_fsdp( if ep_degree == 1: return + # Skip prefetch setup if disabled + if disable_prefetch: + logger.info("FSDP prefetching is disabled") + return + # forward + # pyrefly: ignore [not-callable] transformer_blocks = list(model.layers.values()) next_transformer_blocks = transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.tok_embeddings is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.tok_embeddings.set_modules_to_forward_prefetch([transformer_blocks[0]]) for transformer_block, next_transformer_block in zip( transformer_blocks, next_transformer_blocks ): if next_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if next_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( + # pyrefly: ignore [missing-attribute] [next_transformer_block, next_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [next_transformer_block] ) elif model.norm is not None and model.output is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_forward_prefetch( [model.norm, model.output] ) # backward + # pyrefly: ignore [not-callable] reversed_transformer_blocks = list(reversed(model.layers.values())) prev_transformer_blocks = reversed_transformer_blocks[1:] + [None] + # pyrefly: ignore [bad-argument-type] if model.norm is not None and model.output is not None and len(model.layers) > 0: + # pyrefly: ignore [missing-attribute] model.output.set_modules_to_backward_prefetch([reversed_transformer_blocks[0]]) for transformer_block, prev_transformer_block in zip( reversed_transformer_blocks, prev_transformer_blocks ): if prev_transformer_block is not None: + # pyrefly: ignore [missing-attribute] if prev_transformer_block.moe_enabled: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( + # pyrefly: ignore [missing-attribute] [prev_transformer_block, prev_transformer_block.moe.experts] ) else: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch( [prev_transformer_block] ) elif model.tok_embeddings is not None: + # pyrefly: ignore [missing-attribute] transformer_block.set_modules_to_backward_prefetch([model.tok_embeddings]) @@ -512,13 +591,16 @@ def apply_moe_ep_tp( model: nn.Module, tp_mesh: DeviceMesh | None, ep_mesh: DeviceMesh | None, - ep_tp_mesh: DeviceMesh | None, - etp_enabled: bool, - use_deepep: bool, + etp_mesh: DeviceMesh | None, + ep_etp_mesh: DeviceMesh | None, + dual_pipe_v: bool = False, + use_deepep: bool = False, ): assert ep_mesh is not None or tp_mesh is not None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: continue @@ -541,13 +623,16 @@ def apply_moe_ep_tp( # replicate computation for the router "moe.router.gate": NoParallel(), } - if ep_mesh is not None and not etp_enabled: + if ep_mesh is not None and etp_mesh is None: # If TP is borrowed for EP, then split the tokens across TP ranks so that # the reorderer, the all-to-all comms, and routed experts computation # are effectively running Sequence Parallel (split along the folded bs*slen dim) + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + # pyrefly: ignore [missing-attribute] if transformer_block.moe.shared_experts is not None: # input Replicate, output Partial + # pyrefly: ignore [no-matching-overload] moe_layer_plan.update( { "moe.shared_experts.w1": LoraColwiseParallel(), @@ -558,156 +643,144 @@ def apply_moe_ep_tp( } ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, + # pyrefly: ignore [bad-argument-type] parallelize_plan=moe_layer_plan, ) experts_mesh, experts_plan = None, None if ep_mesh is None: + assert ep_etp_mesh is None experts_mesh = tp_mesh # input Replicate, output Partial experts_plan = TensorParallel() - elif tp_mesh is None or not etp_enabled: + elif tp_mesh is None or etp_mesh is None: + assert ep_etp_mesh is None experts_mesh = ep_mesh - # input / output sharding on the batch / tokens dim - experts_plan = ( - ExpertParallelDeepEP() if use_deepep is True else ExpertParallel() - ) - if use_deepep is True: - logger.info( - "Enabling deep_ep and fused all-to-all communication for expert parallelism" + if use_deepep: + # pyrefly: ignore [missing-attribute] + score_before_experts = transformer_block.moe.score_before_experts + + experts_plan = DeepEPExpertParallel( + score_before_experts=score_before_experts, ) + logger.info("Applying DeepEP to MoE layer") + else: + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() else: - experts_mesh = ep_tp_mesh + experts_mesh = ep_etp_mesh experts_plan = ExpertTensorParallel() + if dual_pipe_v and isinstance(experts_plan, BaseExpertParallel): + experts_plan = DualPipeExpertParallel(experts_plan) + parallelize_module( + # pyrefly: ignore [missing-attribute] module=transformer_block.moe.experts, device_mesh=experts_mesh, parallelize_plan=experts_plan, ) -def old_apply_compile(model: nn.Module, compile_config: CompileConfig): +def apply_compile(model: nn.Module, compile_config: CompileConfig, ep_enabled: bool): """ Apply torch.compile to each TransformerBlock, which makes compilation efficient due to repeated structure. Alternatively one can compile the whole model (after applying DP). """ # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE # but it is experimental. - # torch._dynamo.config.capture_scalar_outputs = True + torch._dynamo.config.capture_scalar_outputs = True + # pyrefly: ignore [missing-attribute] for layer_id, transformer_block in model.layers.named_children(): - # IMPORTANT: MoE layers MUST use fullgraph=False, non-MoE layers SHOULD use fullgraph=True. - # - # Why MoE needs fullgraph=False: - # DeepEP's PrimusTurboDeepepManager.setup_metadata() mutates instance state - # (self.token_probs, self.token_indices) inside activation checkpointing regions. - # When fullgraph=True, torch.compile wraps these in HigherOrderOperator which - # raises "Mutating a variable not in the current scope (SideEffects)" error. - # fullgraph=False allows graph breaks, permitting these state mutations. - # - # Why non-MoE layers should use fullgraph=True: - # fullgraph=False on attention/FFN layers causes unnecessary graph breaks, - # which interferes with activation checkpointing and retains more intermediate - # tensors during backward pass, leading to OOM on memory-constrained setups. - # - # Respect user config for non-MoE layers, but MoE layers MUST use False - fullgraph = compile_config.fullgraph + # pyrefly: ignore[missing-attribute] if transformer_block.moe_enabled: - fullgraph = False - transformer_block = torch.compile( - transformer_block, - backend=compile_config.backend, - fullgraph=fullgraph, - ) - model.layers.register_module(layer_id, transformer_block) - - logger.info("Compiling each TransformerBlock with torch.compile") - - -def apply_compile(model: nn.Module, compile_config: CompileConfig): - """ - Apply torch.compile to each TransformerBlock, which makes compilation efficient due to - repeated structure. Alternatively one can compile the whole model (after applying DP). - """ - try: - # Workaround for https://github.com/pytorch/pytorch/issues/166926 - torch._C._dynamo.eval_frame._set_lru_cache(False) - # NOTE: This flag is needed for torch.compile to avoid graph breaking on dynamic shapes in token-choice MoE - # but it is experimental. - torch._dynamo.config.capture_scalar_outputs = True - for layer_id, transformer_block in model.layers.named_children(): - if transformer_block.moe_enabled: - # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break - # So we must weave compile wrappers around those FSDP hooks to - # prevent AC from falling back the whole graph to eager. - # TODO: Fix Compile(AC(graph break)) - - if isinstance(transformer_block, CheckpointWrapper): - # TODO: Make CheckpointWrapper a transparent wrapper - # unwrap so that .named_children() works - block = transformer_block._checkpoint_wrapped_module - else: - block = transformer_block + # If it is a MoE layer, FSDP(GroupedExperts) will cause a graph break + # So we must weave compile wrappers around those FSDP hooks to + # prevent AC from falling back the whole graph to eager. + # TODO: Fix Compile(AC(graph break)) + + if isinstance(transformer_block, CheckpointWrapper): + # TODO: Make CheckpointWrapper a transparent wrapper + # unwrap so that .named_children() works + block = transformer_block._checkpoint_wrapped_module + else: + block = transformer_block - for attr_name, submod in block.named_children(): - assert getattr(block, attr_name) == getattr( - transformer_block, attr_name - ) + for attr_name, submod in block.named_children(): + assert getattr(block, attr_name) == getattr( + transformer_block, attr_name + ) - if isinstance(submod, moe_module.MoE): - # avoid graph breaking on the GroupedExperts' FSDP hooks - # by wrapping each submod's forward instead of their __call__ - moe = submod - for attr_name, submod in moe.named_children(): - if attr_name == "experts": - # NOTE: We don't compile token dispatch and token combine due to an issue on B200: - # https://github.com/pytorch/torchtitan/issues/1940 - continue - setattr( - moe, - attr_name, - torch.compile( - submod, - backend=compile_config.backend, - fullgraph=compile_config.fullgraph, - ), - ) - else: + if isinstance(submod, moe_module.MoE): + # avoid graph breaking on the GroupedExperts' FSDP hooks + # by wrapping each submod's forward instead of their __call__ + moe = submod + for attr_name, submod in moe.named_children(): + if attr_name == "experts": + # NOTE: We don't compile token dispatch and token combine due to an issue on B200: + # https://github.com/pytorch/torchtitan/issues/1940 + continue setattr( - block, + moe, attr_name, torch.compile( - submod, - backend=compile_config.backend, - fullgraph=compile_config.fullgraph, + submod, backend=compile_config.backend, fullgraph=compile_config.fullgraph ), ) + else: + setattr( + block, + attr_name, + torch.compile( + submod, backend=compile_config.backend, fullgraph=compile_config.fullgraph + ), + ) - else: - # If it's not a MoE layer, there is no FSDP(GroupedExperts) - # So we can compile the whole block - transformer_block = torch.compile( - transformer_block, - backend=compile_config.backend, - fullgraph=compile_config.fullgraph, - ) + else: + # If it's not a MoE layer, there is no FSDP(GroupedExperts) + # So we can compile the whole block + transformer_block = torch.compile( + transformer_block, + backend=compile_config.backend, + fullgraph=compile_config.fullgraph, + ) - model.layers.register_module(layer_id, transformer_block) + # pyrefly: ignore [missing-attribute] + model.layers.register_module(layer_id, transformer_block) + # Patch some globals only once (apply_compile is called multiple times for PP setup) + already_patched = ( + "_run_experts_grouped_mm_dynamic" + in moe_module._run_experts_grouped_mm.__qualname__ + ) + if not already_patched: moe_module._run_experts_grouped_mm = torch.compile( moe_module._run_experts_grouped_mm, backend=compile_config.backend, fullgraph=compile_config.fullgraph, ) - # NOTE: We don't compile for loop code path due to an issue with unbacked symints: - # https://github.com/pytorch/pytorch/issues/166460 + if ep_enabled: + compiled_fn = moe_module._run_experts_grouped_mm - logger.info("Compiling each TransformerBlock with torch.compile") - except AttributeError: - logger.warning( - "Using old compile, if you need new compile support (MoE) please upgrade to torch nightly" - ) - return old_apply_compile(model, compile_config) + # keep function logic in sync with `already_patched` above + def _run_experts_grouped_mm_dynamic( + w1: torch.Tensor, + w2: torch.Tensor, + w3: torch.Tensor, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + # dynamic number of tokens in expert parallel + torch._dynamo.mark_dynamic(x, 0) + return compiled_fn(w1, w2, w3, x, num_tokens_per_expert) + + moe_module._run_experts_grouped_mm = _run_experts_grouped_mm_dynamic + + # NOTE: We don't compile for loop code path due to an issue with unbacked symints: + # https://github.com/pytorch/pytorch/issues/166460 + + logger.info("Compiling each TransformerBlock with torch.compile") diff --git a/torchtitan/models/llama4/model/args.py b/torchtitan/models/llama4/model/args.py index fcce6eb8d1..a93030d82f 100644 --- a/torchtitan/models/llama4/model/args.py +++ b/torchtitan/models/llama4/model/args.py @@ -44,7 +44,7 @@ class TransformerModelArgs(BaseModelArgs): # `False`, each uses the total number of transformer blocks depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" # iRoPE settings # When ``every_n_layers_nope`` is specified, NoPE (no positional embedding) is @@ -76,22 +76,21 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + if ( + job_config.parallelism.context_parallel_degree > 1 + and self.attn_type != "sdpa" + ): raise NotImplementedError( - "CP support for FlexAttention is still in progress." + f"Context Parallel only supports SDPA attention. " + f"Got attn_type='{self.attn_type}'. " + f"FlexAttention and varlen attention are not supported with CP." ) self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) - # Pass DeepEP config to MoE layer and validate - self.moe_args.deepep_config = job_config.deepep - self.moe_args.validate_deepep_config() - - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops( self, model, diff --git a/torchtitan/models/llama4/model/model.py b/torchtitan/models/llama4/model/model.py index 0d8bda6aa2..c98e3b6124 100644 --- a/torchtitan/models/llama4/model/model.py +++ b/torchtitan/models/llama4/model/model.py @@ -88,19 +88,23 @@ def precompute_freqs_cis( return freqs_cis -def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + freqs_cis: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor for broadcasting it with another tensor. This function reshapes the frequency tensor to have the same shape as the target tensor 'x' for the purpose of broadcasting the frequency tensor during element-wise operations. - The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim), + The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim // 2), and the first seqlen elements will be sliced, but dim must match x. Args: freqs_cis (torch.Tensor): Frequency tensor to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. @@ -108,16 +112,35 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten ndim = x.ndim assert ndim > 1 seqlen = x.shape[1] - freqs_cis = freqs_cis[0:seqlen] - assert freqs_cis.shape == (seqlen, x.shape[-1]) - shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] - return freqs_cis.view(*shape) + if positions is None: + freqs_cis = freqs_cis[0:seqlen] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + freqs_cis = freqs_cis[positions.squeeze(0)] + assert freqs_cis.shape == (seqlen, x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + else: + assert positions.shape == (x.shape[0], seqlen) + freqs_cis_expanded = freqs_cis[None, :, None, :].expand(x.shape[0], -1, -1, -1) + freqs_cis = torch.gather( + freqs_cis_expanded, + dim=1, + index=positions.view(x.shape[0], seqlen, 1, 1).expand( + x.shape[0], seqlen, 1, freqs_cis_expanded.shape[-1] + ), + ) + return freqs_cis def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. @@ -131,13 +154,14 @@ def apply_rotary_emb( xq (torch.Tensor): Query tensor to apply rotary embeddings. xk (torch.Tensor): Key tensor to apply rotary embeddings. freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) - freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, positions) xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) return xq_out.type_as(xq), xk_out.type_as(xk) @@ -222,11 +246,17 @@ def __init__( # values of these two variables. self.use_rope = use_rope - self.use_flex_attn = model_args.use_flex_attn - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + self.attn_type = model_args.attn_type + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "sdpa": + # pyrefly: ignore [bad-assignment] + self.inner_attention = ScaledDotProductAttentionWrapper() + case "varlen": + raise ValueError("Varlen attention is not supported with Llama 4.") + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -237,7 +267,8 @@ def forward( self, x: torch.Tensor, freqs_cis: torch.Tensor, - attention_masks: AttentionMasksType | None, + attention_masks: AttentionMasksType, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. @@ -245,6 +276,8 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed frequency tensor. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -262,7 +295,7 @@ def forward( xv = xv.view(bs, seqlen, -1, self.head_dim) if self.use_rope: - xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis) + xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis, positions=positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -272,7 +305,7 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: + if self.attn_type == "flex": assert isinstance(attention_masks, dict), attention_masks attention_mask = attention_masks["rope" if self.use_rope else "nope"] output = self.inner_attention(xq, xk, xv, block_mask=attention_mask) @@ -426,6 +459,7 @@ def forward( x: torch.Tensor, freqs_cis: torch.Tensor, attention_masks: AttentionMasksType | None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -433,12 +467,16 @@ def forward( Args: x (torch.Tensor): Input tensor. freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ - h = x + self.attention(self.attention_norm(x), freqs_cis, attention_masks) + h = x + self.attention( + self.attention_norm(x), freqs_cis, attention_masks, positions + ) if self.moe_enabled: out = h + self.moe(self.ffn_norm(h)) else: @@ -455,7 +493,7 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Transformer(nn.Module, ModelProtocol): +class Transformer(ModelProtocol): """ Transformer Module @@ -476,7 +514,7 @@ class Transformer(nn.Module, ModelProtocol): """ def __init__(self, model_args: TransformerModelArgs, peft_config: PEFT): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers @@ -533,6 +571,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device=buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -550,9 +589,6 @@ def init_weights( def _precompute_freqs_cis(self) -> torch.Tensor: return precompute_freqs_cis( self.model_args.dim // self.model_args.n_heads, - # Need to compute until at least the max token limit for generation - # TODO: explain in docs/composability.md why we removed the 2x - # relaxing in our CP enablement PR self.model_args.max_seq_len, self.model_args.rope_theta, self.model_args.rope_scaling_args, @@ -592,6 +628,7 @@ def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -601,17 +638,22 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer(h, self.freqs_cis, attention_masks) + h = layer(h, self.freqs_cis, attention_masks, positions) + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/llama4/model/state_dict_adapter.py b/torchtitan/models/llama4/model/state_dict_adapter.py index 182981c665..c272b2ac10 100644 --- a/torchtitan/models/llama4/model/state_dict_adapter.py +++ b/torchtitan/models/llama4/model/state_dict_adapter.py @@ -52,6 +52,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: to_combine = defaultdict(dict) for key, value in state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: @@ -77,6 +78,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: hf_abstract_key = ( "language_model.model.layers.{}.feed_forward.experts.gate_up_proj" ) + # pyrefly: ignore [unnecessary-comparison] if hf_abstract_key is None: continue to_combine[hf_abstract_key.format(layer_num)][ @@ -85,6 +87,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: # combine collected values for hf_fqn, tt_fqn_map in to_combine.items(): + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", hf_fqn).group(0) combine_values = [] # put into correct order to combine @@ -106,6 +109,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: for key, value in hf_state_dict.items(): if "layers" in key: + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) key = re.sub(r"(\d+)", "{}", key, count=1) else: diff --git a/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml b/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml index fa0624bc8e..36d36712a2 100644 --- a/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml +++ b/torchtitan/models/llama4/train_configs/llama4_17bx128e.toml @@ -17,7 +17,7 @@ save_tb_folder = "tb" [model] name = "llama4" flavor = "17bx128e" -hf_assets_path = "./assets/hf/Llama-4-Scout-17B-128E" +hf_assets_path = "./assets/hf/Llama-4-Maverick-17B-128E" # converters = ["float8"] [optimizer] diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index c932f6aa83..1fccdaa572 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import ExpertRoutingHistogram, FeedForward, MoE, MoEArgs +from .moe import build_moe, fast_init_trunc_normal_, fast_init_normal_, ExpertRoutingHistogram, FeedForward, MoE, MoEArgs -__all__ = ["FeedForward", "MoE", "MoEArgs", "ExpertRoutingHistogram"] +__all__ = ["FeedForward", "MoE", "MoEArgs", "build_moe", "ExpertRoutingHistogram", "fast_init_trunc_normal_", "fast_init_normal_"] diff --git a/torchtitan/models/moe/kernels.py b/torchtitan/models/moe/kernels.py index 7aac7b3ac4..a1b1d17771 100644 --- a/torchtitan/models/moe/kernels.py +++ b/torchtitan/models/moe/kernels.py @@ -92,8 +92,11 @@ def fill_indices_wrapper( start_index_values, write_offsets, permuted_indices, + # pyrefly: ignore [bad-argument-type] experts_per_rank, + # pyrefly: ignore [bad-argument-type] num_ranks, + # pyrefly: ignore [bad-argument-type] BLOCK_SIZE=block_size, ) return permuted_indices diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 46c2aa4484..9caa4a1d9f 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -17,27 +17,56 @@ from torchtitan.tools.logging import logger from .utils import indices_padding_wrapper, indices_padding_wrapper_lora -# Lazy import DeepEP - only required when use_deepep=True in config -try: - from torchtitan.distributed.deepep.fused_activation import fused_silu_gate_prob - from torchtitan.distributed.deepep.utils import DeepEPTokenDispatcher - - DEEPEP_AVAILABLE = True -except ImportError: - DEEPEP_AVAILABLE = False - fused_silu_gate_prob = None - DeepEPTokenDispatcher = None - @dataclass class ExpertRoutingHistogram: counts: list[float] +# see https://arxiv.org/pdf/2310.10837 def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 +def fast_init_trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + """ + Fast truncated normal initialization that handles bfloat16 tensors on CPU. + + When tensors are bfloat16 on CPU, nn.init.trunc_normal_ is extremely slow + because CPUs don't have native bfloat16 support. This function temporarily + converts to float32 for the initialization, then converts back. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + # Initialize in float32 for CPU performance + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.trunc_normal_(temp, mean=mean, std=std, a=a, b=b) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) + + +def fast_init_normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0 +) -> None: + """ + Fast normal initialization that handles bfloat16 tensors on CPU. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.normal_(temp, mean=mean, std=std) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.normal_(tensor, mean=mean, std=std) + + @dataclass class MoEArgs: num_experts: int = 8 @@ -50,8 +79,10 @@ class MoEArgs: route_scale: float = 1.0 score_before_experts: bool = True - # token-choice + # token-choice with optional node limited routing top_k: int = 1 + num_expert_groups: int | None = None # must be a divisor of num_experts + num_limited_groups: int | None = None use_grouped_mm: bool = True # grouped mm or for-loop for the experts computation load_balance_coeff: float | None = 1e-3 @@ -174,9 +205,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=init_std) # NOTE: keeping this for-loop implementation for comparison @@ -189,20 +220,20 @@ def _run_experts_for_loop( num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: # NOTE: this would incur a synchronization between device and host - num_tokens_per_expert = num_tokens_per_expert.tolist() + num_tokens_per_expert_list = num_tokens_per_expert.tolist() # side-effect code due to the usage of generate_permute_indices - num_padding = x.shape[0] - sum(num_tokens_per_expert) + num_padding = x.shape[0] - sum(num_tokens_per_expert_list) # a tuple of tensors indexed by experts # each with shape (tokens_per_expert(varying), dim) - x = torch.split( - x[: sum(num_tokens_per_expert)], - split_size_or_sections=num_tokens_per_expert, + x_splits = torch.split( + x[: sum(num_tokens_per_expert_list)], + split_size_or_sections=num_tokens_per_expert_list, dim=0, ) out_experts_splits = [] - for expert_idx, x_expert in enumerate(x): + for expert_idx, x_expert in enumerate(x_splits): h = F.silu(torch.matmul(x_expert, w1[expert_idx].transpose(-2, -1))) h = h * torch.matmul(x_expert, w3[expert_idx].transpose(-2, -1)) h = torch.matmul(h, w2[expert_idx].transpose(-2, -1)) @@ -345,10 +376,6 @@ def __init__( hidden_dim: int, num_experts: int, use_grouped_mm: bool, - deepep_dispatcher: DeepEPTokenDispatcher = None, - score_before_experts: bool = False, - use_fused_weighted_scatter: bool = True, - use_fused_silu_gate_prob: bool = False, ): super().__init__() self.num_experts = num_experts @@ -356,109 +383,47 @@ def __init__( self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim)) self.w3 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim)) self.use_grouped_mm = use_grouped_mm - self.deepep_dispatcher = deepep_dispatcher - # NOTE(phuc): fix compatibility with ExpertParallelDeepEP - # When DeepEP is used, it passes routed_prob to forward() which needs to be multiplied - # with the input/output. This flag controls whether to apply the scaling before or after - # the expert computation. See torchtitan-amd implementation for reference. - self.score_before_experts = score_before_experts - # When True and score_before_experts=False, skip multiplication here and do it - # in unpermute via fused_weighted_scatter_add kernel (2-3x faster) - self.use_fused_weighted_scatter = use_fused_weighted_scatter - # When True, use fused Triton kernel for silu(x@w1) * (x@w3) * prob (~3.5x faster) - # Only effective when score_before_experts=False and DeepEP is enabled - self.use_fused_silu_gate_prob = use_fused_silu_gate_prob def forward( self, x: torch.Tensor, num_tokens_per_expert: torch.Tensor, - routed_prob: torch.Tensor | None = None, ) -> torch.Tensor: - """Forward pass through grouped experts. - - Args: - x: Input tensor [num_routed_tokens, hidden] - num_tokens_per_expert: Number of tokens per expert [num_experts] - routed_prob: Routing probabilities [num_routed_tokens] (optional, for DeepEP) - """ if isinstance(self.w1, DTensor): # Convert parameters from DTensors to plain Tensors, to work with # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. w1 = self.w1.to_local() + # pyrefly: ignore [missing-attribute] w2 = self.w2.to_local() + # pyrefly: ignore [missing-attribute] w3 = self.w3.to_local() else: w1 = self.w1 w2 = self.w2 w3 = self.w3 - # Apply routing scores BEFORE expert computation if score_before_experts=True - if ( - self.deepep_dispatcher is not None - and routed_prob is not None - and self.score_before_experts - ): - x = (x.to(torch.float32) * routed_prob.reshape(-1, 1)).to(x.dtype) - - # Determine if we should use fused silu-gate-prob kernel - # Only when: DeepEP enabled, score_before_experts=False, and config enabled - should_use_fused_silu = ( - self.use_fused_silu_gate_prob - and self.deepep_dispatcher is not None - and routed_prob is not None - and not self.score_before_experts - ) - if self.use_grouped_mm: # NOTE: If EP is not used, we need to pad the indices # to prepare for grouped_mm; # otherwise, EP will handle the padding. if ( not isinstance(self.w1, DTensor) + # pyrefly: ignore[not-iterable] or "ep" not in self.w1.device_mesh.mesh_dim_names ): run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) else: run_experts_fn = _run_experts_grouped_mm - - if should_use_fused_silu: - # Fused path: prob multiplication happens inside the Triton kernel - out = run_experts_fn( - w1, - w2, - w3, - x, - num_tokens_per_expert, - routed_prob=routed_prob, - use_fused_silu_gate_prob=True, - ) - else: - # Original unfused path - out = run_experts_fn(w1, w2, w3, x, num_tokens_per_expert) + return run_experts_fn(w1, w2, w3, x, num_tokens_per_expert) else: - out = _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) - - # Apply routing scores AFTER expert computation if score_before_experts=False - # Skip if use_fused_weighted_scatter=True (will be done in unpermute instead) - # Skip if use_fused_silu_gate_prob=True (already done in fused kernel) - if ( - self.deepep_dispatcher is not None - and routed_prob is not None - and not self.score_before_experts - and not self.use_fused_weighted_scatter - and not should_use_fused_silu - ): - out = (out.to(torch.float32) * routed_prob.reshape(-1, 1)).to(out.dtype) - - return out + return _run_experts_for_loop(w1, w2, w3, x, num_tokens_per_expert) def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) def _groupmm(x, w, offs): @@ -602,12 +567,12 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) - nn.init.trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) nn.init.zeros_(self.w1_lora_b) nn.init.zeros_(self.w2_lora_b) nn.init.zeros_(self.w3_lora_b) @@ -617,9 +582,17 @@ class TokenChoiceTopKRouter(nn.Module): """This class implements token-choice routing. In token-choice top-K routing, each token is routed to top K experts based on the router scores. + Optionally supports node-limited (group-limited) routing where experts are divided into groups + (e.g., by node), and only num_limited_groups groups are considered before selecting top_k experts. + This reduces cross-node communication in distributed settings. + Args: dim (int): Dimension of input tokens. num_experts (int): Number of experts in each moe layer. + num_expert_groups (int | None): Number of expert groups for node-limited routing. If None, standard + top-k routing is used. Must be a divisor of num_experts. + num_limited_groups (int | None): Number of groups to select in node-limited routing. Required when + num_expert_groups is set. top_k (int): Number of experts each token will be routed to in token-choice routing. score_func (Literal["softmax", "sigmoid"]): Whether to use sigmoid or softmax for router scores. route_norm (bool): Whether to normalize the routing scores when using sigmoid. @@ -630,6 +603,8 @@ def __init__( self, dim: int, num_experts: int, + num_expert_groups: int | None, + num_limited_groups: int | None, top_k: int, score_func: Literal["softmax", "sigmoid"], route_norm: bool, @@ -639,6 +614,8 @@ def __init__( super().__init__() self.gate = nn.Linear(dim, num_experts, bias=False) self.num_experts = num_experts + self.num_expert_groups = num_expert_groups + self.num_limited_groups = num_limited_groups self.top_k = top_k self.score_func = score_func self.route_norm = route_norm @@ -662,6 +639,48 @@ def _debug_force_load_balance_routing( top_scores = scores.gather(dim=1, index=selected_experts_indices) # [N,K] return selected_experts_indices, top_scores + def _get_node_limited_routing_scores( + self, + scores_for_choice: torch.Tensor, + ) -> torch.Tensor: + """Select num_limited_groups groups based on group scores, + and set expert scores in non-selected groups as -inf + + Args: + scores_for_choice: Router scores with expert_bias (if any), shape (bs*slen, num_experts) + + Returns: + scores_for_choice: shape (bs*slen, num_experts) + """ + if self.num_limited_groups is None: + raise ValueError( + "num_limited_groups must be set when num_expert_groups is set" + ) + assert self.num_expert_groups is not None + if self.num_experts % self.num_expert_groups != 0: + raise ValueError( + f"num_experts ({self.num_experts}) must be divisible by num_expert_groups ({self.num_expert_groups})" + ) + experts_per_group = self.num_experts // self.num_expert_groups + if experts_per_group < 2: + raise ValueError(f"experts_per_group ({experts_per_group}) must be >= 2") + scores_grouped = scores_for_choice.view( + -1, self.num_expert_groups, experts_per_group + ) + top2_scores_in_group, _ = scores_grouped.topk(2, dim=-1) + group_scores = top2_scores_in_group.sum(dim=-1) + _, group_idx = torch.topk( + group_scores, k=self.num_limited_groups, dim=-1, sorted=False + ) + group_mask = torch.ones_like(group_scores, dtype=torch.bool) + group_mask.scatter_(1, group_idx, False) # False = selected groups (keep) + # Mask out experts from non-selected groups + scores_for_choice = scores_grouped.masked_fill( + group_mask.unsqueeze(-1), float("-inf") + ).view(-1, self.num_experts) + + return scores_for_choice + def forward( self, x: torch.Tensor, expert_bias: torch.Tensor | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -691,18 +710,18 @@ def forward( else: raise NotImplementedError(f"Unknown score function {self.score_func}") + scores_for_choice = scores if expert_bias is None else scores + expert_bias + # Apply node-limited routing if configured + if self.num_expert_groups is not None: + scores_for_choice = self._get_node_limited_routing_scores(scores_for_choice) + _, selected_experts_indices = torch.topk( + scores_for_choice, k=self.top_k, dim=-1, sorted=False + ) + # top scores shape (bs*slen, top_k) # NOTE: The expert_bias is only used for routing. The gating value # top_scores is still derived from the original scores. - if expert_bias is not None: - _, selected_experts_indices = torch.topk( - scores + expert_bias, k=self.top_k, dim=1 - ) - top_scores = scores.gather(dim=1, index=selected_experts_indices) - else: - top_scores, selected_experts_indices = torch.topk( - scores, k=self.top_k, dim=1 - ) + top_scores = scores.gather(dim=1, index=selected_experts_indices) # debug override: balanced round-robin routing if self._debug_force_load_balance: @@ -728,11 +747,15 @@ def forward( return top_scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float, n_layers: int): + # Init gate with each row normalized + # From "Approximating Two-Layer Feedforward Networks for Efficient Transformers" + # https://arxiv.org/pdf/2310.10837 + # NOTE: Must use in-place operations here. When FSDP wraps parameters as # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss # when CPU offload is enabled with 3+ GPUs. - nn.init.normal_(self.gate.weight, mean=0.0, std=1.0) + fast_init_normal_(self.gate.weight, mean=0.0, std=1.0) # Normalize rows in-place with torch.no_grad(): @@ -798,7 +821,6 @@ def forward( ) top_scores_experts_sorted = top_scores.view(-1)[token_indices_experts_sorted] - token_indices_experts_sorted = token_indices_experts_sorted // self.top_k return ( top_scores_experts_sorted, @@ -807,8 +829,6 @@ def forward( ) -# TODO(phuc): would be more clean if separate MoEWithDeepEP, and MoEDefault classes, -# because they have different forward logic class MoE(nn.Module): """ MoE Module @@ -830,34 +850,6 @@ def __init__( super().__init__() num_experts = moe_args.num_experts - - self.use_deepep = moe_args.use_deepep - - # Validate DeepEP availability when use_deepep=True - if self.use_deepep and not DEEPEP_AVAILABLE: - raise ImportError( - "use_deepep=True requires deep_ep to be installed, but it is not available. " - "Please install deep_ep or set use_deepep=False in your model config. " - "See torchtitan/distributed/deepep/README.md for installation instructions." - ) - - if self.use_deepep: - self.deepep_dispatcher = DeepEPTokenDispatcher( - moe_router_topk=moe_args.top_k, - num_moe_experts=num_experts, - deepep_config=moe_args.deepep_config, - score_before_experts=moe_args.score_before_experts, - ) - - # Determine use_fused_weighted_scatter from config - # Only relevant when DeepEP is enabled and score_before_experts=False - use_fused_weighted_scatter = False # default - use_fused_silu_gate_prob = False # default - if self.use_deepep and moe_args.deepep_config is not None: - use_fused_weighted_scatter = ( - moe_args.deepep_config.fused_weighted_scatter_add - ) - use_fused_silu_gate_prob = moe_args.deepep_config.fused_silu_gate_prob if peft_config is not None and peft_config.enable_peft: # TODO: # Update to deepep here @@ -874,20 +866,13 @@ def __init__( hidden_dim=hidden_dim, num_experts=num_experts, use_grouped_mm=moe_args.use_grouped_mm, - deepep_dispatcher=self.deepep_dispatcher - if self.use_deepep is True - else None, - # NOTE(phuc): fix ExpertParallelDeepEP compatibility - # This ensures that GroupedExperts knows whether to apply routing scores before or after - # expert computation when routed_prob is passed to forward() - score_before_experts=moe_args.score_before_experts, - use_fused_weighted_scatter=use_fused_weighted_scatter, - use_fused_silu_gate_prob=use_fused_silu_gate_prob, ) self.router = TokenChoiceTopKRouter( dim=dim, num_experts=num_experts, + num_expert_groups=moe_args.num_expert_groups, + num_limited_groups=moe_args.num_limited_groups, top_k=moe_args.top_k, score_func=moe_args.score_func, route_norm=moe_args.route_norm, @@ -956,7 +941,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: bs, slen, dim = x.shape x = x.view(-1, dim) - # top_scores and selected_experts_indices shape (bs*slen*top_k,) + # top_scores and selected_experts_indices shape (bs*slen, top_k) # num_tokens_per_expert shape (num_experts,) ( top_scores, @@ -974,70 +959,66 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: if self.log_expert_routing: self.expert_routing_counter.add_(num_tokens_per_expert) - # Compute shared expert output (shared by both DeepEP and non-DeepEP paths) - if self.shared_experts is not None: - shared_out = self.shared_experts(x) - if self.shared_gate is not None: - shared_gate_val = F.sigmoid(self.shared_gate(x)) - shared_expert_out = shared_out * shared_gate_val - else: - shared_expert_out = shared_out - else: - shared_expert_out = torch.zeros_like(x) - - if self.use_deepep: - top_scores = top_scores.float() - self.experts.deepep_dispatcher.dispatch_preprocess( - top_scores, selected_experts_indices + # top_scores_experts_sorted and token_indices_experts_sorted shape (bs*slen*top_k,) + # num_tokens_per_expert shape (num_experts,) + # NOTE: the reason we need to compute num_tokens_per_expert again is: + # 1st computation in router is to update self.tokens_per_expert + # which would be the same across all TP ranks. + # 2nd computation in reorderer is for the actual routing and experts computation + # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. + # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. + ( + top_scores_experts_sorted, + token_indices_experts_sorted, + num_tokens_per_expert, + ) = self.reorderer(top_scores, selected_experts_indices) + + # shape (bs*slen*top_k, dim) + routed_input = x[token_indices_experts_sorted // self.router.top_k] + + if self.score_before_experts: + routed_input = ( + routed_input.to(torch.float32) + * top_scores_experts_sorted.reshape(-1, 1) + ).to(x.dtype) + + # shape (bs*slen*top_k, dim) + routed_output = self.experts(routed_input, num_tokens_per_expert) + + # shared expert + # Note: we execute the shared expert before scoring the output of the routed expert + # to "implicitly" overlap the shared expert compute with token combine communication + out = self.shared_experts(x) if self.shared_experts is not None else None + + # Apply shared gate if configured + if out is not None and self.shared_gate is not None: + out = F.sigmoid(self.shared_gate(x)) * out + + # Unsort routed outputs + routed_output_unsorted = torch.zeros( + (bs * slen * self.router.top_k, dim), + dtype=routed_output.dtype, + device=routed_output.device, + ) + routed_output_unsorted[token_indices_experts_sorted] = routed_output + routed_output_unsorted = routed_output_unsorted.reshape( + -1, self.router.top_k, dim + ) + if not self.score_before_experts: + out_experts = ( + torch.bmm( + top_scores.reshape(-1, 1, self.router.top_k), + routed_output_unsorted.float(), + ) + .to(x.dtype) + .squeeze(1) ) - # shape (bs*slen*top_k, dim) - routed_output = self.experts(x, num_tokens_per_expert) - out = routed_output + shared_expert_out - out = out.reshape(bs, slen, dim) - return out else: - # top_scores and token_indices_experts_sorted shape (bs*slen*top_k,) - # num_tokens_per_expert shape (num_experts,) - # NOTE: the reason we need to compute num_tokens_per_expert again is: - # 1st computation in router is to update self.tokens_per_expert - # which would be the same across all TP ranks. - # 2nd computation in reorderer is for the actual routing and experts computation - # which would be sharded over TP ranks if expert_tensor_parallel_degree==1. - # If tensor_paralllel_degree == expert_tensor_parallel_degree, they agree. - ( - top_scores_experts_sorted, - token_indices_experts_sorted, - num_tokens_per_expert, - ) = self.reorderer(top_scores, selected_experts_indices) - - # shape (bs*slen*top_k, dim) - token_indices_experts_sorted = token_indices_experts_sorted.reshape( - -1, 1 - ).expand(-1, dim) - - # shape (bs*slen*top_k, dim) - routed_input = torch.gather(x, dim=0, index=token_indices_experts_sorted) + out_experts = routed_output_unsorted.sum(dim=1) - if self.score_before_experts: - routed_input = ( - routed_input.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - # shape (bs*slen*top_k, dim) - routed_output = self.experts(routed_input, num_tokens_per_expert) - - if not self.score_before_experts: - routed_output = ( - routed_output.to(torch.float32) - * top_scores_experts_sorted.reshape(-1, 1) - ).to(x.dtype) - - out = shared_expert_out.scatter_add( - dim=0, index=token_indices_experts_sorted, src=routed_output - ) - out = out.reshape(bs, slen, dim) - return out + if out is None: + return out_experts.reshape(bs, slen, dim) + return (out + out_experts).reshape(bs, slen, dim) def pop_expert_routing_metrics(self) -> torch.Tensor | None: if not self.log_expert_routing: @@ -1053,7 +1034,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i if self.shared_experts is not None: self.shared_experts.init_weights(init_std) if self.shared_gate is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.shared_gate.weight, mean=0.0, std=moe_init_std(self.shared_gate.weight.shape[1], n_layers), @@ -1067,6 +1048,25 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i self.experts.num_experts, dtype=torch.float32 ) if self.load_balance_coeff is not None: + # pyrefly: ignore[bad-assignment] self.expert_bias = torch.zeros( self.experts.num_experts, dtype=torch.float32 ) + + +def build_moe( + args: MoEArgs, dim: int, hidden_dim: int, peft_config: PEFT, moe_impl: str = "standard", +) -> nn.Module: + """Factory for MoE with different backends: 'standard' (all-to-all) or 'deepep' (DeepEP).""" + if moe_impl == "deepep": + from .moe_deepep import DeepEPMoE + + logger.info( + f"DeepEP MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}" + ) + return DeepEPMoE(moe_args=args, dim=dim, hidden_dim=hidden_dim) + + logger.info( + f"Standard MoE: num_experts={args.num_experts}, top_k={args.top_k}, dim={dim}, hidden_dim={hidden_dim}" + ) + return MoE(args, dim=dim, hidden_dim=hidden_dim, peft_config=peft_config) diff --git a/torchtitan/models/moe/moe_deepep.py b/torchtitan/models/moe/moe_deepep.py new file mode 100644 index 0000000000..54e3f0f2a3 --- /dev/null +++ b/torchtitan/models/moe/moe_deepep.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +"""MoE with DeepEP backend for efficient expert-parallel communication.""" + +import torch + +from .moe import MoE, MoEArgs + + +class DeepEPMoE(MoE): + """ + Mixture of Experts with DeepEP communication. + + Inherits from MoE but overrides forward() to pass routing info to experts, + letting DeepEPExpertParallel hooks handle dispatch/combine. + """ + + def __init__(self, moe_args: MoEArgs, dim: int, hidden_dim: int): + super().__init__(moe_args, dim, hidden_dim) + # DeepEP doesn't use reorderer - routing handled by DeepEPExpertParallel + self.reorderer = None # pyrefly: ignore [bad-assignment] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with DeepEP communication. + + DeepEPExpertParallel hooks intercept experts() call and handle + dispatch/combine via deepep functions. + """ + bs, slen, dim = x.shape + x = x.view(-1, dim) + + top_scores, selected_experts_indices, num_tokens_per_expert = self.router( + x, self.expert_bias + ) + + if self.load_balance_coeff is not None: + with torch.no_grad(): + self.tokens_per_expert.add_(num_tokens_per_expert) + + # Call experts with routing info - hooks handle DeepEP dispatch/combine + routed_output = self.experts( + x, + num_tokens_per_expert, + selected_experts_indices, + top_scores, + self.experts.num_experts, + ) + + out = self.shared_experts(x) if self.shared_experts is not None else None + + if out is None: + return routed_output.reshape(bs, slen, dim) + return (out + routed_output).reshape(bs, slen, dim) diff --git a/torchtitan/models/qwen3/__init__.py b/torchtitan/models/qwen3/__init__.py index ece4a13ccc..a359c40a03 100644 --- a/torchtitan/models/qwen3/__init__.py +++ b/torchtitan/models/qwen3/__init__.py @@ -82,7 +82,7 @@ rope_theta=1000000, enable_weight_tying=True, ), - "4B_vq_flex": Qwen3ModelArgs( + "4B_vq_varlen": Qwen3ModelArgs( vocab_size=168064, max_seq_len=16384, head_dim=128, @@ -95,7 +95,7 @@ rope_theta=1000000, enable_weight_tying=False, eos_id=151643, - use_flex_attn=True, + attn_type="varlen", attn_mask_type="block_causal", ), "8B": Qwen3ModelArgs( @@ -122,7 +122,7 @@ hidden_dim=17408, rope_theta=1000000, ), - "hermes4_14B_vq_flex": Qwen3ModelArgs( + "hermes4_14B_vq_varlen": Qwen3ModelArgs( vocab_size=168064, max_seq_len=16384, head_dim=128, @@ -133,8 +133,8 @@ qk_norm=True, hidden_dim=17408, rope_theta=1000000, - use_flex_attn=True, eos_id=151643, + attn_type="varlen", attn_mask_type="block_causal", ), "32B": Qwen3ModelArgs( @@ -220,7 +220,7 @@ use_deepep=True, ), ), - "10B-A1B-flex": Qwen3ModelArgs( + "10B-A1B-varlen": Qwen3ModelArgs( vocab_size=151936, max_seq_len=8192, head_dim=128, @@ -243,7 +243,7 @@ route_scale=1.0, score_before_experts=False, ), - use_flex_attn=True, + attn_type="varlen", attn_mask_type="block_causal", ), "30B-A3B-deepep": Qwen3ModelArgs( @@ -293,6 +293,56 @@ score_before_experts=False, ), ), + "30B-A3B-varlen": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=262144, + head_dim=128, + dim=2048, + n_layers=48, + n_heads=32, + n_kv_heads=4, + qk_norm=True, + hidden_dim=6144, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + attn_type="varlen", + attn_mask_type="block_causal", + ), + "30B-A3B-flex": Qwen3ModelArgs( + vocab_size=151936, + max_seq_len=262144, + head_dim=128, + dim=2048, + n_layers=48, + n_heads=32, + n_kv_heads=4, + qk_norm=True, + hidden_dim=6144, + rope_theta=1000000, + moe_enabled=True, + moe_inter_dim=768, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + top_k=8, + score_func="softmax", + route_norm=True, + route_scale=1.0, + score_before_experts=False, + ), + attn_type="flex", + attn_mask_type="block_causal_by_sequence_lengths", + ), "235B-A22B": Qwen3ModelArgs( vocab_size=151936, max_seq_len=4096, diff --git a/torchtitan/models/qwen3/infra/parallelize.py b/torchtitan/models/qwen3/infra/parallelize.py index 32bc21de3d..3dad54d46a 100644 --- a/torchtitan/models/qwen3/infra/parallelize.py +++ b/torchtitan/models/qwen3/infra/parallelize.py @@ -8,6 +8,7 @@ # training techniques (e.g. activation checkpointing and compile) to the Llama model. import torch +import torch._inductor.config import torch.nn as nn from torch.distributed.device_mesh import DeviceMesh @@ -21,6 +22,8 @@ from torchtitan.config import JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac +from torchtitan.distributed.context_parallel import apply_cp_to_attention_module +from torchtitan.distributed.dual_pipe_v import get_dual_pipe_v_flag from torchtitan.models.llama3.infra.parallelize import apply_ddp from torchtitan.models.llama4.infra.parallelize import ( apply_compile, @@ -35,12 +38,17 @@ torch.ops.aten.mm.default, torch.ops.aten._scaled_dot_product_efficient_attention.default, torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops.aten._scaled_dot_product_cudnn_attention.default, + torch.ops.aten._scaled_dot_product_attention_math.default, + torch.ops.aten._scaled_dot_product_fused_attention_overrideable.default, torch.ops._c10d_functional.reduce_scatter_tensor.default, # for low precision training, it's useful to always save # the result of max, since the absolute maximum is # used to compute the scaling factor for quantization. torch.ops.aten.max.default, torch._higher_order_ops.flex_attention, + torch.ops.torch_attn._varlen_attn.default, + torch._higher_order_ops.inductor_compiled_code, } @@ -49,7 +57,6 @@ def parallelize_qwen3( parallel_dims: ParallelDims, job_config: JobConfig, ): - world_mesh = parallel_dims.world_mesh assert ( job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 ), f""" @@ -57,9 +64,13 @@ def parallelize_qwen3( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + attn_type = getattr(model.model_args, "attn_type", "sdpa") + if job_config.parallelism.context_parallel_degree > 1 and attn_type == "varlen": + raise NotImplementedError( + f"Context Parallel only supports SDPA and FlexAttention." + f"Got attn_type='{attn_type}'. " + f"Varlen attention is not supported with CP." + ) model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components @@ -82,28 +93,34 @@ def parallelize_qwen3( # all-gather happens in high precision. enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + tp_mesh = parallel_dims.get_mesh("tp") apply_non_moe_tp( model, - world_mesh["tp"], + tp_mesh, loss_parallel=not job_config.parallelism.disable_loss_parallel, enable_float8_tensorwise_tp=enable_float8_tensorwise_tp, enable_async_tp=job_config.parallelism.enable_async_tensor_parallel, + positions_enabled=parallel_dims.cp_enabled or job_config.training.dataset_type == "preprocessed", ) if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + dual_pipe_v = get_dual_pipe_v_flag(job_config, parallel_dims) + apply_moe_ep_tp( model, - tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, - ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, - ep_tp_mesh=( - world_mesh["ep", "tp"] - if parallel_dims.tp_enabled - and parallel_dims.ep_enabled - and parallel_dims.etp_enabled - else None - ), - etp_enabled=parallel_dims.etp_enabled, - use_deepep=model.model_args.moe_args.use_deepep, + tp_mesh=parallel_dims.get_optional_mesh("tp"), + ep_mesh=parallel_dims.get_optional_mesh("ep"), + etp_mesh=parallel_dims.get_optional_mesh("etp"), + ep_etp_mesh=parallel_dims.get_optional_mesh(["ep", "etp"]), + dual_pipe_v=dual_pipe_v, + ) + + if parallel_dims.cp_enabled: + apply_cp_to_attention_module( + # pyrefly: ignore [missing-attribute, not-callable] + [block.attention.inner_attention for block in model.layers.values()], + parallel_dims.get_mesh("cp"), + attn_type, ) if job_config.activation_checkpoint.mode != "none": @@ -111,29 +128,29 @@ def parallelize_qwen3( model, job_config.activation_checkpoint, model_compile_enabled=model_compile_enabled, - use_flex_attn=use_flex_attn, + # pyrefly: ignore [bad-argument-type] op_sac_save_list=_op_sac_save_list, base_folder=job_config.job.dump_folder, ) # turn on per-TransformerBlock compile after AC wrapping and before FSDP if model_compile_enabled: - apply_compile(model, job_config.compile) + apply_compile(model, job_config.compile, parallel_dims.ep_enabled) if parallel_dims.fsdp_enabled: # apply FSDP or HSDP, potentially with Context Parallel - if parallel_dims.dp_replicate_enabled: - dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") - else: - dp_mesh_dim_names = ("dp_shard_cp",) - dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + dp_mesh_names = ( + ["dp_replicate", "fsdp"] if parallel_dims.dp_replicate_enabled else ["fsdp"] + ) + dp_mesh = parallel_dims.get_mesh(dp_mesh_names) # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP - dp_mod_ep_mesh_dim_names = [] - if parallel_dims.ep_enabled: - if parallel_dims.dp_replicate_enabled: - dp_mod_ep_mesh_dim_names.append("dp_replicate") - dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + edp_mesh_names = ( + ["dp_replicate", "efsdp"] + if parallel_dims.dp_replicate_enabled + else ["efsdp"] + ) + edp_mesh = parallel_dims.get_optional_mesh(edp_mesh_names) apply_fsdp( model, @@ -144,12 +161,9 @@ def parallelize_qwen3( cpu_offload=job_config.training.enable_cpu_offload, reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, ep_degree=parallel_dims.ep, - dp_mod_ep_mesh=( - world_mesh[tuple(dp_mod_ep_mesh_dim_names)] - if parallel_dims.ep_enabled - else None - ), + edp_mesh=edp_mesh, gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -157,22 +171,22 @@ def parallelize_qwen3( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: - if world_mesh.ndim > 1: + dp_mesh = parallel_dims.get_mesh("dp_replicate") + if dp_mesh.ndim > 1: raise RuntimeError("DDP has not supported > 1D parallelism") apply_ddp( model, - world_mesh, + dp_mesh, enable_compile=model_compile_enabled, ) # Enable weight tying after applying parallelisms + # pyrefly: ignore [missing-attribute] if model.model_args.enable_weight_tying: + # pyrefly: ignore [missing-attribute] model.output.weight = model.tok_embeddings.weight return model @@ -184,6 +198,7 @@ def apply_non_moe_tp( loss_parallel: bool, enable_float8_tensorwise_tp: bool, enable_async_tp: bool, + positions_enabled: bool, ): """Apply tensor parallelism.""" # 1. Parallelize the embedding and shard its outputs (which are the first @@ -233,12 +248,19 @@ def apply_non_moe_tp( # NOTE: At the cost of model code change, we can accelerate Sequence Parallel # by folding (and unfolding) the batch dimension and the sequence dimension. # Examples can be found at https://github.com/pytorch/torchtitan/pull/437 + positions_sharding = Replicate() if positions_enabled else None + # pyrefly: ignore [not-callable] for transformer_block in model.layers.values(): layer_plan = { "attention_norm": SequenceParallel(), "attention": prepare_module_input( - input_layouts=(Shard(1), Replicate(), None), - desired_input_layouts=(Replicate(), Replicate(), None), + input_layouts=(Shard(1), Replicate(), None, positions_sharding), + desired_input_layouts=( + Replicate(), + Replicate(), + None, + positions_sharding, + ), ), "attention.wq": colwise_parallel(use_local_output=False), "attention.wk": colwise_parallel(use_local_output=False), @@ -249,6 +271,7 @@ def apply_non_moe_tp( "ffn_norm": SequenceParallel(), } + # pyrefly: ignore [missing-attribute] if not transformer_block.moe_enabled: layer_plan.update( { @@ -263,16 +286,14 @@ def apply_non_moe_tp( ) parallelize_module( + # pyrefly: ignore [bad-argument-type] module=transformer_block, device_mesh=tp_mesh, parallelize_plan=layer_plan, ) if enable_async_tp: - from torch.distributed._symmetric_memory import enable_symm_mem_for_group - torch._inductor.config._micro_pipeline_tp = True - enable_symm_mem_for_group(tp_mesh.get_group().group_name) logger.info( f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" diff --git a/torchtitan/models/qwen3/model/args.py b/torchtitan/models/qwen3/model/args.py index 26769407bf..d0a0556bf1 100644 --- a/torchtitan/models/qwen3/model/args.py +++ b/torchtitan/models/qwen3/model/args.py @@ -36,7 +36,7 @@ class Qwen3ModelArgs(BaseModelArgs): max_seq_len: int = 4096 depth_init: bool = True - use_flex_attn: bool = False + attn_type: str = "sdpa" attn_mask_type: str = "causal" eos_id: int = 151645 @@ -59,11 +59,5 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: job_config.debug.moe_force_load_balance ) - # Pass DeepEP config to MoE layer and validate - self.moe_args.deepep_config = job_config.deepep - self.moe_args.validate_deepep_config() - - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: return get_moe_model_nparams_and_flops(self, model, 2 * self.head_dim, seq_len) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 5a11eb4f1a..120e8f78fd 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -17,11 +17,15 @@ from torchtitan.config.job_config import PEFT from torchtitan.models.attention import ( create_attention_mask, + create_varlen_metadata_for_document, + create_varlen_metadata_from_sequence_lengths, FlexAttentionWrapper, - get_block_causal_mask_mod_by_seq_lens, get_causal_mask_mod, get_document_mask_mod, + get_block_causal_mask_mod_by_seq_lens, ScaledDotProductAttentionWrapper, + VarlenAttentionWrapper, + VarlenMetadata, ) from torchtitan.models.moe import MoE from torchtitan.protocols.model import AttentionMasksType @@ -57,7 +61,9 @@ def rotate_half(x: torch.Tensor) -> torch.Tensor: return torch.cat((-x2, x1), dim=-1) -def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: +def reshape_for_broadcast( + rope_cache: torch.Tensor, x: torch.Tensor, positions: torch.Tensor | None = None +) -> torch.Tensor: """ Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. @@ -70,34 +76,51 @@ def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Te Args: rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. x (torch.Tensor): Target tensor for broadcasting compatibility. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. + Shape is (1, seqlen) or (bz, seqlen). Defaults to None. Returns: torch.Tensor: Reshaped frequency tensor. """ ndim = x.ndim assert ndim > 1 - _, seqlen, _, head_dim = x.shape - rope_cache = rope_cache[0:seqlen] - # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin - assert rope_cache.shape == (seqlen, head_dim * 2) - shape = [-1, seqlen, 1, head_dim * 2] - return rope_cache.view(*shape) + bz, seqlen, _, head_dim = x.shape + if positions is None: + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + elif positions.size(0) == 1: + assert positions.shape == (1, seqlen) + rope_cache = rope_cache[positions.squeeze(0)] + # The shape of rope_cache is (seqlen, head_dim * 2) + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + else: + assert positions.shape == (bz, seqlen) + rope_cache_expanded = rope_cache[None, :, None, :].expand(bz, -1, -1, -1) + rope_cache = torch.gather( + rope_cache_expanded, + dim=1, + index=positions.view(bz, seqlen, 1, 1).expand(bz, seqlen, 1, head_dim * 2), + ) + # The shape of rope_cache is (bz, seqlen, 1, head_dim * 2) + assert rope_cache.shape == (bz, seqlen, 1, head_dim * 2) + return rope_cache def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: # input tensor x has shape [bsz, seq_len, num_heads, head_dim] head_dim = xq.shape[-1] - # reshape for broadcast or gather custom positions - if position_ids is not None: - rope_cache = rope_cache[position_ids].unsqueeze(2) - else: - rope_cache = reshape_for_broadcast(rope_cache, xq) + rope_cache = reshape_for_broadcast(rope_cache, xq, positions) # [bsz, seq_len, 1, head_dim] cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) @@ -142,6 +165,9 @@ class Attention(nn.Module): """ + q_norm: nn.RMSNorm | None + k_norm: nn.RMSNorm | None + def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT): super().__init__() self.n_heads = model_args.n_heads @@ -153,7 +179,7 @@ def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT): self.n_rep = self.n_heads // self.n_kv_heads self.head_dim = model_args.head_dim self.scaling = self.head_dim**-0.5 - self.use_flex_attn = getattr(model_args, "use_flex_attn", False) + self.attn_type = getattr(model_args, "attn_type", "sdpa") # RMSNorm added here to the here to include the q-k norm # This is one of the main differences between Llama3 and Qwen3 @@ -196,10 +222,17 @@ def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT): self.q_norm.weight.requires_grad = False self.k_norm.weight.requires_grad = False - if self.use_flex_attn: - self.inner_attention = FlexAttentionWrapper() - else: - self.inner_attention = ScaledDotProductAttentionWrapper() + match self.attn_type: + case "flex": + self.inner_attention = FlexAttentionWrapper() + case "varlen": + # pyrefly: ignore [bad-assignment] + self.inner_attention = VarlenAttentionWrapper() + case "sdpa": + # pyrefly: ignore [bad-assignment] + self.inner_attention = ScaledDotProductAttentionWrapper() + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") def init_weights(self, init_std: float): for linear in (self.wq, self.wk, self.wv): @@ -215,16 +248,16 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Forward pass of the attention module. Args: x (torch.Tensor): Input tensor. - rope_cache (torch.Tensor): Cached RoPE values. - attention_masks (AttentionMasksType | None): Optional attention masks. - position_ids (torch.Tensor | None): Optional custom position ids. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after attention. @@ -243,13 +276,13 @@ def forward( # Adding the q_norm and k_norm here # Last layer of adding q-k norm - if self.q_norm: + if self.q_norm: # pyrefly: ignore[invalid-argument] xq = self.q_norm(xq) - if self.k_norm: + if self.k_norm: # pyrefly: ignore[invalid-argument] xk = self.k_norm(xk) # Apply rotary embedding - xq, xk = apply_rotary_emb(xq, xk, rope_cache, position_ids=position_ids) + xq, xk = apply_rotary_emb(xq, xk, rope_cache, positions) # repeat k/v heads if n_kv_heads < n_heads keys = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim) @@ -259,16 +292,34 @@ def forward( xk = keys.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) xv = values.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim) - if self.use_flex_attn: - assert isinstance(attention_masks, BlockMask), attention_masks - output = self.inner_attention(xq, xk, xv, block_mask=attention_masks) - else: - assert attention_masks is None - output = self.inner_attention(xq, xk, xv) - - output = output.transpose( - 1, 2 - ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + match self.attn_type: + case "flex": + assert isinstance(attention_masks, BlockMask), attention_masks + output = self.inner_attention( + xq, xk, xv, block_mask=attention_masks, scale=self.scaling + ) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + case "varlen": + # TODO: pass self.scaling into varlen attention + assert isinstance(attention_masks, VarlenMetadata), attention_masks + output = self.inner_attention( + xq, + xk, + xv, + self.head_dim, + attention_masks, + scale=self.scaling, + ) + case "sdpa": + assert attention_masks is None + output = self.inner_attention(xq, xk, xv, scale=self.scaling) + output = output.transpose( + 1, 2 + ).contiguous() # (bs, seqlen, n_local_heads, head_dim) + case _: + raise ValueError(f"Unknown attention type: {self.attn_type}") output = output.view(bs, seqlen, -1) return self.wo(output) @@ -389,7 +440,7 @@ def forward( x: torch.Tensor, rope_cache: torch.Tensor, attention_masks: AttentionMasksType | None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the TransformerBlock. @@ -397,18 +448,15 @@ def forward( Args: x (torch.Tensor): Input tensor. rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. - attention_masks (AttentionMasksType | None): Optional attention masks. - position_ids (torch.Tensor | None): Optional custom position ids. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ x = x + self.attention( - self.attention_norm(x), - rope_cache, - attention_masks, - position_ids=position_ids, + self.attention_norm(x), rope_cache, attention_masks, positions ) if self.moe_enabled: @@ -428,7 +476,7 @@ def init_weights(self, buffer_device: torch.device): self.feed_forward.init_weights(self.weight_init_std) -class Qwen3Model(nn.Module, ModelProtocol): +class Qwen3Model(ModelProtocol): """ Qwen3Model Module @@ -448,7 +496,7 @@ class Qwen3Model(nn.Module, ModelProtocol): """ def __init__(self, model_args: Qwen3ModelArgs, peft_config: PEFT): - super().__init__() + super().__init__(model_args) self.model_args = model_args self.vocab_size = model_args.vocab_size self.n_layers = model_args.n_layers @@ -508,6 +556,7 @@ def init_weights( nn.init.normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: + # pyrefly: ignore [not-callable] layer.init_weights(buffer_device) if self.norm is not None: self.norm.reset_parameters() @@ -531,7 +580,7 @@ def _precompute_rope_cache(self) -> torch.Tensor: self.model_args.rope_theta, ) - def get_attention_masks( + def _get_flex_attention_masks( self, input_batch: torch.Tensor, tokenizer: BaseTokenizer, @@ -562,11 +611,42 @@ def get_attention_masks( and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] ) + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + match self.model_args.attn_type: + case "flex": + return self._get_flex_attention_masks( + input_batch, tokenizer, extra_inputs + ) + case "varlen": + if self.model_args.attn_mask_type != "block_causal": + raise ValueError( + f"varlen attention is only supported with block_causal \ + attention mask type, got {self.model_args.attn_mask_type}" + ) + # Use explicit sequence_lengths from extra_inputs if available, + # otherwise fall back to EOS-based document detection + if extra_inputs is not None and "sequence_lengths" in extra_inputs: + return create_varlen_metadata_from_sequence_lengths( + extra_inputs["sequence_lengths"], + input_batch.shape[1], + input_batch.device, + ) + return create_varlen_metadata_for_document( + input_batch, tokenizer.eos_id + ) + case _: + raise TypeError("Only varlen and flex attn masks are supported") + def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, - position_ids: torch.Tensor | None = None, + positions: torch.Tensor | None = None, ): """ Perform a forward pass through the Transformer model. @@ -576,24 +656,22 @@ def forward( If pipeline parallelism is enabled, this will be the input token indices for the ranks on the first pipeline stage. This will be the activation of the previous pipeline stage if the current rank is not on the first stage. - attention_masks (AttentionMasksType | None): Optional attention masks. - position_ids (torch.Tensor | None): Optional custom position ids. + attention_masks (AttentionMasksType | None): Masks used when calculating attention scores. + positions (torch.Tensor | None): Position indices used to access/shuffle RoPE cache. Defaults to None. Returns: torch.Tensor: Output logits after applying the Transformer model. """ # passthrough for nonexistent layers, allows easy configuration of pipeline parallel stages + # pyrefly: ignore[not-callable, invalid-argument] h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens for layer in self.layers.values(): - h = layer( - h, - self.rope_cache, - attention_masks, - position_ids=position_ids, - ) + h = layer(h, self.rope_cache, attention_masks, positions) + # pyrefly: ignore[not-callable, invalid-argument] h = self.norm(h) if self.norm else h + # pyrefly: ignore[not-callable, invalid-argument] output = self.output(h) if self.output else h return output diff --git a/torchtitan/models/qwen3/model/state_dict_adapter.py b/torchtitan/models/qwen3/model/state_dict_adapter.py index 11bb8058c0..1fcd51081c 100644 --- a/torchtitan/models/qwen3/model/state_dict_adapter.py +++ b/torchtitan/models/qwen3/model/state_dict_adapter.py @@ -63,6 +63,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_abstract_key = to_hf_map[abstract_key] @@ -72,6 +73,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key ] = value.placements self.grouped_expert_weight_shape[abstract_key] = value.shape + self.grouped_expert_weight_mesh[abstract_key] = value.device_mesh # Split GroupedExperts weight to local individual expert weights local_expert_fqn = self._get_local_experts_weights( @@ -85,9 +87,12 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: # keep this path for offline conversion split_values = self._split_experts_weights( - value, self.model_args.moe_args.num_experts + value, + # pyrefly: ignore [missing-attribute] + self.model_args.moe_args.num_experts, ) + # pyrefly: ignore [missing-attribute] for expert_num in range(self.model_args.moe_args.num_experts): new_key = new_abstract_key.format(layer_num, expert_num) hf_state_dict[new_key] = split_values[expert_num].squeeze() @@ -96,6 +101,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) if abstract_key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = to_hf_map[abstract_key] new_key = new_key.format(layer_num) @@ -104,6 +110,7 @@ def to_hf(self, state_dict: dict[str, Any]) -> dict[str, Any]: else: if key not in to_hf_map: continue + # pyrefly: ignore [missing-attribute] if self.model_args.enable_weight_tying and key == "output.weight": continue new_key = to_hf_map[key] @@ -121,6 +128,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: expert_weights_by_layer = {} # {layer: {abstract_key: {expert_id: tensor}}} if ( + # pyrefly: ignore [missing-attribute] self.model_args.enable_weight_tying and "lm_head.weight" not in hf_state_dict ): @@ -132,6 +140,7 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: abstract_key = re.sub(r"(\d+)", "{}", key, count=2) layer_num, expert_num = re.findall(r"\d+", key) titan_abstract_key = self.from_hf_map[abstract_key] + assert titan_abstract_key is not None new_key = titan_abstract_key.format(layer_num) # Store the expert's weight in expert_weights_by_layer for concatenating later. @@ -143,18 +152,20 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: int(expert_num) ] = value - if isinstance(value, DTensor): + # Use stored metadata to decide path (online vs offline) + # Online mode: local_experts_indices was populated during to_hf() + if titan_abstract_key in self.local_experts_indices: stacked_value = self._concatenate_expert_weights_dtensor( expert_weights_by_layer, titan_abstract_key, layer_num, - value.device_mesh, ) else: # keep this path to be compatible with offline conversion stacked_value = self._concatenate_expert_weights( expert_weights_by_layer, titan_abstract_key, layer_num, + # pyrefly: ignore [missing-attribute] self.model_args.moe_args.num_experts, ) @@ -163,13 +174,16 @@ def from_hf(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: elif "layers" in key: abstract_key = re.sub(r"(\d+)", "{}", key, count=1) + # pyrefly: ignore [missing-attribute] layer_num = re.search(r"\d+", key).group(0) new_key = self.from_hf_map[abstract_key] + # pyrefly: ignore [missing-attribute] new_key = new_key.format(layer_num) state_dict[new_key] = value else: new_key = self.from_hf_map[key] + # pyrefly: ignore [unsupported-operation] state_dict[new_key] = value return state_dict diff --git a/torchtitan/models/utils.py b/torchtitan/models/utils.py index addfa17421..1e5befbf70 100644 --- a/torchtitan/models/utils.py +++ b/torchtitan/models/utils.py @@ -37,6 +37,7 @@ def __init__( # Store metadata for GroupedExperts <-> individual experts conversion self.grouped_expert_weight_placements = {} # {titan_abstract_key: placements} self.grouped_expert_weight_shape = {} # {titan_abstract_key: shape} + self.grouped_expert_weight_mesh = {} # {titan_abstract_key: device_mesh} self.local_experts_indices = {} # {titan_abstract_key: (start_idx, end_idx)} def _calculate_strided_shard_shard_indices( @@ -102,6 +103,7 @@ def _caculate_indices_from_placements( dim_i_placements = [] # Find all the device mesh dimensios that shard on dim-i + # pyrefly: ignore [bad-argument-type] for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] if placement.dim == dim: @@ -109,7 +111,7 @@ def _caculate_indices_from_placements( dim_i_placements.append(placement) # Calculate local expert indices based on sharding strategy - start_index, end_index = None, None + start_index, end_index = 0, dim_size if len(dim_i_placements) == 2: # Handle StridedShard(i) + Shard(i) case assert isinstance( @@ -148,8 +150,8 @@ def _caculate_indices_from_placements( end_index = start_index + block_size elif len(dim_i_placements) == 0: - # No need to split on this dimension - return start_index, end_index + # No sharding on this dimension means all elements are local + pass else: raise NotImplementedError( @@ -179,9 +181,11 @@ def _get_local_experts_weights( grouped_expert_weight: DTensor containing all experts' weights Returns: - Dictionary mapping individual expert keys to their DTensor weights + Dictionary mapping individual expert keys to their DTensor or plain tensor weights """ + # pyrefly: ignore [missing-attribute] device_mesh = grouped_expert_weight.device_mesh + # pyrefly: ignore [missing-attribute] dtensor_placements = grouped_expert_weight.placements # Step 1: Extract dimension-0 placement information @@ -192,32 +196,44 @@ def _get_local_experts_weights( dtensor_placements=dtensor_placements, device_mesh=device_mesh, ) - assert ( - start_index is not None and end_index is not None - ), "Start index and end index can not be None on dim-0!" # Step 2: Store indices for potential future use in from_hf() self.local_experts_indices[titan_abstract_key] = (start_index, end_index) - # Step 3: Create new placements for individual expert weights - new_placements = [] + # Step 3: Identify mesh dimensions that shard on dim-0 (expert dimension) + # exclude expert dimension + # and build new sub-mesh/placements for individual expert weights + sub_mesh_names = [] + sub_placements = [] + for i, name in enumerate(device_mesh.mesh_dim_names): placement = dtensor_placements[i] - if placement.dim == 0: - # Convert dim-0 sharding to replication for individual experts - new_placements.append(Replicate()) + if isinstance(placement, Replicate): + # Replicate (hybrid) doesn't shard any dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append(Replicate()) + elif isinstance(placement, (Shard, _StridedShard)) and placement.dim == 0: + # Shards on expert dim, exclude from sub-mesh + pass elif isinstance(placement, Shard): - # Keep other shard dimensions (individual expert weight has 2D) - new_placements.append(Shard(placement.dim)) + # Shards on non-expert dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append(Shard(placement.dim)) elif isinstance(placement, _StridedShard): - # Keep strided shard with same parameters - new_placements.append( + # Strided shard on non-expert dim, keep in sub-mesh + sub_mesh_names.append(name) + sub_placements.append( + # pyrefly: ignore [unexpected-positional-argument] _StridedShard(placement.dim, placement.split_factor) ) else: raise ValueError(f"Unsupported placement type: {type(placement)}") - # Step 4: Create individual expert DTensors + # Step 4: Create sub-mesh excluding dim-0 sharding dimensions + # If all mesh dimensions were sharding on dim-0, sub_mesh will be None (use plain tensors) + sub_mesh = device_mesh[tuple(sub_mesh_names)] if sub_mesh_names else None + + # Step 5: Create individual expert tensors assert isinstance( grouped_expert_weight, DTensor ), "Expected DTensor for grouped expert weight" @@ -236,15 +252,21 @@ def _get_local_experts_weights( expert_key = abstract_key.format(layer_id, expert_id) local_expert_index = expert_id - start_index - # Extract individual expert weight and add batch dimension temporarily - expert_weight = local_grouped_weights[local_expert_index, :, :].unsqueeze(0) - - # Create DTensor and remove batch dimension (experts dimension is removed) - expert_dtensor = DTensor.from_local( - expert_weight, device_mesh, new_placements, run_check=False - ).squeeze(0) - - local_expert_tensors[expert_key] = expert_dtensor + if sub_mesh is None: + # Extract individual expert weight (2D) as plain tensor + expert_weight = local_grouped_weights[local_expert_index, :, :] + else: + # Use slicing and unsqueeze get a 3D tensor, then create DTensor and squeeze + expert_weight_3d = local_grouped_weights[ + local_expert_index, :, : + ].unsqueeze(0) + expert_weight = DTensor.from_local( + expert_weight_3d, + sub_mesh, + sub_placements, + run_check=False, + ).squeeze(0) + local_expert_tensors[expert_key] = expert_weight return local_expert_tensors @@ -253,7 +275,6 @@ def _concatenate_expert_weights_dtensor( expert_weights_by_layer: dict[str, dict[str, dict[int, torch.Tensor]]], abstract_key: str, layer_num: str, - device_mesh: DeviceMesh, ) -> torch.Tensor | None: """ Args: @@ -268,7 +289,6 @@ def _concatenate_expert_weights_dtensor( Used to collect individual expert weights before concatenating them into GroupedExperts. abstract_key: TorchTitan templage key with {} placeholders for layer and expert IDs layer_num: Layer identifier - device_mesh: DeviceMesh for the target GroupedExperts weight DTensor Returns: Concatenated GroupedExperts weight DTensor if all experts are available, otherwise None @@ -284,16 +304,21 @@ def _concatenate_expert_weights_dtensor( sorted_expert_ids = sorted(experts.keys()) sorted_experts = [experts[i] for i in sorted_expert_ids] - local_tensor = torch.stack(sorted_experts, dim=0)._local_tensor + + # Stack experts - result may be DTensor or plain tensor depending on sub_mesh + local_tensor = torch.stack(sorted_experts, dim=0) + if isinstance(local_tensor, DTensor): + local_tensor = local_tensor._local_tensor assert ( abstract_key in self.grouped_expert_weight_placements and abstract_key in self.grouped_expert_weight_shape - ), "GroupedExperts weight metadata (placements, shape) can not be None!" + and abstract_key in self.grouped_expert_weight_mesh + ), "GroupedExperts weight metadata (placements, shape, mesh) can not be None!" stacked_dtensor = DTensor.from_local( local_tensor, - device_mesh, + self.grouped_expert_weight_mesh[abstract_key], self.grouped_expert_weight_placements[abstract_key], run_check=False, ) @@ -306,7 +331,7 @@ def _concatenate_expert_weights_dtensor( def _split_experts_weights( self, weight: torch.Tensor, n_experts: int - ) -> list[torch.Tensor]: + ) -> tuple[torch.Tensor, ...]: """ Split the weights of the experts into a list of tensors. Used for offline conversion. @@ -365,7 +390,7 @@ def get_dense_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Args: model_args: BaseModelArgs object containing model configuration parameters. @@ -395,6 +420,7 @@ def get_dense_model_nparams_and_flops( # 4. we follow the convention and do not account for sparsity in causal attention num_flops_per_token = ( 6 * (nparams - nparams_embedding) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) @@ -410,7 +436,7 @@ def get_moe_model_nparams_and_flops( model: nn.Module, head_dims: int, seq_len: int, -) -> tuple[int, float]: +) -> tuple[int, int]: """ Calculate nparams and nflops for MoE models. @@ -450,6 +476,7 @@ def get_moe_model_nparams_and_flops( nparams_sparse_active = ( nparams_moe_router + nparams_shared_experts + # pyrefly: ignore [missing-attribute] + nparams_experts * model_args.moe_args.top_k // model_args.moe_args.num_experts ) @@ -460,6 +487,7 @@ def get_moe_model_nparams_and_flops( num_flops_per_token = ( 6 * (nparams_dense - nparams_embedding + nparams_sparse_active) + # pyrefly: ignore [missing-attribute] + 6 * model_args.n_layers * model_args.n_heads * head_dims * seq_len ) diff --git a/torchtitan/protocols/model.py b/torchtitan/protocols/model.py index a713bec65b..99e4c34dc0 100644 --- a/torchtitan/protocols/model.py +++ b/torchtitan/protocols/model.py @@ -6,7 +6,6 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Protocol import torch import torch.nn as nn @@ -16,9 +15,10 @@ from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config import JobConfig +from torchtitan.models.attention import VarlenMetadata -AttentionMasksType = dict[str, BlockMask] | BlockMask +AttentionMasksType = dict[str, BlockMask] | BlockMask | VarlenMetadata @dataclass @@ -36,21 +36,22 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: pass @abstractmethod - def get_nparams_and_flops( - self, model: nn.Module, seq_len: int - ) -> tuple[int, float]: + def get_nparams_and_flops(self, model: nn.Module, seq_len: int) -> tuple[int, int]: pass -class ModelProtocol(Protocol): +class ModelProtocol(nn.Module): """Defines the interface for a model class. This is used to enforce that all model classes have some methods that are required by the trainer. + + NOTE: We keep protocol name for backward compatibility even though it is + not a Protocol anymore. """ def __init__(self, model_args: BaseModelArgs) -> None: - pass + super().__init__() @abstractmethod def init_weights(self, buffer_device: torch.device | None = None) -> None: diff --git a/torchtitan/protocols/model_converter.py b/torchtitan/protocols/model_converter.py index dbfc3a99c3..cb4804be6f 100644 --- a/torchtitan/protocols/model_converter.py +++ b/torchtitan/protocols/model_converter.py @@ -62,7 +62,9 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims): _registry_model_converter_cls[name] for name in job_config.model.converters ] self.converters = [ - mh_cls(job_config, parallel_dims) for mh_cls in converter_classes + # pyrefly: ignore[bad-instantiation] + mh_cls(job_config, parallel_dims) + for mh_cls in converter_classes ] self.print_after_conversion = job_config.model.print_after_conversion diff --git a/torchtitan/protocols/state_dict_adapter.py b/torchtitan/protocols/state_dict_adapter.py index e22692bd52..7b2b3ef3ad 100644 --- a/torchtitan/protocols/state_dict_adapter.py +++ b/torchtitan/protocols/state_dict_adapter.py @@ -8,7 +8,7 @@ import os import re from abc import ABC, abstractmethod -from typing import Any +from typing import Any, Dict from torch.distributed.checkpoint import HuggingFaceStorageReader @@ -27,6 +27,8 @@ class BaseStateDictAdapter(ABC): hf_assets_path: path to HF assets folder containing tokenizer, model weights, etc. """ + fqn_to_index_mapping: Dict[Any, int] | None + @abstractmethod def __init__( self, @@ -98,6 +100,7 @@ def __init__( if hf_safetensors_indx: self.fqn_to_index_mapping = {} for hf_key, raw_indx in hf_safetensors_indx["weight_map"].items(): + # pyrefly: ignore [missing-attribute] indx = re.search(r"\d+", raw_indx).group(0) self.fqn_to_index_mapping[hf_key] = int(indx) else: diff --git a/torchtitan/tools/aggressive_memory_manager.py b/torchtitan/tools/aggressive_memory_manager.py new file mode 100644 index 0000000000..1c4861cb74 --- /dev/null +++ b/torchtitan/tools/aggressive_memory_manager.py @@ -0,0 +1,414 @@ +""" +Aggressive Memory Manager for reducing CUDA memory fragmentation. + +This module provides aggressive memory clearing strategies to minimize +fragmentation and allocation retries during distributed training. + +Usage: + from torchtitan.tools.aggressive_memory_manager import AggressiveMemoryManager + + # Initialize at start of training + mem_manager = AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + sync_before_clear=True, + defrag_threshold_mb=1000, # Defrag if fragmentation > 1GB + ) + + # In training loop: + loss.backward() + mem_manager.post_backward() + + optimizer.step() + mem_manager.post_optimizer() + + mem_manager.step_complete() +""" + +import gc +import os +import time +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + +from torchtitan.tools.logging import logger + + +@dataclass +class MemoryStats: + """Current memory statistics""" + + allocated: int + reserved: int + active: int + fragmentation: int + fragmentation_pct: float + num_alloc_retries: int + + +class AggressiveMemoryManager: + """ + Aggressive memory management to minimize CUDA memory fragmentation. + + Key strategies: + 1. Clear cache at strategic points (post-backward, post-optimizer) + 2. Synchronize before clearing to ensure all async ops complete + 3. Force garbage collection to release Python references + 4. Monitor fragmentation and trigger defrag when threshold exceeded + 5. Set optimal allocator configuration + + Args: + clear_after_backward: Clear cache after backward pass + clear_after_optimizer: Clear cache after optimizer step + clear_every_n_steps: Only clear every N steps (1 = every step) + sync_before_clear: Synchronize CUDA before clearing cache + defrag_threshold_mb: Trigger defrag if fragmentation exceeds this (MB) + gc_generation: Python GC generation to collect (0-2, higher = more thorough) + verbose: Log detailed memory stats + rank: Distributed rank (auto-detected if None) + """ + + def __init__( + self, + clear_after_backward: bool = True, + clear_after_optimizer: bool = True, + clear_every_n_steps: int = 1, + sync_before_clear: bool = True, + defrag_threshold_mb: float = 500.0, + gc_generation: int = 1, + verbose: bool = False, + rank: Optional[int] = None, + ): + self.clear_after_backward = clear_after_backward + self.clear_after_optimizer = clear_after_optimizer + self.clear_every_n_steps = clear_every_n_steps + self.sync_before_clear = sync_before_clear + self.defrag_threshold_mb = defrag_threshold_mb + self.gc_generation = gc_generation + self.verbose = verbose + + self.rank = ( + rank + if rank is not None + else (dist.get_rank() if dist.is_initialized() else 0) + ) + + self.step_count = 0 + self.total_clears = 0 + self.total_defrag_time_ms = 0.0 + + # Disable automatic GC - we'll control it manually + gc.disable() + + # Initial cleanup + self._aggressive_clear("initialization") + + if self.rank == 0: + logger.info( + f"[AggressiveMemoryManager] Initialized: " + f"clear_backward={clear_after_backward}, " + f"clear_optimizer={clear_after_optimizer}, " + f"every_n_steps={clear_every_n_steps}, " + f"sync={sync_before_clear}, " + f"defrag_threshold={defrag_threshold_mb}MB" + ) + + @staticmethod + def configure_allocator( + expandable_segments: bool = True, + max_split_size_mb: int = 128, + garbage_collection_threshold: float = 0.8, + roundup_power2_divisions: int = 4, + ) -> str: + """ + Configure PyTorch CUDA allocator for minimal fragmentation. + + Call this BEFORE any CUDA operations (before model creation). + + Args: + expandable_segments: Enable expandable memory segments + max_split_size_mb: Max size of memory splits (smaller = less fragmentation) + garbage_collection_threshold: Trigger GC when this fraction of memory is fragmented + roundup_power2_divisions: Memory rounding granularity + + Returns: + The PYTORCH_CUDA_ALLOC_CONF string that was set + """ + config_parts = [] + + if expandable_segments: + config_parts.append("expandable_segments:True") + + config_parts.append(f"max_split_size_mb:{max_split_size_mb}") + config_parts.append( + f"garbage_collection_threshold:{garbage_collection_threshold}" + ) + config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") + + config_str = ",".join(config_parts) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_str + + return config_str + + def get_memory_stats(self) -> MemoryStats: + """Get current memory statistics""" + if not torch.cuda.is_available(): + return MemoryStats(0, 0, 0, 0, 0.0, 0) + + stats = torch.cuda.memory_stats() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + active = stats.get("active_bytes.all.current", 0) + fragmentation = reserved - allocated + fragmentation_pct = (fragmentation / reserved * 100) if reserved > 0 else 0.0 + num_retries = stats.get("num_alloc_retries", 0) + + return MemoryStats( + allocated=allocated, + reserved=reserved, + active=active, + fragmentation=fragmentation, + fragmentation_pct=fragmentation_pct, + num_alloc_retries=num_retries, + ) + + def _should_clear(self) -> bool: + """Check if we should clear cache this step""" + return self.step_count % self.clear_every_n_steps == 0 + + def _aggressive_clear(self, reason: str) -> float: + """ + Perform aggressive memory clearing. + + Returns: + Time taken in milliseconds + """ + if not torch.cuda.is_available(): + return 0.0 + + start = time.perf_counter() + + # 1. Synchronize all CUDA streams to ensure ops complete + if self.sync_before_clear: + torch.cuda.synchronize() + + # 2. Python garbage collection (releases tensor references) + gc.collect(self.gc_generation) + + # 3. Clear CUDA cache (releases unused cached memory) + torch.cuda.empty_cache() + + # 4. Optional: Force synchronization after clear + if self.sync_before_clear: + torch.cuda.synchronize() + + elapsed_ms = (time.perf_counter() - start) * 1000 + self.total_clears += 1 + self.total_defrag_time_ms += elapsed_ms + + if self.verbose and self.rank == 0: + stats = self.get_memory_stats() + logger.info( + f"[AggressiveMemoryManager] {reason}: " + f"cleared in {elapsed_ms:.1f}ms, " + f"frag={stats.fragmentation_pct:.1f}%, " + f"reserved={stats.reserved/1e9:.2f}GB" + ) + + return elapsed_ms + + def _check_and_defrag(self, phase: str) -> bool: + """ + Check fragmentation and defrag if needed. + + Returns: + True if defrag was triggered + """ + stats = self.get_memory_stats() + fragmentation_mb = stats.fragmentation / (1024 * 1024) + + if fragmentation_mb > self.defrag_threshold_mb: + self._aggressive_clear(f"defrag_{phase}_frag={fragmentation_mb:.0f}MB") + return True + + return False + + def post_backward(self): + """Call after backward pass completes""" + if self.clear_after_backward and self._should_clear(): + self._check_and_defrag("post_backward") + self._aggressive_clear("post_backward") + + def post_optimizer(self): + """Call after optimizer step completes""" + if self.clear_after_optimizer and self._should_clear(): + self._check_and_defrag("post_optimizer") + self._aggressive_clear("post_optimizer") + + def step_complete(self): + """Call at the end of each training step""" + self.step_count += 1 + + # Always check for high fragmentation + self._check_and_defrag("step_end") + + def get_summary(self) -> str: + """Get summary of memory management activity""" + avg_time = self.total_defrag_time_ms / max(1, self.total_clears) + return ( + f"AggressiveMemoryManager Summary:\n" + f" Total clears: {self.total_clears}\n" + f" Total defrag time: {self.total_defrag_time_ms:.1f}ms\n" + f" Avg time per clear: {avg_time:.2f}ms\n" + f" Steps processed: {self.step_count}" + ) + + +class BackwardMemoryHook: + """ + Register hooks on model parameters to clear memory during backward pass. + + This clears memory incrementally as gradients are computed, rather than + waiting until the end of backward. + + Args: + clear_every_n_params: Clear cache after every N parameter gradients + sync_on_clear: Synchronize before clearing (slower but more thorough) + """ + + def __init__( + self, + clear_every_n_params: int = 10, + sync_on_clear: bool = False, + ): + self.clear_every_n_params = clear_every_n_params + self.sync_on_clear = sync_on_clear + self.param_count = 0 + self.handles = [] + + def _backward_hook(self, grad): + """Hook called when gradient is computed for a parameter""" + self.param_count += 1 + + if self.param_count % self.clear_every_n_params == 0: + if self.sync_on_clear: + torch.cuda.synchronize() + gc.collect(0) # Fast GC (generation 0 only) + torch.cuda.empty_cache() + + return grad + + def register(self, model: torch.nn.Module): + """Register hooks on all model parameters""" + for name, param in model.named_parameters(): + if param.requires_grad: + handle = param.register_post_accumulate_grad_hook( + lambda p, name=name: self._backward_hook(p.grad) + ) + self.handles.append(handle) + + logger.info( + f"[BackwardMemoryHook] Registered on {len(self.handles)} parameters, " + f"clearing every {self.clear_every_n_params} params" + ) + + def remove(self): + """Remove all registered hooks""" + for handle in self.handles: + handle.remove() + self.handles.clear() + + def reset_count(self): + """Reset parameter count (call at start of each backward)""" + self.param_count = 0 + + +def setup_aggressive_memory_environment(): + """ + Set up environment variables for aggressive memory management. + + Call this BEFORE importing torch or creating any CUDA tensors. + """ + # Optimal allocator settings for minimal fragmentation + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( + "expandable_segments:True," + "max_split_size_mb:128," + "garbage_collection_threshold:0.8," + "roundup_power2_divisions:4" + ) + + # Disable NCCL async error handling (can cause memory issues) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + # Force synchronous CUDA operations for debugging + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # Uncomment for debugging + + return os.environ.get("PYTORCH_CUDA_ALLOC_CONF") + + +# Convenience function for quick setup +def create_aggressive_memory_manager( + mode: str = "balanced", + verbose: bool = False, +) -> AggressiveMemoryManager: + """ + Create an AggressiveMemoryManager with preset configurations. + + Args: + mode: One of: + - "minimal": Only clear on high fragmentation + - "balanced": Clear after backward and optimizer + - "aggressive": Clear frequently with sync + - "maximum": Clear after every operation + verbose: Enable verbose logging + + Returns: + Configured AggressiveMemoryManager + """ + if mode == "minimal": + return AggressiveMemoryManager( + clear_after_backward=False, + clear_after_optimizer=False, + clear_every_n_steps=10, + sync_before_clear=False, + defrag_threshold_mb=2000, + gc_generation=0, + verbose=verbose, + ) + elif mode == "balanced": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=False, + defrag_threshold_mb=500, + gc_generation=1, + verbose=verbose, + ) + elif mode == "aggressive": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=200, + gc_generation=2, + verbose=verbose, + ) + elif mode == "maximum": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=100, + gc_generation=2, + verbose=verbose, + ) + else: + raise ValueError( + f"Unknown mode: {mode}. Use minimal/balanced/aggressive/maximum" + ) diff --git a/torchtitan/tools/cuda_memory_tracker.py b/torchtitan/tools/cuda_memory_tracker.py new file mode 100644 index 0000000000..0f7d7af5f4 --- /dev/null +++ b/torchtitan/tools/cuda_memory_tracker.py @@ -0,0 +1,123 @@ +"""Track CUDA memory directly from nvidia-smi and PyTorch""" +import logging +import subprocess +from typing import Dict, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class CUDAMemoryTracker: + """Track memory from both PyTorch and CUDA/nvidia-smi""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.device = torch.cuda.current_device() + self.device_name = torch.cuda.get_device_name(self.device) + + if self.enabled: + logger.info( + f"CUDAMemoryTracker enabled for device {self.device}: {self.device_name}" + ) + + def get_nvidia_smi_memory(self) -> Optional[Dict[str, int]]: + """Get memory from nvidia-smi""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.free,memory.total", + "--format=csv,noheader,nounits", + "-i", + str(self.device), + ], + capture_output=True, + text=True, + timeout=2, + ) + + if result.returncode == 0: + used, free, total = map(int, result.stdout.strip().split(",")) + return {"used_mb": used, "free_mb": free, "total_mb": total} + except Exception as e: + logger.warning(f"Failed to get nvidia-smi memory: {e}") + + return None + + def get_pytorch_memory(self) -> Dict[str, int]: + """Get memory from PyTorch""" + stats = torch.cuda.memory_stats(self.device) + + return { + "reserved_bytes": torch.cuda.memory_reserved(self.device), + "allocated_bytes": torch.cuda.memory_allocated(self.device), + "active_bytes": stats.get("active_bytes.all.current", 0), + "inactive_bytes": stats.get("inactive_split_bytes.all.current", 0), + "peak_active_bytes": stats.get("active_bytes.all.peak", 0), + "num_alloc_retries": stats.get("num_alloc_retries.all.current", 0), + "num_ooms": stats.get("num_ooms.all.current", 0), + } + + def get_cuda_device_memory(self) -> Dict[str, int]: + """Get memory directly from CUDA device properties""" + props = torch.cuda.get_device_properties(self.device) + + return { + "total_memory": props.total_memory, + "reserved_memory": torch.cuda.memory_reserved(self.device), + "allocated_memory": torch.cuda.memory_allocated(self.device), + } + + def measure_all(self, phase: str, step: int): + """Comprehensive memory measurement""" + if not self.enabled: + return + + # PyTorch memory + pytorch_mem = self.get_pytorch_memory() + + # CUDA device memory + cuda_mem = self.get_cuda_device_memory() + + # nvidia-smi memory (if available) + smi_mem = self.get_nvidia_smi_memory() + + # Calculate fragmentation + reserved = pytorch_mem["reserved_bytes"] + allocated = pytorch_mem["allocated_bytes"] + active = pytorch_mem["active_bytes"] + + fragmentation = reserved - allocated + frag_pct = (fragmentation / reserved * 100) if reserved > 0 else 0 + + # Log PyTorch view + logger.info( + f"[PyTorch] Step {step:2d} | {phase:25s} | " + f"Reserved: {reserved/1e9:6.2f} GB | " + f"Allocated: {allocated/1e6:8.2f} MB | " + f"Active: {active/1e6:8.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + # Log CUDA/nvidia-smi view + if smi_mem: + logger.info( + f"[CUDA-SMI] Step {step:2d} | {phase:25s} | " + f"Used: {smi_mem['used_mb']/1024:6.2f} GB | " + f"Free: {smi_mem['free_mb']/1024:6.2f} GB | " + f"Total: {smi_mem['total_mb']/1024:6.2f} GB" + ) + + # Log comparison + if smi_mem: + pytorch_used_gb = reserved / 1e9 + smi_used_gb = smi_mem["used_mb"] / 1024 + diff_gb = smi_used_gb - pytorch_used_gb + + logger.info( + f"[Compare] Step {step:2d} | {phase:25s} | " + f"PyTorch reports: {pytorch_used_gb:6.2f} GB | " + f"nvidia-smi reports: {smi_used_gb:6.2f} GB | " + f"Diff: {diff_gb:+6.2f} GB" + ) diff --git a/torchtitan/tools/detailed_memory_tracker.py b/torchtitan/tools/detailed_memory_tracker.py new file mode 100644 index 0000000000..7b513b3e20 --- /dev/null +++ b/torchtitan/tools/detailed_memory_tracker.py @@ -0,0 +1,160 @@ +"""Detailed memory tracking throughout training step""" +import logging +from typing import Dict, List + +import torch + +logger = logging.getLogger(__name__) + + +class DetailedMemoryTracker: + """Track memory at every phase of training with cache clearing""" + + def __init__(self, enabled: bool = True, clear_cache: bool = True): + self.enabled = enabled + self.clear_cache_between_steps = clear_cache + self.measurements: List[Dict] = [] + self.device = torch.cuda.current_device() + + if self.enabled: + logger.info(f"DetailedMemoryTracker enabled (clear_cache={clear_cache})") + + def measure(self, phase: str, step: int): + """Capture memory state at a specific phase""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + measurement = { + "step": step, + "phase": phase, + "reserved": torch.cuda.memory_reserved(self.device), + "allocated": torch.cuda.memory_allocated(self.device), + "active": stats.get("active_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc_retries.all.current", 0), + } + + self.measurements.append(measurement) + + # Calculate fragmentation + fragmentation = measurement["reserved"] - measurement["allocated"] + frag_pct = ( + (fragmentation / measurement["reserved"] * 100) + if measurement["reserved"] > 0 + else 0 + ) + + logger.info( + f"[MemTrack] Step {step} | {phase:20s} | " + f"Reserved: {measurement['reserved']/1e9:6.2f} GB | " + f"Allocated: {measurement['allocated']/1e6:7.2f} MB | " + f"Active: {measurement['active']/1e6:7.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + def clear_cache_and_measure(self, phase: str, step: int): + """Clear cache and measure to see minimum memory""" + if not self.enabled: + return + + # Measure before clearing + self.measure(f"{phase}_before_clear", step) + + # Clear cache + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Measure after clearing + self.measure(f"{phase}_after_clear", step) + + def step_complete(self, step: int): + """Called after each training step""" + if not self.enabled: + return + + if self.clear_cache_between_steps: + self.clear_cache_and_measure("step_end", step) + + def get_summary(self) -> str: + """Get summary of all measurements""" + if not self.measurements: + return "No measurements recorded" + + summary = ["", "=" * 100, "DETAILED MEMORY TRACKING SUMMARY", "=" * 100, ""] + + # Group by step + steps = {} + for m in self.measurements: + step = m["step"] + if step not in steps: + steps[step] = [] + steps[step].append(m) + + for step, measures in sorted(steps.items()): + summary.append(f"\nStep {step}:") + summary.append( + f"{'Phase':<30} {'Reserved':>12} {'Allocated':>12} {'Active':>12} {'Frag%':>8}" + ) + summary.append("-" * 80) + + for m in measures: + frag_pct = ( + ((m["reserved"] - m["allocated"]) / m["reserved"] * 100) + if m["reserved"] > 0 + else 0 + ) + summary.append( + f"{m['phase']:<30} " + f"{m['reserved']/1e9:10.2f} GB " + f"{m['allocated']/1e6:10.2f} MB " + f"{m['active']/1e6:10.2f} MB " + f"{frag_pct:7.1f}%" + ) + + # Peak measurements + summary.append("\n" + "=" * 100) + summary.append("PEAK MEASUREMENTS ACROSS ALL STEPS:") + summary.append("=" * 100) + + peak_reserved = max(m["reserved"] for m in self.measurements) + peak_allocated = max(m["allocated"] for m in self.measurements) + peak_active = max(m["active"] for m in self.measurements) + + peak_reserved_phase = [ + m for m in self.measurements if m["reserved"] == peak_reserved + ][0] + peak_allocated_phase = [ + m for m in self.measurements if m["allocated"] == peak_allocated + ][0] + peak_active_phase = [ + m for m in self.measurements if m["active"] == peak_active + ][0] + + summary.append( + f"Peak Reserved: {peak_reserved/1e9:7.2f} GB at Step {peak_reserved_phase['step']} ({peak_reserved_phase['phase']})" + ) + step = peak_allocated_phase["step"] + phase = peak_allocated_phase["phase"] + summary.append( + f"Peak Allocated: {peak_allocated/1e6:7.2f} MB at Step {step} ({phase})" + ) + summary.append( + f"Peak Active: {peak_active/1e6:7.2f} MB at Step {peak_active_phase['step']} ({peak_active_phase['phase']})" + ) + + # Minimum after cache clear + cleared_measures = [m for m in self.measurements if "after_clear" in m["phase"]] + if cleared_measures: + min_reserved_cleared = min(m["reserved"] for m in cleared_measures) + min_measure = [ + m for m in cleared_measures if m["reserved"] == min_reserved_cleared + ][0] + summary.append( + f"\nMinimum Reserved (after cache clear): {min_reserved_cleared/1e9:7.2f} GB at Step {min_measure['step']}" + ) + summary.append(f" Active at minimum: {min_measure['active']/1e6:7.2f} MB") + + summary.append("=" * 100) + return "\n".join(summary) diff --git a/torchtitan/tools/mesh_visualizer.py b/torchtitan/tools/mesh_visualizer.py new file mode 100644 index 0000000000..0ba8fecb03 --- /dev/null +++ b/torchtitan/tools/mesh_visualizer.py @@ -0,0 +1,415 @@ +""" +Device Mesh Visualizer for Distributed Training + +Creates comprehensive visualization of how GPUs are allocated across +all parallelism dimensions: DP, PP, TP, CP, EP. +""" + +import os +from typing import Dict + +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.tools.logging import logger + + +def get_rank_info() -> Dict: + """Get current rank's information across all process groups.""" + info = { + "global_rank": dist.get_rank() if dist.is_initialized() else 0, + "world_size": dist.get_world_size() if dist.is_initialized() else 1, + "local_rank": int(os.environ.get("LOCAL_RANK", 0)), + "node_rank": int(os.environ.get("GROUP_RANK", os.environ.get("NODE_RANK", 0))), + } + return info + + +def visualize_mesh_structure( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a detailed text visualization of the device mesh structure. + + Args: + mesh: The DeviceMesh object + parallel_dims: ParallelDims object with all parallelism settings + rank: Current rank (only rank 0 prints full visualization) + + Returns: + String visualization of the mesh + """ + lines = [] + lines.append("=" * 100) + lines.append("DEVICE MESH VISUALIZATION") + lines.append("=" * 100) + + # Basic info + lines.append("\n[CLUSTER INFO]") + lines.append(f" Total GPUs: {parallel_dims.world_size}") + lines.append(f" Nodes: {parallel_dims.world_size // 8} (assuming 8 GPUs/node)") + + # Parallelism dimensions + lines.append("\n[PARALLELISM DIMENSIONS]") + lines.append(f" DP Replicate (HSDP): {parallel_dims.dp_replicate}") + lines.append(f" DP Shard (FSDP): {parallel_dims.dp_shard}") + lines.append(f" Context Parallel: {parallel_dims.cp}") + lines.append(f" Tensor Parallel: {parallel_dims.tp}") + lines.append(f" Pipeline Parallel: {parallel_dims.pp}") + lines.append(f" Expert Parallel: {parallel_dims.ep}") + lines.append(f" Expert TP: {parallel_dims.etp}") + + # Mesh structure + lines.append("\n[MESH STRUCTURE]") + lines.append(f" Mesh dim names: {mesh.mesh_dim_names}") + lines.append(f" Mesh shape: {mesh.mesh.shape}") + + # Log each dimension + for i, (name, size) in enumerate(zip(mesh.mesh_dim_names, mesh.mesh.shape)): + lines.append(f" Dim {i}: {name:20s} = {size}") + + # EP-specific derived dimensions + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append("\n[EXPERT PARALLEL DERIVED DIMENSIONS]") + lines.append(f" dp_shard_mod_ep (DP for non-experts): {dp_shard_mod_ep}") + lines.append(f" dp_shard_in_ep (DP within EP group): {dp_shard_in_ep}") + lines.append(f" ep_group_size (EP degree): {parallel_dims.ep}") + lines.append("") + lines.append(" Formula: dp_shard = dp_shard_mod_ep * dp_shard_in_ep") + lines.append( + f" {parallel_dims.dp_shard} = {dp_shard_mod_ep} * {dp_shard_in_ep}" + ) + lines.append("") + lines.append(" Formula: ep = dp_shard_in_ep * cp") + lines.append( + f" {parallel_dims.ep} = {dp_shard_in_ep} * {parallel_dims.cp}" + ) + + # Submesh info + lines.append("\n[SUBMESHES]") + + # Try to get submesh info + submesh_names = ["dp", "dp_shard_cp", "dp_cp", "ep", "cp", "tp", "pp"] + for name in submesh_names: + try: + submesh = mesh[name] + lines.append( + f" {name:15s}: size={submesh.size():4d}, dim_names={submesh.mesh_dim_names}" + ) + except (KeyError, RuntimeError): + pass + + return "\n".join(lines) + + +def visualize_gpu_allocation( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a grid visualization showing GPU allocation. + + For 16 nodes (128 GPUs) with EP=64, CP=8: + - Shows how each GPU maps to (dp_shard_mod_ep, dp_shard_in_ep, cp) coordinates + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("GPU ALLOCATION GRID") + lines.append("=" * 100) + + world_size = parallel_dims.world_size + num_nodes = world_size // 8 + + # For EP-enabled config + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append( + f"\nMesh: [{dp_shard_mod_ep}] x [{dp_shard_in_ep}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard_mod_ep] x [dp_shard_in_ep] x [cp]") + lines.append("") + + # Create mapping from global rank to mesh coordinates + lines.append("GPU -> Mesh Coordinate Mapping:") + lines.append("-" * 80) + lines.append( + f"{'Node':>6} | {'GPU':>4} | {'Rank':>5} | {'dp_mod_ep':>10} | {'dp_in_ep':>10} | {'cp':>4} | {'EP Group':>10}" + ) + lines.append("-" * 80) + + # The mesh is laid out as: dp_shard_mod_ep (slowest) x dp_shard_in_ep x cp (fastest) + for node in range(num_nodes): + for local_gpu in range(8): + global_rank = node * 8 + local_gpu + + # Compute mesh coordinates (assuming row-major ordering) + # Total size = dp_shard_mod_ep * dp_shard_in_ep * cp + cp_coord = global_rank % parallel_dims.cp + dp_in_ep_coord = (global_rank // parallel_dims.cp) % dp_shard_in_ep + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + + # EP group = dp_in_ep_coord * cp + cp_coord (within each dp_shard_mod_ep group) + ep_group = dp_in_ep_coord * parallel_dims.cp + cp_coord + + row = ( + f"{node:>6} | {local_gpu:>4} | {global_rank:>5} | " + f"{dp_mod_ep_coord:>10} | {dp_in_ep_coord:>10} | " + f"{cp_coord:>4} | {ep_group:>10}" + ) + lines.append(row) + + if node < num_nodes - 1: + lines.append("-" * 80) + else: + lines.append( + f"\nMesh: [{parallel_dims.dp_shard}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard] x [cp]") + + return "\n".join(lines) + + +def visualize_expert_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize which GPUs belong to which Expert Parallel group. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("EXPERT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.ep <= 1: + lines.append("Expert Parallel is disabled (EP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append(f"\nEP={parallel_dims.ep} experts distributed across GPUs") + lines.append( + f"Each EP group has {parallel_dims.ep} GPUs working on different experts" + ) + lines.append( + f"There are {dp_shard_mod_ep} such EP groups (for FSDP replication of experts)" + ) + lines.append("") + + # Group GPUs by their dp_shard_mod_ep coordinate + lines.append("EP Groups (GPUs that share the same set of experts):") + lines.append("-" * 80) + + for dp_mod_ep_idx in range(dp_shard_mod_ep): + # Find all ranks in this dp_shard_mod_ep group + ranks_in_group = [] + for global_rank in range(world_size): + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + if dp_mod_ep_coord == dp_mod_ep_idx: + ranks_in_group.append(global_rank) + + lines.append(f"\nDP_SHARD_MOD_EP group {dp_mod_ep_idx}:") + lines.append( + f" GPUs: {ranks_in_group[:16]}{'...' if len(ranks_in_group) > 16 else ''}" + ) + lines.append(f" Total: {len(ranks_in_group)} GPUs") + lines.append(" These GPUs have IDENTICAL expert parameters (FSDP sharded)") + + return "\n".join(lines) + + +def visualize_context_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize Context Parallel groups - GPUs that work on different parts of the sequence. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("CONTEXT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.cp <= 1: + lines.append("Context Parallel is disabled (CP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + cp = parallel_dims.cp + + lines.append(f"\nCP={cp} - Each sequence is split into {cp} chunks") + lines.append( + "GPUs with the same (dp_shard, ep) coordinates but different cp coordinates" + ) + lines.append("work on different parts of the same sequence.") + lines.append("") + + # Show a few example CP groups + lines.append("Example CP groups (first few):") + lines.append("-" * 80) + + num_cp_groups = world_size // cp + for cp_group_idx in range(min(4, num_cp_groups)): + ranks_in_group = [cp_group_idx * cp + i for i in range(cp)] + lines.append(f"\nCP group {cp_group_idx}:") + lines.append(f" GPUs: {ranks_in_group}") + lines.append(f" These {cp} GPUs process different chunks of the same sequence") + + if num_cp_groups > 4: + lines.append(f"\n... and {num_cp_groups - 4} more CP groups") + + return "\n".join(lines) + + +def visualize_fsdp_sharding( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize FSDP sharding - which GPUs share which parameters. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("FSDP SHARDING VISUALIZATION") + lines.append("=" * 100) + + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + + lines.append("\n[NON-EXPERT PARAMETERS (Attention, Embeddings, etc.)]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + lines.append( + f" All-gather buffer size per param: original_size / {dp_shard_cp_size}" + ) + + lines.append("\n[EXPERT PARAMETERS (MoE experts)]") + lines.append(" FSDP mesh: dp_shard_mod_ep") + lines.append(f" FSDP group size: {dp_shard_mod_ep} GPUs") + lines.append( + f" Each expert's parameters are sharded across {dp_shard_mod_ep} GPUs" + ) + lines.append( + f" All-gather buffer size per expert param: original_size / {dp_shard_mod_ep}" + ) + + lines.append("\n[MEMORY IMPLICATIONS]") + lines.append( + f" Non-expert params: sharded {dp_shard_cp_size}x -> small per-GPU footprint" + ) + lines.append( + f" Expert params: sharded only {dp_shard_mod_ep}x -> larger per-GPU footprint" + ) + lines.append(" ") + lines.append(" As DP increases:") + lines.append( + " - dp_shard_cp increases -> non-expert params get more sharded" + ) + lines.append( + " - dp_shard_mod_ep increases -> expert params get more sharded" + ) + lines.append( + " - BUT: all-gather/reduce-scatter buffers scale with group size!" + ) + + else: + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + lines.append("\n[ALL PARAMETERS]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + + return "\n".join(lines) + + +def create_full_visualization( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """Create a comprehensive visualization of the entire mesh structure.""" + parts = [ + visualize_mesh_structure(mesh, parallel_dims, rank), + visualize_gpu_allocation(mesh, parallel_dims, rank), + visualize_expert_parallel_groups(mesh, parallel_dims, rank), + visualize_context_parallel_groups(mesh, parallel_dims, rank), + visualize_fsdp_sharding(mesh, parallel_dims, rank), + ] + + full_viz = "\n".join(parts) + full_viz += "\n" + "=" * 100 + full_viz += "\nEND OF DEVICE MESH VISUALIZATION" + full_viz += "\n" + "=" * 100 + + return full_viz + + +def log_mesh_visualization(mesh: DeviceMesh, parallel_dims): + """Log the full mesh visualization (only on rank 0).""" + rank = dist.get_rank() if dist.is_initialized() else 0 + + if rank == 0: + viz = create_full_visualization(mesh, parallel_dims, rank) + # Log each line separately for better formatting + for line in viz.split("\n"): + logger.info(f"[MESH-VIZ] {line}") diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index c37345f8d3..988534f73d 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -13,6 +13,7 @@ from torchtitan.config import Profiling as ProfilingConfig from torchtitan.tools.logging import logger +from torchtitan.tools.utils import device_module # how much memory allocation/free ops to record in memory snapshots MEMORY_SNAPSHOT_MAX_ENTRIES = 100000 @@ -69,6 +70,7 @@ def trace_handler(prof): elif torch.xpu.is_available(): gpu_device_profiled = torch.profiler.ProfilerActivity.XPU with torch.profiler.profile( + # pyrefly: ignore [bad-argument-type] activities=[ torch.profiler.ProfilerActivity.CPU, gpu_device_profiled, @@ -105,7 +107,7 @@ def maybe_enable_memory_snapshot( class MemoryProfiler: def __init__(self, step_num: int, freq: int): - torch.cuda.memory._record_memory_history( + device_module.memory._record_memory_history( max_entries=MEMORY_SNAPSHOT_MAX_ENTRIES ) # when resume training, we start from the last step @@ -132,7 +134,7 @@ def step(self, exit_ctx: bool = False): curr_snapshot_dir, f"rank{rank}_memory_snapshot.pickle" ) with open(output_file, "wb") as output: - pickle.dump(torch.cuda.memory._snapshot(), output) + pickle.dump(device_module.memory._snapshot(), output) logger.info( f"Finished dumping memory snapshot in {time.monotonic() - begin:.2f} seconds" ) diff --git a/torchtitan/tools/utils.py b/torchtitan/tools/utils.py index ff19502ae2..23c187b520 100644 --- a/torchtitan/tools/utils.py +++ b/torchtitan/tools/utils.py @@ -9,6 +9,7 @@ import subprocess import time from dataclasses import dataclass +from types import ModuleType from typing import Generator, Optional import torch @@ -24,7 +25,15 @@ def has_cuda_capability(major: int, minor: int) -> bool: ) -def get_device_info() -> tuple[str, torch.device]: +def has_rocm_capability(major: int, minor: int) -> bool: + is_rocm = torch.cuda.is_available() and torch.version.hip is not None + return is_rocm and torch.cuda.get_device_capability() >= ( + major, + minor, + ) + + +def get_device_info() -> tuple[str, ModuleType]: device_type = _get_available_device_type() or "cuda" device_module = _get_device_module(device_type) # default device_module:torch.cuda return device_type, device_module @@ -73,7 +82,7 @@ def collect(reason: str, generation: int = 1, empty_cuda_cache: bool = False): # hardcoded BF16 type peak flops for NVIDIA A100, H100, H200, B200 GPU and AMD MI250, MI300X, MI325X, MI355X and Intel PVC -def get_peak_flops(device_name: str) -> int: +def get_peak_flops(device_name: str) -> float: try: # Run the lspci command and capture the output result = subprocess.run(["lspci"], stdout=subprocess.PIPE, text=True) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9c16c2ef47..2ff6b3a49c 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -4,14 +4,16 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import dataclasses import importlib +import json import os import time from datetime import timedelta -from typing import Any, Generator, Iterable +from typing import Any, Iterable import torch - +import torch.distributed.checkpoint.stateful from torch.distributed.elastic.multiprocessing.errors import record import torchtitan.protocols.train_spec as train_spec_module @@ -25,9 +27,14 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.distributed.context_parallel import prepare_context_parallel_input from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils +from torchtitan.tools.aggressive_memory_manager import create_aggressive_memory_manager +from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker +from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger +from torchtitan.tools.mesh_visualizer import log_mesh_visualization from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, @@ -58,7 +65,7 @@ class Trainer(torch.distributed.checkpoint.stateful.Stateful): # runtime utilities device: torch.device gc_handler: utils.GarbageCollection - train_context: Generator[None, None, None] + train_context: dist_utils.TrainContext gradient_accumulation_steps: int pp_has_first_stage: bool pp_has_last_stage: bool @@ -80,34 +87,70 @@ def __init__(self, job_config: JobConfig): importlib.import_module(job_config.experimental.custom_import) device_module, device_type = utils.device_module, utils.device_type + # pyrefly: ignore [read-only] self.device = torch.device(f"{device_type}:{int(os.environ['LOCAL_RANK'])}") # Device has to be set before creating TorchFT manager. device_module.set_device(self.device) - job_config.maybe_log() - # init distributed and build meshes self.parallel_dims = parallel_dims = self.init_distributed() - world_mesh = parallel_dims.world_mesh + # Log mesh visualization for debugging distributed setup (rank 0 only) + log_mesh_visualization(parallel_dims.world_mesh, parallel_dims) + if parallel_dims.dp_enabled: - dp_mesh = world_mesh["dp"] - dp_degree, dp_rank = dp_mesh.size(), dp_mesh.get_local_rank() + batch_mesh = parallel_dims.get_mesh("batch") + batch_degree, batch_rank = batch_mesh.size(), batch_mesh.get_local_rank() else: - dp_degree, dp_rank = 1, 0 + batch_degree, batch_rank = 1, 0 + # pyrefly: ignore [bad-argument-type] self.ft_manager = FTManager(job_config.fault_tolerance) - dp_degree, dp_rank = self.ft_manager.get_dp_info(dp_degree, dp_rank) + batch_degree, batch_rank = self.ft_manager.get_dp_info(batch_degree, batch_rank) # take control of garbage collection to avoid stragglers self.gc_handler = utils.GarbageCollection( gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # Initialize detailed memory tracker + self.detailed_memory_tracker = DetailedMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + clear_cache=getattr( + job_config.training, "clear_cache_between_steps", False + ), + ) + + # Initialize CUDA memory tracker + self.cuda_memory_tracker = CUDAMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + ) + + # Initialize aggressive memory manager to reduce CUDA fragmentation + aggressive_mem_mode = getattr( + job_config.training, "aggressive_memory_mode", None + ) + if aggressive_mem_mode: + self.aggressive_mem_manager = create_aggressive_memory_manager( + mode=aggressive_mem_mode, + verbose=getattr( + job_config.training, "aggressive_memory_verbose", False + ), + ) + logger.info( + f"Aggressive memory manager enabled (mode={aggressive_mem_mode})" + ) + else: + self.aggressive_mem_manager = None + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( - world_mesh, + parallel_dims, self.device, job_config.debug, distinct_seed_mesh_dims=["pp"], @@ -122,8 +165,8 @@ def __init__(self, job_config: JobConfig): ) self.dataloader = self.train_spec.build_dataloader_fn( - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, job_config=job_config, ) @@ -135,7 +178,8 @@ def __init__(self, job_config: JobConfig): self.model_args = model_args logger.info( - f"Building {job_config.model.name} {job_config.model.flavor} with {model_args}" + f"Building {job_config.model.name} {job_config.model.flavor}" + f"with {json.dumps(dataclasses.asdict(model_args), indent=2, ensure_ascii=False)}" ) with ( torch.device("meta"), @@ -194,19 +238,20 @@ def __init__(self, job_config: JobConfig): if global_batch_size < 0: # This global batch size results in 1 gradient accumulation # step. - global_batch_size = job_config.training.local_batch_size * dp_degree + global_batch_size = job_config.training.local_batch_size * batch_degree assert global_batch_size > 0 assert ( - global_batch_size % (job_config.training.local_batch_size * dp_degree) == 0 + global_batch_size % (job_config.training.local_batch_size * batch_degree) + == 0 ), ( f"global batch size must be multiple of local batch size times " f"data-parallel degree ({global_batch_size} " - f"% ({job_config.training.local_batch_size} * {dp_degree}) != 0)" + f"% ({job_config.training.local_batch_size} * {batch_degree}) != 0)" ) # calculate gradient accumulation steps self.gradient_accumulation_steps = global_batch_size // ( - job_config.training.local_batch_size * dp_degree + job_config.training.local_batch_size * batch_degree ) assert self.gradient_accumulation_steps > 0 self.loss_fn = rescale_accumulated_loss( @@ -263,10 +308,12 @@ def __init__(self, job_config: JobConfig): for m in self.model_parts: m.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] m.init_weights(buffer_device=buffer_device) m.train() # confirm that user will be able to view loss metrics on the console + # pyrefly: ignore [bad-argument-type] ensure_pp_loss_visible(parallel_dims, job_config, color) else: # apply PT-D Tensor Parallel, activation checkpointing, torch.compile, Data Parallel @@ -274,6 +321,7 @@ def __init__(self, job_config: JobConfig): model.to_empty(device=init_device) with torch.no_grad(): + # pyrefly: ignore [not-callable] model.init_weights(buffer_device=buffer_device) model.train() @@ -331,6 +379,7 @@ def __init__(self, job_config: JobConfig): states={"train_state": self}, checkpoint_config=job_config.checkpoint, sd_adapter=( + # pyrefly: ignore[bad-instantiation] self.train_spec.state_dict_adapter( model_args, job_config.model.hf_assets_path ) @@ -368,8 +417,8 @@ def __init__(self, job_config: JobConfig): self.validator = self.train_spec.build_validator_fn( job_config=job_config, - dp_world_size=dp_degree, - dp_rank=dp_rank, + dp_world_size=batch_degree, + dp_rank=batch_rank, tokenizer=self.tokenizer, parallel_dims=parallel_dims, loss_fn=self.loss_fn, @@ -411,15 +460,13 @@ def format_tokens(num): def init_distributed(self) -> ParallelDims: job_config = self.job_config - dist_utils.init_distributed( + world_size = dist_utils.init_distributed( job_config.comm, enable_cpu_backend=job_config.training.enable_cpu_offload, base_folder=job_config.job.dump_folder, ) - world_size = int(os.environ["WORLD_SIZE"]) parallelism_config = job_config.parallelism - return ParallelDims( dp_shard=parallelism_config.data_parallel_shard_degree, dp_replicate=parallelism_config.data_parallel_replicate_degree, @@ -504,18 +551,34 @@ def post_dataloading_process( """ inputs = input_dict["input"] extra_inputs = {k: v for k, v in input_dict.items() if k != "input"} - # For arguments, like attention_masks, we have to put them in a separate - # dict as extra_inputs are not forwarded to other stages in PP, but - # extra_kwargs are. + extra_kwargs: dict[str, Any] = {} + if "position_ids" in extra_inputs: + extra_kwargs["positions"] = extra_inputs.pop("position_ids") - if getattr(self.model_args, "use_flex_attn", False): + attn_type = getattr(self.model_args, "attn_type", "sdpa") + if attn_type in ["flex", "varlen"]: + # pyrefly: ignore [not-callable] extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) + # Remove sequence_lengths from extra_inputs after attention mask creation + # as it's not needed by model forward + extra_inputs.pop("sequence_lengths", None) + + if self.parallel_dims.cp_enabled: + inputs, labels, extra_kwargs = prepare_context_parallel_input( + inputs, + labels, + extra_kwargs, + self.parallel_dims.get_mesh("cp"), + self.device, + self.job_config.parallelism.context_parallel_load_balancer, + ) + return inputs, labels, extra_inputs, extra_kwargs def _collect_moe_expert_metrics(self) -> dict[str, Any]: @@ -623,25 +686,10 @@ def forward_backward_step( inputs, labels, extra_inputs, extra_kwargs = self.post_dataloading_process( input_dict, labels ) - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=list(input_dict.values()) - + [labels] - + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1] * len(input_dict) + [1] + [0 for _ in model_parts], - cp_no_restore_buffers=set(input_dict.values()).union([labels]), - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None - ) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): + with self.train_context(): targets, losses = ( (labels, []) if self.pp_has_last_stage else (None, None) ) @@ -674,8 +722,8 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 + assert len(model_parts) == 1 + with self.train_context(): with self.maybe_enable_amp: pred = model_parts[0](inputs, **extra_inputs, **extra_kwargs) loss = self.loss_fn(pred, labels) @@ -683,11 +731,30 @@ def forward_backward_step( del pred loss.backward() + # Aggressive memory clearing after backward to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_backward() + return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): + # AGGRESSIVE cache clearing before step for accurate memory measurements + if self.job_config.training.aggressive_memory_mode: + import gc + + torch.cuda.synchronize() + gc.collect(0) + gc.collect(1) + gc.collect(2) + torch.cuda.empty_cache() + torch.cuda.synchronize() + gc.collect(2) + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + self.metrics_processor.device_memory_monitor.reset_peak_stats() + self.optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -696,26 +763,64 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + # Track memory before forward pass + self.detailed_memory_tracker.measure("before_forward", self.step) + self.cuda_memory_tracker.measure_all("before_forward", self.step) + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. for _microbatch in range(self.gradient_accumulation_steps): + # pyrefly: ignore [no-matching-overload] input_dict, labels = next(data_iterator) loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + # Track memory after forward/backward + self.detailed_memory_tracker.measure("after_forward_backward", self.step) + self.cuda_memory_tracker.measure_all("after_forward_backward", self.step) + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, foreach=True, - pp_mesh=( - parallel_dims.world_mesh["pp"] if parallel_dims.pp_enabled else None - ), + pp_mesh=parallel_dims.get_optional_mesh("pp"), ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() + + # Skip optimizer step if configured (for memory profiling) + if not self.job_config.training.skip_optimizer_step: + import datetime + import time as _time + + # Log step start with timestamp for correlation + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info(f"[STEP {self.step}] optimizer.step() START @ {_ts}") + + _optim_start = _time.time() + self.optimizers.step() + _optim_elapsed = _time.time() - _optim_start + + # Aggressive memory clearing after optimizer to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_optimizer() + + # Log step end with timing + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info( + f"[STEP {self.step}] optimizer.step() END @ {_ts} | Duration: {_optim_elapsed:.2f}s" + ) + + self.lr_schedulers.step() + else: + logger.info("Skipping optimizer step (skip_optimizer_step=True)") + + # Track memory after optimizer step + self.detailed_memory_tracker.measure("after_optimizer", self.step) + self.cuda_memory_tracker.measure_all("after_optimizer", self.step) # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -727,14 +832,15 @@ def train_step( if parallel_dims.dp_cp_enabled: loss = loss.detach() ft_pg = self.ft_manager.loss_sync_pg + loss_mesh = parallel_dims.get_optional_mesh("loss") global_avg_loss, global_max_loss, global_ntokens_seen = ( - dist_utils.dist_mean(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), - dist_utils.dist_max(loss, parallel_dims.world_mesh["dp_cp"], ft_pg), + dist_utils.dist_mean(loss, loss_mesh, ft_pg), + dist_utils.dist_max(loss, loss_mesh, ft_pg), dist_utils.dist_sum( torch.tensor( self.ntokens_seen, dtype=torch.int64, device=self.device ), - parallel_dims.world_mesh["dp_cp"], + loss_mesh, ft_pg, ), ) @@ -755,11 +861,25 @@ def train_step( extra_metrics=extra_metrics, ) + # Signal step complete to aggressive memory manager (triggers defrag check) + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.step_complete() + @record def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + + # Pre-initialize bf16 optimizer states if configured + # This must happen BEFORE training to avoid rank skew during first step + if hasattr(self.optimizers, "init_bf16_states"): + self.optimizers.init_bf16_states() + # Barrier to ensure all ranks finish before training starts + if torch.distributed.is_initialized(): + torch.distributed.barrier() + logger.info("All ranks synchronized after bf16 optimizer state init") + logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( @@ -781,6 +901,7 @@ def train(self): leaf_folder=leaf_folder, ) as memory_profiler, maybe_semi_sync_training( + # pyrefly: ignore [bad-argument-type] job_config.fault_tolerance, ft_manager=self.ft_manager, model=self.model_parts[0], @@ -823,7 +944,9 @@ def train(self): self.job_config.validation.enable and self.validator.should_validate(self.step) ): + # pyrefly: ignore [missing-attribute] with self.loss_fn.no_rescale(): + # pyrefly: ignore [bad-argument-count] self.validator.validate(self.model_parts, self.step) # signal the profiler that the next profiling step has started @@ -832,6 +955,10 @@ def train(self): if memory_profiler: memory_profiler.step() + # Track memory at step end and optionally clear cache + self.detailed_memory_tracker.step_complete(self.step) + self.cuda_memory_tracker.measure_all("step_end", self.step) + # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -839,13 +966,17 @@ def train(self): timeout=timedelta( seconds=job_config.comm.train_timeout_seconds ), - world_mesh=self.parallel_dims.world_mesh, + parallel_dims=self.parallel_dims, ) if torch.distributed.get_rank() == 0: logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) + # Log detailed memory tracking summary + if torch.distributed.get_rank() == 0: + logger.info(self.detailed_memory_tracker.get_summary()) + logger.info("Training completed") def should_continue_training(self) -> bool: @@ -859,9 +990,9 @@ def load_state_dict(self, state_dict: dict[str, Any]): self.ntokens_seen = state_dict["ntokens_seen"] def close(self) -> None: - if self.checkpointer: + if hasattr(self, "checkpointer") and self.checkpointer: self.checkpointer.close() - if self.metrics_processor: + if hasattr(self, "metrics_processor") and self.metrics_processor: self.metrics_processor.close() @@ -872,6 +1003,14 @@ def main(trainer_class: type[Trainer]) -> None: trainer_class: The trainer class to instantiate (e.g., Trainer, FluxTrainer, TorchCommsTrainer) """ init_logger() + + import torchtitan + + logger.info( + "torchtitan version: %s (0.0.0 means __version__ is not defined correctly).", + torchtitan.__version__, + ) + config_manager = ConfigManager() config = config_manager.parse_args() @@ -880,6 +1019,13 @@ def main(trainer_class: type[Trainer]) -> None: try: trainer = trainer_class(config) + # TODO(local_tensor): Remove this special case once LocalTensor supports + # init_weights() and foreach_allgather. In local tensor mode, skip + # training/checkpointing as the # model is not fully initialized + if config.comm.mode == "local_tensor": + logger.info("Local tensor mode enabled - skipping training execution") + return + if config.checkpoint.create_seed_checkpoint: assert ( int(os.environ["WORLD_SIZE"]) == 1 @@ -897,7 +1043,8 @@ def main(trainer_class: type[Trainer]) -> None: raise else: trainer.close() - torch.distributed.destroy_process_group() + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() logger.info("Process group destroyed")