diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_internode.py b/scripts/deepep/torchtitan_deepep_tune/tune_internode.py index cc7f5fa7c3..7c8b7d9767 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_internode.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_internode.py @@ -29,8 +29,50 @@ print("ERROR: deep_ep not found") sys.exit(1) -sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests") -from utils import bench_kineto, init_dist +DEEPEP_TESTS_PATH = os.environ.get( + "DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests" +) +sys.path.insert(0, DEEPEP_TESTS_PATH) +from utils import bench_kineto + + +def init_dist_torchrun(local_rank: int, num_local_ranks: int): + """ + Initialize distributed for torchrun environment. + torchrun sets: WORLD_SIZE=total_procs, RANK=global_rank, LOCAL_RANK, LOCAL_WORLD_SIZE + But init_dist expects: WORLD_SIZE=num_nodes, RANK=node_rank + """ + import inspect + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + global_rank = int(os.environ.get("RANK", 0)) + + # Calculate node info from torchrun env vars + num_nodes = world_size // num_local_ranks + node_rank = global_rank // num_local_ranks + + ip = os.getenv("MASTER_ADDR", "127.0.0.1") + port = int(os.getenv("MASTER_PORT", "29500")) + + sig = inspect.signature(dist.init_process_group) + params = { + "backend": "nccl", + "init_method": f"tcp://{ip}:{port}", + "world_size": num_nodes * num_local_ranks, + "rank": node_rank * num_local_ranks + local_rank, + } + if "device_id" in sig.parameters: + params["device_id"] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.cuda.set_device(local_rank) + + return ( + dist.get_rank(), + dist.get_world_size(), + dist.new_group(list(range(num_local_ranks * num_nodes))), + ) @dataclass @@ -71,15 +113,19 @@ def __init__( self.hidden = hidden self.num_experts = num_experts self.num_topk = num_topk - self.num_topk_groups = num_topk_groups - # Init distributed + # Init distributed (using torchrun-compatible init) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) num_local_ranks = int(os.environ.get("LOCAL_WORLD_SIZE", 8)) - self.rank, self.num_ranks, self.group = init_dist( + self.rank, self.num_ranks, self.group = init_dist_torchrun( self.local_rank, num_local_ranks ) - self.num_nodes = int(os.environ.get("WORLD_SIZE", 1)) + # Calculate num_nodes from torchrun env vars + world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.num_nodes = world_size // num_local_ranks + + # num_topk_groups must be <= num_nodes + self.num_topk_groups = min(num_topk_groups, self.num_nodes) self.num_sms = 24 # Buffer sizes (from benchmark_internode.py) @@ -124,39 +170,31 @@ def setup_data(self): * self.rank ) - # Random scores with group-based routing (like Qwen3) + # UNIFORM distribution across all ranks + # Use same seed for reproducibility across all ranks + torch.manual_seed(42) scores = ( torch.randn( (self.num_tokens, self.num_experts), dtype=torch.float32, device="cuda" ).abs() + 1 ) - group_scores = scores.view(self.num_tokens, self.num_nodes, -1).amax(dim=-1) - group_idx = torch.topk( - group_scores, k=self.num_topk_groups, dim=-1, sorted=False - ).indices - - # Create grouped scores (group-limited routing) - masked_scores = scores.clone() - for i in range(self.num_nodes): - mask = (group_idx == i).any(dim=-1, keepdim=True) - node_mask = torch.zeros( - self.num_tokens, self.num_experts, dtype=torch.bool, device="cuda" - ) - start_expert = i * (self.num_experts // self.num_nodes) - end_expert = (i + 1) * (self.num_experts // self.num_nodes) - node_mask[:, start_expert:end_expert] = True - masked_scores = torch.where( - mask & node_mask, - masked_scores, - torch.tensor(-float("inf"), device="cuda"), - ) - self.topk_idx = torch.topk( - masked_scores, self.num_topk, dim=-1, largest=True, sorted=False + scores, self.num_topk, dim=-1, largest=True, sorted=False )[1] self.topk_idx = self.topk_idx.to(deep_ep.topk_idx_t) + # Verify distribution (only on rank 0) + if self.is_rank0(): + rank_idx_check = self.topk_idx // (self.num_experts // self.num_ranks) + tokens_per_rank = [ + (rank_idx_check == r).sum().item() for r in range(self.num_ranks) + ] + print( + f"[uniform] Tokens per rank: min={min(tokens_per_rank)}, max={max(tokens_per_rank)}, " + f"ratio={max(tokens_per_rank)/max(min(tokens_per_rank), 1):.2f}x" + ) + # Get layout ( self.num_tokens_per_rank, @@ -347,6 +385,13 @@ def cleanup(self): dist.destroy_process_group() +# Model presets +MODEL_CONFIGS = { + "qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8}, + "kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8}, +} + + def main(): parser = argparse.ArgumentParser( description="Tune DeepEP for actual torchtitan setup" @@ -354,15 +399,55 @@ def main(): parser.add_argument("--ep-size", type=int, required=True) parser.add_argument("--mode", choices=["quick", "medium", "full"], default="medium") parser.add_argument("--output-dir", default="results") + parser.add_argument( + "--model", + type=str, + choices=["qwen3", "kimi_k2", "custom"], + default="qwen3", + help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)", + ) + parser.add_argument("--num-tokens", type=int, default=4096, help="Number of tokens") + parser.add_argument( + "--hidden", + type=int, + default=None, + help="Hidden dimension (overrides model preset)", + ) + parser.add_argument( + "--num-experts", + type=int, + default=None, + help="Number of experts (overrides model preset)", + ) + parser.add_argument( + "--num-topk", + type=int, + default=None, + help="Top-k experts (overrides model preset)", + ) args = parser.parse_args() - # Qwen3-30B-A3B parameters + # Apply model preset, allow overrides + if args.model in MODEL_CONFIGS: + preset = MODEL_CONFIGS[args.model] + hidden = args.hidden if args.hidden else preset["hidden"] + num_experts = args.num_experts if args.num_experts else preset["num_experts"] + num_topk = args.num_topk if args.num_topk else preset["num_topk"] + else: + hidden = args.hidden if args.hidden else 2048 + num_experts = args.num_experts if args.num_experts else 128 + num_topk = args.num_topk if args.num_topk else 8 + + print( + f"Model: {args.model}, hidden={hidden}, experts={num_experts}, topk={num_topk}, tokens={args.num_tokens}" + ) + tuner = DeepEPTuner( - num_tokens=4096, - hidden=2048, # Qwen3-30B dim - num_experts=128, # Qwen3-30B-A3B - num_topk=8, + num_tokens=args.num_tokens, + hidden=hidden, + num_experts=num_experts, + num_topk=num_topk, num_topk_groups=4, ) diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py b/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py index 441a5f1012..9405fc386e 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py @@ -40,7 +40,7 @@ def init_dist(local_rank: int, num_local_ranks: int): os.environ["RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(num_local_ranks) os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(29500 + local_rank) + os.environ["MASTER_PORT"] = "29500" # Same port for all ranks dist.init_process_group(backend="nccl", rank=local_rank, world_size=num_local_ranks) torch.cuda.set_device(local_rank) @@ -426,6 +426,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): dist.destroy_process_group() +# Model presets +MODEL_CONFIGS = { + "qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8}, + "kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8}, +} + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tune DeepEP intranode configs for TorchTitan" @@ -436,26 +443,58 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument( "--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)" ) + parser.add_argument( + "--model", + type=str, + choices=["qwen3", "kimi_k2", "custom"], + default="qwen3", + help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)", + ) parser.add_argument( "--hidden", type=int, - default=2048, - help="Hidden dimension - Qwen3-30B (default: 2048)", + default=None, + help="Hidden dimension (overrides model preset)", ) parser.add_argument( - "--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)" + "--num-topk", + type=int, + default=None, + help="Number of top-k experts (overrides model preset)", ) parser.add_argument( "--num-experts", type=int, - default=128, - help="Number of experts - Qwen3-30B-A3B (default: 128)", + default=None, + help="Number of experts (overrides model preset)", ) parser.add_argument( "--output-dir", type=str, default="results", help="Output directory for results" ) args = parser.parse_args() + # Apply model preset, allow overrides + if args.model in MODEL_CONFIGS: + preset = MODEL_CONFIGS[args.model] + if args.hidden is None: + args.hidden = preset["hidden"] + if args.num_experts is None: + args.num_experts = preset["num_experts"] + if args.num_topk is None: + args.num_topk = preset["num_topk"] + else: + # Custom mode - require explicit values + if args.hidden is None: + args.hidden = 2048 + if args.num_experts is None: + args.num_experts = 128 + if args.num_topk is None: + args.num_topk = 8 + + print( + f"Model: {args.model}, hidden={args.hidden}, experts={args.num_experts}, topk={args.num_topk}" + ) + num_processes = args.num_processes torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py b/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py index d8164967f2..c73f61f691 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py @@ -31,7 +31,10 @@ sys.exit(1) # Import from DeepEP tests -sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests") +DEEPEP_TESTS_PATH = os.environ.get( + "DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests" +) +sys.path.insert(0, DEEPEP_TESTS_PATH) from utils import bench, init_dist, inplace_unique diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 62ec4331ef..1ec434bfdf 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -36,6 +36,8 @@ "max_reserved_pct", "num_alloc_retries", "num_ooms", + "nvidia_smi_used_gib", # nvidia-smi reported memory for verification + "nvidia_smi_used_pct", ], ) @@ -62,6 +64,49 @@ def _to_gib(self, memory_in_bytes): def _to_pct(self, memory): return 100 * memory / self.device_capacity + def _get_nvidia_smi_memory(self): + """Get GPU memory usage from nvidia-smi for verification.""" + try: + import os + import subprocess + + # In SLURM with CUDA_VISIBLE_DEVICES, PyTorch device index 0-7 maps to + # physical GPUs listed in CUDA_VISIBLE_DEVICES. We need the physical GPU index. + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible: + # CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" means device 0 is physical GPU 0 + # But it could also be "4,5,6,7,0,1,2,3" meaning device 0 is physical GPU 4 + visible_gpus = [ + int(x.strip()) for x in cuda_visible.split(",") if x.strip() + ] + if self.device_index < len(visible_gpus): + physical_gpu_index = visible_gpus[self.device_index] + else: + physical_gpu_index = self.device_index + else: + physical_gpu_index = self.device_index + + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + f"--id={physical_gpu_index}", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # nvidia-smi reports in MiB + used_mib = float(result.stdout.strip()) + used_gib = used_mib / 1024 + used_pct = (used_mib * 1024 * 1024) / self.device_capacity * 100 + return used_gib, used_pct + except Exception: + pass + return -1.0, -1.0 + def get_peak_stats(self): device_info = device_module.memory_stats(self.device) @@ -83,6 +128,9 @@ def get_peak_stats(self): if num_ooms > 0: logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") + # Get nvidia-smi memory for verification + nvidia_smi_gib, nvidia_smi_pct = self._get_nvidia_smi_memory() + return DeviceMemStats( max_active_gib, max_active_pct, @@ -90,6 +138,8 @@ def get_peak_stats(self): max_reserved_pct, num_retries, num_ooms, + nvidia_smi_gib, + nvidia_smi_pct, ) def reset_peak_stats(self): @@ -480,16 +530,26 @@ def log( self.logger.log(metrics, step) color = self.color + # Show ACTIVE memory (actual tensor usage) as primary, with reserved and nvidia-smi for verification + # active = actual tensors, reserved = active + cached, nvidia-smi = OS-level GPU memory logger.info( f"{color.red}step: {step:2} " f"{color.green}loss: {global_avg_loss:7.4f} " f"{color.orange}grad_norm: {grad_norm:7.4f} " - f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" - f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.turquoise}memory: {device_mem_stats.max_active_gib:5.2f}GiB" + f"({device_mem_stats.max_active_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " f"{color.cyan}tflops: {tflops:,.2f} " f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) + # Log detailed memory breakdown for verification (only on rank 0, step 2+) + if step >= 2 and torch.distributed.get_rank() == 0: + logger.info( + f" [Memory Detail] active={device_mem_stats.max_active_gib:.2f}GiB " + f"reserved={device_mem_stats.max_reserved_gib:.2f}GiB " + f"nvidia-smi={device_mem_stats.nvidia_smi_used_gib:.2f}GiB " + f"retries={device_mem_stats.num_alloc_retries}" + ) self.ntokens_since_last_log = 0 self.data_loading_times.clear() diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 77d317c0b9..a17258f8ea 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -21,6 +21,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config import Optimizer as OptimizerConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger # Dion optimizer availability will be checked lazily when needed DION_AVAILABLE = None @@ -32,11 +33,11 @@ def _check_dion_availability(): global DION_AVAILABLE if DION_AVAILABLE is None: try: - from torchtitan.experiments.dion_optimizer.dion import ( + from torchtitan.experiments.dion_optimizer.dion import ( # noqa: F401 Dion, DionMixedPrecisionConfig, ) - from torchtitan.experiments.dion_optimizer.titan_dion import ( + from torchtitan.experiments.dion_optimizer.titan_dion import ( # noqa: F401 DionOptimizersContainer, ) @@ -51,8 +52,8 @@ def _check_muon_availability(): global MUON_AVAILABLE if MUON_AVAILABLE is None: try: - from torchtitan.experiments.dion_optimizer.muon import Muon - from torchtitan.experiments.dion_optimizer.titan_muon import ( + from torchtitan.experiments.dion_optimizer.muon import Muon # noqa: F401 + from torchtitan.experiments.dion_optimizer.titan_muon import ( # noqa: F401 MuonOptimizersContainer, ) @@ -76,6 +77,127 @@ def _check_muon_availability(): T = TypeVar("T", bound=Optimizer) +def preinit_optimizer_states_bf16(optimizers_container: "OptimizersContainer") -> None: + """ + Pre-initialize optimizer states (exp_avg, exp_avg_sq) directly in bfloat16. + This MUST be called BEFORE the first optimizer.step() to avoid fp32 allocation spike. + + This reduces optimizer state memory by ~50% (from fp32 to bf16). + States are allocated in bf16 from the start, avoiding the memory spike from fp32 allocation. + """ + total_params = 0 + total_bytes = 0 + + # For detailed logging + dtype_device_samples = [] + + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + for opt_idx, optimizer in enumerate(optimizers_container.optimizers): + for pg_idx, param_group in enumerate(optimizer.param_groups): + for p_idx, p in enumerate(param_group["params"]): + if p.requires_grad: + # Log first few params for debugging + if total_params < 5: + dtype_device_samples.append( + f"param[{opt_idx}][{pg_idx}][{p_idx}]: dtype={p.dtype}, device={p.device}, shape={list(p.shape)}" + ) + + # Initialize state dict for this parameter + state = optimizer.state[p] + if len(state) == 0: # Only initialize if not already initialized + state["step"] = torch.tensor(0, dtype=torch.float32) + # Allocate exp_avg and exp_avg_sq in SAME dtype as param for fused Adam compatibility + state["exp_avg"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + total_params += 1 + # Calculate bytes based on actual param dtype + bytes_per_element = 2 if p.dtype == torch.bfloat16 else 4 + total_bytes += p.numel() * 2 * bytes_per_element # 2 states + + # Log first few state allocations + if total_params <= 3: + logger.info( + f"[Rank {rank}] State init sample: param dtype={p.dtype}, device={p.device}, " + f"exp_avg dtype={state['exp_avg'].dtype}, device={state['exp_avg'].device}" + ) + + # Log dtype/device samples + for sample in dtype_device_samples: + logger.info(f"[Rank {rank}] {sample}") + + logger.info( + f"[Rank {rank}] Pre-initialized {total_params} optimizer states matching param dtype, " + f"this rank: {total_bytes / 1e9:.2f} GB" + ) + + +class BF16StateOptimizersContainer(Generic[T]): + """ + Wrapper that pre-initializes optimizer states in bfloat16 BEFORE first step. + This prevents the memory spike from fp32 state allocation. + + IMPORTANT: Call init_bf16_states() BEFORE the first step() to avoid + rank skew during state allocation. This should be called after model + setup but before training starts, ideally with a barrier afterwards. + """ + + def __init__( + self, + base_container: "OptimizersContainer", + state_dtype: torch.dtype = torch.bfloat16, + ): + self._base = base_container + self._state_dtype = state_dtype + self._states_initialized = False + + def init_bf16_states(self): + """ + Pre-initialize optimizer states in bf16. + Call this BEFORE training starts, then call a distributed barrier. + This avoids rank skew during the first optimizer.step(). + """ + if not self._states_initialized: + logger.info("Pre-initializing optimizer states in bfloat16...") + preinit_optimizer_states_bf16(self._base) + self._states_initialized = True + logger.info("BF16 optimizer state pre-initialization complete.") + + def step(self, *args, **kwargs) -> None: + # If states weren't pre-initialized, do it now (fallback) + if not self._states_initialized: + logger.warning( + "BF16 optimizer states not pre-initialized! " + "Call init_bf16_states() before training to avoid rank skew." + ) + self.init_bf16_states() + # Call base step + self._base.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs) -> None: + self._base.zero_grad(*args, **kwargs) + + def state_dict(self) -> dict[str, Any]: + return self._base.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._base.load_state_dict(state_dict) + + def __iter__(self): + return iter(self._base) + + def __len__(self): + return len(self._base) + + def __getattr__(self, name): + # Delegate all other attributes to base container + return getattr(self._base, name) + + class OptimizersContainer(Optimizer, Stateful, Generic[T]): """A container for multiple optimizers. @@ -504,7 +626,15 @@ def build_optimizers( use_ft_optimizer=ft_manager.use_async_quorum, ) - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + container = OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + # Wrap with BF16 state container if configured + state_dtype = getattr(optimizer_config, "state_dtype", "float32") + if state_dtype == "bfloat16": + logger.info("Using bfloat16 optimizer states (will convert after first step)") + return BF16StateOptimizersContainer(container, torch.bfloat16) + + return container def build_optimizers_with_moe_load_balancing( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index ff909e6ae9..643e6e27e8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -287,6 +287,13 @@ class Optimizer: use_triton: bool = False """Whether to use Triton kernel for Newton-Schulz in Muon optimizer.""" + state_dtype: Literal["float32", "bfloat16"] = "float32" + """ + Dtype for optimizer states (exp_avg, exp_avg_sq for Adam/AdamW). + Using bfloat16 reduces memory by ~50% but may affect training stability. + Only applies to Adam/AdamW optimizers. + """ + @dataclass class LRScheduler: @@ -370,6 +377,28 @@ class Training: Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP """ + enable_detailed_memory_tracking: bool = False + """ + Whether to enable detailed memory tracking at every training phase + """ + + clear_cache_between_steps: bool = False + """ + Whether to clear CUDA cache between training steps to measure minimum memory requirements + """ + + drop_page_cache_before_training: bool = False + """ + Whether to drop Linux page cache before training starts (after model/optimizer init). + This helps when using FSDP2 CPU offload with mmap'd model weights that fill page cache. + Requires root/sudo or appropriate permissions to write to /proc/sys/vm/drop_caches. + """ + + skip_optimizer_step: bool = False + """ + Whether to skip the optimizer step (for memory profiling purposes only) + """ + dtype: Literal["bfloat16", "float32"] = "float32" """ torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will @@ -402,6 +431,26 @@ class Training: many temporary files. """ + aggressive_memory_mode: Literal[ + "minimal", "balanced", "aggressive", "maximum" + ] | None = None + """ + Enable aggressive memory management to reduce CUDA memory fragmentation. + This clears CUDA cache and Python GC at strategic points (post-backward, post-optimizer). + Modes: + - None: Disabled (default) + - "minimal": Only clear on high fragmentation (<1% overhead) + - "balanced": Clear after backward and optimizer (2-3% overhead) + - "aggressive": Clear frequently with sync (5-8% overhead) + - "maximum": Clear after every operation (10-15% overhead, for debugging) + """ + + aggressive_memory_verbose: bool = False + """ + Enable verbose logging for aggressive memory manager. + Logs detailed memory stats after each clear operation. + """ + @dataclass class Parallelism: @@ -428,19 +477,22 @@ class Parallelism: only `data_parallel_shard_degree` can be negative. 1 means disabled. """ - fsdp_reshard_after_forward: Literal["default", "always", "never"] = "default" + fsdp_reshard_after_forward: Literal["default", "always", "never"] | int = "default" """ `reshard_after_forward` specifies the policy for applying `reshard_after_forward` within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. - The supported policies include "default", "always" and "never": + The supported policies include "default", "always", "never", or an integer: - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + - integer N: Partially reshard to groups of N GPUs after forward. Must be a factor of + the FSDP shard world size. Use N=8 for intra-node resharding (reduces memory while + keeping communication fast via NVLink). This trades memory for communication. """ tensor_parallel_degree: int = 1 @@ -555,6 +607,18 @@ class Parallelism: Note that this is still an experimental feature. """ + fsdp_bucket_cap_mb: int | None = None + """ + FSDP bucket size in MB for gradient reduction. None means use PyTorch default (25MB). + Smaller values (e.g., 20) can reduce peak memory at the cost of more communication overhead. + """ + + fsdp_disable_prefetch: bool = False + """ + Whether to disable FSDP forward/backward prefetching. Disabling prefetch can reduce memory + at the cost of performance (less overlap of communication and computation). + """ + @dataclass class DeepEP: @@ -1221,6 +1285,12 @@ class Debug: moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + enable_nan_tracker: bool = False + """If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model.""" + + nan_tracker_verbose: bool = False + """If True, print stats for every layer (very verbose output).""" + @dataclass class JobConfig: diff --git a/torchtitan/experiments/dion_optimizer/muon.py b/torchtitan/experiments/dion_optimizer/muon.py index 432ab1399f..0ac1602f71 100644 --- a/torchtitan/experiments/dion_optimizer/muon.py +++ b/torchtitan/experiments/dion_optimizer/muon.py @@ -609,27 +609,17 @@ def muon_update_batch_dim_sharded_async( - This is mathematically equivalent to orthogonalizing each expert's weights independently This function processes all params locally without all-to-all or all-gather. + + Optimized for CPU offloading with: + - Double-buffered CUDA streams to overlap transfer and compute + - Batched Newton-Schulz for fewer kernel launches + - Single sync point at end (no intermediate cuda.synchronize()) """ - U = muon_update_pre_orthogonalize( - G=G, - M=M, - momentum=momentum, - nesterov=nesterov, - ) - - # Orthogonalize each tensor locally - # Newton-Schulz treats dim 0 as batch, processing each slice independently - U = [ - muon_update_newton_schulz( - u, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) - for u in U - ] + # Check if we need CPU offloading (tensors are on CPU) + original_device = G[0].device + needs_gpu_transfer = original_device.type != "cuda" - # Compute scaled learning rate + # Compute scaled learning rate upfront # Use the first tensor's shape (they should all be the same shape within a batch) if adjust_lr is None: adjusted_lr = lr @@ -640,16 +630,132 @@ def muon_update_batch_dim_sharded_async( else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=X, - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if needs_gpu_transfer: + # PIPELINED MODE: Double-buffered streams for maximum overlap + # Timeline: transfer[i+1] overlaps with compute[i] overlaps with writeback[i-1] + cuda_device = torch.device("cuda") + dtype = M[0].dtype + n_tensors = len(X) + + # Mini-batch size for batched Newton-Schulz (fewer kernel launches) + BATCH_SIZE = 4 + + # Create streams: one for H2D transfers, one for compute, one for D2H transfers + h2d_stream = torch.cuda.Stream() + compute_stream = torch.cuda.Stream() + d2h_stream = torch.cuda.Stream() + + # Double buffer: prefetch next batch while computing current + prefetch_data = None # Will hold (g_batch, m_batch, x_batch, indices) for next iteration + + def prefetch_batch(start_idx): + """Prefetch a batch of tensors to GPU (non-blocking).""" + end_idx = min(start_idx + BATCH_SIZE, n_tensors) + indices = list(range(start_idx, end_idx)) + with torch.cuda.stream(h2d_stream): + g_batch = [G[i].to(dtype=dtype).to(cuda_device, non_blocking=True) for i in indices] + m_batch = [M[i].to(cuda_device, non_blocking=True) for i in indices] + x_batch = [X[i].to(cuda_device, non_blocking=True) for i in indices] + return (g_batch, m_batch, x_batch, indices) + + def compute_batch(g_batch, m_batch, x_batch, indices): + """Compute momentum update and Newton-Schulz on GPU.""" + with torch.cuda.stream(compute_stream): + # Wait for H2D transfer to complete (lightweight stream sync) + compute_stream.wait_stream(h2d_stream) + + u_batch = [] + for j in range(len(indices)): + g_gpu, m_gpu = g_batch[j], m_batch[j] + # Update momentum: M = mu * M + G + m_gpu.mul_(momentum) + m_gpu.add_(g_gpu) + # Compute U + if nesterov: + u_gpu = m_gpu * momentum + g_gpu + else: + u_gpu = m_gpu.clone() + u_batch.append(u_gpu.to(dtype=torch.bfloat16)) + + # Batched Newton-Schulz: stack same-shape tensors for single kernel + if len(u_batch) > 1 and all(u.shape == u_batch[0].shape for u in u_batch): + u_stacked = torch.stack(u_batch, dim=0) + u_stacked = muon_update_newton_schulz(u_stacked, newton_schulz_func, flatten, epsilon) + u_batch = list(u_stacked.unbind(0)) + else: + u_batch = [muon_update_newton_schulz(u, newton_schulz_func, flatten, epsilon) for u in u_batch] + + # Apply weight decay and update + for j in range(len(indices)): + x_batch[j].mul_(1 - lr * weight_decay) + x_batch[j].sub_(u_batch[j] * adjusted_lr) + + return m_batch, x_batch + + def writeback_batch(m_batch, x_batch, indices): + """Write results back to CPU (non-blocking).""" + with torch.cuda.stream(d2h_stream): + # Wait for compute to complete + d2h_stream.wait_stream(compute_stream) + for j, i in enumerate(indices): + M[i].copy_(m_batch[j], non_blocking=True) + X[i].copy_(x_batch[j], non_blocking=True) + + # Pipeline: prefetch first batch + if n_tensors > 0: + prefetch_data = prefetch_batch(0) + + # Main loop with double buffering + for batch_start in range(0, n_tensors, BATCH_SIZE): + # Get current batch (already prefetched) + g_batch, m_batch, x_batch, indices = prefetch_data + + # Start prefetching NEXT batch (overlaps with current compute) + next_start = batch_start + BATCH_SIZE + if next_start < n_tensors: + prefetch_data = prefetch_batch(next_start) + + # Compute current batch + m_batch, x_batch = compute_batch(g_batch, m_batch, x_batch, indices) + + # Writeback current batch (overlaps with next iteration's prefetch/compute) + writeback_batch(m_batch, x_batch, indices) + + # Single sync at end to ensure all D2H transfers complete + torch.cuda.synchronize() + + yield # Single yield to make this a generator + else: + # STANDARD GPU MODE: Process all tensors together (original behavior) + U = muon_update_pre_orthogonalize( + G=G, + M=M, + momentum=momentum, + nesterov=nesterov, + ) + + # Orthogonalize each tensor locally + # Newton-Schulz treats dim 0 as batch, processing each slice independently + U = [ + muon_update_newton_schulz( + u, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + for u in U + ] - yield # Single yield to make this a generator + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + yield # Single yield to make this a generator def muon_update_batch_async( @@ -673,146 +779,333 @@ def muon_update_batch_async( Batched version of Muon update. Batch size should be equal to number of GPUs. All tensors in a batch should have identical shape, sharding, and dtype. Identical hyperparameters are used for all tensors in the batch. + + Memory-optimized for CPU offloading: when tensors are on CPU, moves ALL computation + to GPU (momentum update, all_to_all, Newton-Schulz, weight update) then copies back. """ assert len(X) == len(G) assert len(X) == len(M) assert len(X) == world_size - # Update momentum and compute the inputs for orthogonalization - U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - momentum=momentum, - nesterov=nesterov, - ) - - # Get one whole matrix for each device to orthogonalize - if shard_dim is not None: - # Use all-to-all to transform from a batch of shards to a single whole matrix - # https://www.essential.ai/blog/infra - assert ( - process_group is not None - ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert not isinstance(U[0], DTensor), "U should contain local shards" - - # Debug: print full tensor info before the divisibility check - x0 = X[0] - x0_mesh = x0.device_mesh - x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} - - assert ( - X[0].size(shard_dim) % world_size == 0 - ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ - f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ - f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" - - # Allocate buffers to receive shards of one whole matrix from other devices - single_matrix_shards = [torch.empty_like(u) for u in U] - - # Redistribute the shards to form one unique full tensor on each device - # Sync CUDA before collective to ensure all prior GPU ops are complete - # This can prevent NCCL hangs due to async GPU operations + # Check early if we're in CPU offloading mode + G_local = to_local(G) + M_local = to_local(M) + X_local = to_local(X) + original_device = M_local[0].device + needs_gpu_transfer = original_device.type != "cuda" + + if needs_gpu_transfer: + # ====== CPU OFFLOADING PATH: Do ALL computation on GPU ====== + # This avoids slow CPU foreach operations for momentum and weight updates + cuda_device = torch.device("cuda") + dtype = M_local[0].dtype + + # Transfer G, M to GPU for momentum update + G_gpu = [g.to(dtype=dtype).to(cuda_device, non_blocking=True) for g in G_local] + M_gpu = [m.to(cuda_device, non_blocking=True) for m in M_local] torch.cuda.synchronize() - # N sequential all_gathers - only keep result for our assigned param - single_matrix_shards = None - for param_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + # Momentum update on GPU (equivalent to muon_update_pre_orthogonalize) + torch._foreach_mul_(M_gpu, momentum) + torch._foreach_add_(M_gpu, G_gpu) - # All ranks send their shard of param_idx - dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + if nesterov: + U_gpu = torch._foreach_mul(M_gpu, momentum) + torch._foreach_add_(U_gpu, G_gpu) + else: + # U shares memory with M when not using nesterov + U_gpu = M_gpu + + # Free G_gpu - no longer needed + del G_gpu + + # Convert to bfloat16 for communication + U_gpu = [u.to(dtype=torch.bfloat16) for u in U_gpu] + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + + # Validation + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Make contiguous for all_to_all + U_gpu = [u.contiguous() for u in U_gpu] + + # First all_to_all: batch of shards -> single whole matrix + single_matrix_shards = [torch.empty_like(U_gpu[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, U_gpu, group=process_group) + del U_gpu - # Only keep if this is our assigned parameter - if param_idx == device_rank: - single_matrix_shards = gathered - # Otherwise 'gathered' goes out of scope and memory can be freed + yield - yield + # Concatenate shards to form whole matrix + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards - # Concatentate shards to form a whole matrix to orthogonalize - single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) + # Newton-Schulz orthogonalization (on GPU) + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Split result back into shards - # Contiguous is needed for communication to work correctly - orth_shards = [ - x.contiguous() - for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) - ] + # Split result back into shards + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix - # N sequential all_gathers - collect results as we go - for shard_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + # Second all_to_all to redistribute orthogonalized shards + U_orth_gpu = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(U_orth_gpu, orth_shards, group=process_group) + del orth_shards - # All ranks send their shard at index shard_idx - dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + yield - # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} - # We need U[r] = O^r_{device_rank} - # So when shard_idx == device_rank: U[r] = gathered[r] for all r - if shard_idx == device_rank: - for r in range(world_size): - U[r].copy_(gathered[r]) + else: + # Matrices are not sharded, orthogonalize directly + single_matrix = U_gpu[device_rank] + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + if process_group is not None and process_group.size() > 1: + U_orth_gpu = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_orth_gpu, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + else: + assert world_size == 1 + U_orth_gpu = [single_matrix] + + # Compute scaled learning rate (use full tensor shape from X[0]) + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Transfer X to GPU for weight update + X_gpu = [x.to(cuda_device, non_blocking=True) for x in X_local] + torch.cuda.synchronize() + + # Weight update on GPU (equivalent to muon_update_post_orthogonalize) + torch._foreach_mul_(X_gpu, 1 - lr * weight_decay) + U_scaled = torch._foreach_mul(U_orth_gpu, adjusted_lr) + torch._foreach_sub_(X_gpu, U_scaled) + del U_scaled, U_orth_gpu + + # Copy M and X back to CPU + for i in range(world_size): + M_local[i].copy_(M_gpu[i], non_blocking=True) + X_local[i].copy_(X_gpu[i], non_blocking=True) - yield + torch.cuda.synchronize() + del M_gpu, X_gpu else: - # Matrices are not sharded, so we can directly orthogonalize - # Get a single matrix corresponding to this device - single_matrix = U[device_rank] - assert not isinstance(single_matrix, DTensor) - - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, + # ====== STANDARD GPU PATH ====== + # Update momentum and compute the inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=G_local, + M=M_local, + momentum=momentum, + nesterov=nesterov, ) - if process_group is not None and process_group.size() > 1: - # Allocate empty tensors to receive updates from other devices - U = [torch.empty_like(u) for u in U] + # Get one whole matrix for each device to orthogonalize + # JQ: This is the N sequential gather version + # if shard_dim is not None: + # # Use all-to-all to transform from a batch of shards to a single whole matrix + # # https://www.essential.ai/blog/infra + # assert ( + # process_group is not None + # ), "process_group must be provided for sharded DTensors" + # assert isinstance(X[0], DTensor), "X should contain DTensors" + # assert not isinstance(U[0], DTensor), "U should contain local shards" + + # # Debug: print full tensor info before the divisibility check + # x0 = X[0] + # x0_mesh = x0.device_mesh + # x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + + # assert ( + # X[0].size(shard_dim) % world_size == 0 + # ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + # f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + # f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # # Allocate buffers to receive shards of one whole matrix from other devices + # single_matrix_shards = [torch.empty_like(u) for u in U] + + # # Redistribute the shards to form one unique full tensor on each device + # # Sync CUDA before collective to ensure all prior GPU ops are complete + # # This can prevent NCCL hangs due to async GPU operations + # torch.cuda.synchronize() + + # # N sequential all_gathers - only keep result for our assigned param + # single_matrix_shards = None + # for param_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + + # # All ranks send their shard of param_idx + # dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + + # # Only keep if this is our assigned parameter + # if param_idx == device_rank: + # single_matrix_shards = gathered + # # Otherwise 'gathered' goes out of scope and memory can be freed + + # yield + + # # Concatentate shards to form a whole matrix to orthogonalize + # single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + # single_matrix = muon_update_newton_schulz( + # single_matrix, + # newton_schulz_func=newton_schulz_func, + # flatten=flatten, + # epsilon=epsilon, + # ) + + # # Split result back into shards + # # Contiguous is needed for communication to work correctly + # orth_shards = [ + # x.contiguous() + # for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + # ] + + # # N sequential all_gathers - collect results as we go + # for shard_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + + # # All ranks send their shard at index shard_idx + # dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + + # # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} + # # We need U[r] = O^r_{device_rank} + # # So when shard_idx == device_rank: U[r] = gathered[r] for all r + # if shard_idx == device_rank: + # for r in range(world_size): + # U[r].copy_(gathered[r]) + + # yield + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert not isinstance(U[0], DTensor), "U should contain local shards" + + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Sync CUDA before collective to prevent NCCL hangs from async GPU ops + torch.cuda.synchronize() + + single_matrix_shards = [torch.empty_like(U[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, [u.contiguous() for u in U], group=process_group) - # All gather orthogonalized results from other devices into buffer - work = dist.all_gather( - U, single_matrix.contiguous(), group=process_group, async_op=True + yield + + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, ) + + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix + + output_shards = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(output_shards, orth_shards, group=process_group) + del orth_shards + + for i in range(world_size): + U[i].copy_(output_shards[i]) + del output_shards + yield - work.wait() else: - # Single GPU case, no need to gather - assert world_size == 1 - U = [single_matrix] - - # Compute scaled learning rate - # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + single_matrix = U[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=to_local(X), - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if process_group is not None and process_group.size() > 1: + U_gathered = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_gathered, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + U = U_gathered + else: + assert world_size == 1 + U = [single_matrix] + + # Compute scaled learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X_local, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) def adamw_update_foreach_async( diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 8873ef2f90..c2b3f51ef2 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -46,9 +46,10 @@ class FlexAttentionWrapper(torch.nn.Module): block_mask as a keyword argument to be compatible with _ContextParallel. """ + # Using dynamic=True to avoid CUDA crashes with CP, but need to debug NaN issue _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, - mode="max-autotune-no-cudagraphs", + dynamic=True, ) def forward( diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index eedc20cbb5..452a873fb0 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -50,6 +50,52 @@ v_head_dim=128, mscale=0.70, ), + "debugmodel_1b": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=1024, + inter_dim=4096, + moe_inter_dim=1024, + n_layers=24, + n_dense_layers=2, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "debugmodel_7b": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=2048, + inter_dim=8192, + moe_inter_dim=2048, + n_layers=32, + n_dense_layers=3, + n_heads=32, + moe_args=MoEArgs( + num_experts=16, + num_shared_experts=4, + top_k=4, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), "debugmodel_flex_attn": DeepSeekV3ModelArgs( vocab_size=2048, dim=256, @@ -75,6 +121,58 @@ use_flex_attn=True, attn_mask_type="block_causal", ), + # debugmodel_1b with FlexAttention for CP NaN testing + "debugmodel_1b_flex_attn": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=1024, + inter_dim=4096, + moe_inter_dim=1024, + n_layers=24, + n_dense_layers=2, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, # Enable FlexAttention + attn_mask_type="block_causal", # Same as kimi_k2 + ), + # Debug model with FlexAttention + causal-only mask for CP testing + "debugmodel_flex_attn_causal": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=6, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="causal", # Simple causal, no document mask + ), "16B": DeepSeekV3ModelArgs( vocab_size=129280, dim=2048, @@ -186,6 +284,38 @@ rope_factor=32.0, beta_fast=1, ), + # kimi_k2 with SDPA (no FlexAttention) for Context Parallel compatibility + "kimi_k2_sdpa": DeepSeekV3ModelArgs( + vocab_size=163840, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=1, + n_heads=64, + norm_eps=1e-6, + moe_args=MoEArgs( + num_experts=384, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.827, + score_before_experts=False, + ), + n_expert_groups=1, + n_limited_groups=1, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=False, # Use SDPA for CP compatibility + attn_mask_type="causal", + rope_theta=50000.0, + rope_factor=32.0, + beta_fast=1, + ), } diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7da79c361e..5009609281 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -62,8 +62,6 @@ def parallelize_deepseekv3( """ use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters @@ -108,7 +106,11 @@ def parallelize_deepseekv3( job_config.compile.enable and "model" in job_config.compile.components ) - if job_config.activation_checkpoint.mode != "none": + # Apply activation checkpointing or CPU offloading + if ( + job_config.activation_checkpoint.mode != "none" + or job_config.activation_checkpoint.cpu_offload + ): apply_ac( model, job_config.activation_checkpoint, @@ -152,6 +154,8 @@ def parallelize_deepseekv3( else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + bucket_cap_mb=job_config.parallelism.fsdp_bucket_cap_mb, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 1a6ff3cf6e..dd9056ee91 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,7 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 - mscale_all_dim: float = 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + mscale_all_dim: float = ( + 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + ) def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len @@ -102,11 +104,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 67bb37480e..29591c81ab 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,12 +5,20 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Optional import torch from torch import nn +from torch.distributed.device_mesh import DeviceMesh from torch.nn.attention.flex_attention import and_masks, BlockMask +# Import CP block mask creator for Context Parallel + FlexAttention +try: + from torch.distributed.tensor.experimental._attention import create_cp_block_mask +except ImportError: + create_cp_block_mask = None + from torchtitan.components.peft.lora import lora_or_linear, per_layer_config from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config.job_config import PEFT @@ -22,7 +30,12 @@ get_document_mask_mod, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import ( + fast_init_normal_, + fast_init_trunc_normal_, + FeedForward, + MoE, +) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -325,8 +338,8 @@ def init_weights(self, init_std: float): linear_list.append(self.wq) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.wo.weight, mean=0.0, std=init_std) self.kv_norm.reset_parameters() if self.q_lora_rank > 0: @@ -456,7 +469,8 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: with torch.device(buffer_device): self.freqs_cis = precompute_freqs_cis(self.model_args) if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: layer.init_weights(buffer_device=buffer_device) @@ -465,7 +479,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 if self.output is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.output.weight, mean=0.0, std=final_out_std, @@ -478,6 +492,7 @@ def get_attention_masks( input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, + cp_mesh: Optional[DeviceMesh] = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] match self.model_args.attn_mask_type: @@ -500,15 +515,33 @@ def get_attention_masks( raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) - return create_attention_mask( - and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] - ) + + combined_mask_mod = and_masks(*mask_mods) + seq_len = input_batch.shape[1] + H = self.model_args.n_heads # Number of attention heads + + # Use CP-aware block mask when Context Parallel is enabled + if cp_mesh is not None: + if create_cp_block_mask is None: + raise RuntimeError("Cannot do context parallel without a PyTorch that supports `create_cp_block_mask`") + return create_cp_block_mask( + mask_mod=combined_mask_mod, + B=B, + H=H, + Q_LEN=seq_len, + KV_LEN=seq_len, + device_mesh=cp_mesh, + ) + else: + return create_attention_mask(combined_mask_mod, B, None, seq_len, seq_len) def forward( self, tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, position_ids: torch.Tensor | None = None, + return_outputs: bool = False, # For pipeline parallelism compatibility + **kwargs, # Accept additional kwargs for PP compatibility ): """ Forward pass for the Transformer model. diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml new file mode 100644 index 0000000000..e3e1c9d497 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml @@ -0,0 +1,72 @@ +# Kimi K2 - 12 nodes - EP=96, CP=16, LBS=11 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep/configs/exp1acd_12n_ep96_cp16_lbs11.toml +# Job ID: 2307 +# +# Expected Performance: +# - TPS: 402 +# - Memory: 67.55 GiB (85.2%) +# - MFU: 17.72% +# - TFLOPS: ~175 +# +# Parallelism: EP=96, CP=16, DP=1 (dp_replicate=1, dp_shard=1) +# Nodes: 12 (96 GPUs) +# Seq Length: 32768 +# Local Batch Size: 11 +# + +[job] +dump_folder = "./outputs/kimi_k2/12n_ep96_cp16_lbs11" +description = "Kimi K2 - 12n EP=96 CP=16 LBS=11" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 11 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml new file mode 100644 index 0000000000..7ee95edec6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml @@ -0,0 +1,73 @@ +# Kimi K2 - 36 nodes - EP=96, CP=16, HSDP (dp_replicate=3, dp_shard=6), LBS=10 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep_dp/configs/exp1aj_HSDP_r3_s6_lbs10.toml +# Job ID: 2485 +# +# Expected Performance: +# - TPS: 378 +# - Memory: 69.45 GiB (87.6%) +# - MFU: 16.64% +# +# Parallelism: EP=96, CP=16, dp_replicate=3, dp_shard=6 +# HSDP: Shard within 12 nodes, all-reduce between 3 replica groups +# Nodes: 36 (288 GPUs) +# Seq Length: 32768 +# Local Batch Size: 10 +# + +[job] +dump_folder = "./outputs/kimi_k2/36n_ep96_cp16_hsdp_replicate3_shard6_lbs10" +description = "Kimi K2 - 36n HSDP dp_replicate=3 dp_shard=6 LBS=10" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 10 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 3 +data_parallel_shard_degree = 6 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index efc73ccc0e..f5899e2de1 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -177,6 +177,8 @@ def parallelize_llama( else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + bucket_cap_mb=job_config.parallelism.fsdp_bucket_cap_mb, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -298,10 +300,12 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, - reshard_after_forward_policy: str = "default", + reshard_after_forward_policy: str | int = "default", ep_degree: int = 1, dp_mod_ep_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, + bucket_cap_mb: int | None = None, + disable_prefetch: bool = False, ): """ Apply data parallelism (via FSDP2) to the model. @@ -313,31 +317,50 @@ def apply_fsdp( reduce_dtype (torch.dtype): The data type to use for reduction operations. pp_enabled (bool): Whether pipeline parallelism is enabled. cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. - reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". - Other options: "never", "always". + reshard_after_forward_policy (str | int, optional): The policy to use for + resharding after forward pass. Defaults to "default". + String options: "never", "always", "default". - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + Integer option: N (e.g., 8) for partial resharding to N-GPU groups. + - Reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. + - Use N=8 for intra-node resharding (fast NVLink communication). + - N must be a factor of the FSDP shard world size. + bucket_cap_mb (int | None, optional): FSDP bucket size in MB for gradient + reduction. None means use PyTorch default (25MB). Defaults to None. + disable_prefetch (bool, optional): Whether to disable FSDP forward/backward prefetching. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() + if bucket_cap_mb is not None: + logger.warning( + f"bucket_cap_mb={bucket_cap_mb} requested but not supported by FSDP2 fully_shard() API - ignoring" + ) - match reshard_after_forward_policy: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." - ) + # Handle integer reshard_after_forward (partial resharding to N-GPU groups) + if isinstance(reshard_after_forward_policy, int): + reshard_after_forward = reshard_after_forward_policy + logger.info( + f"Using partial reshard_after_forward={reshard_after_forward} (resharding to {reshard_after_forward}-GPU groups)" + ) + else: + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) if model.tok_embeddings is not None: fully_shard( @@ -461,6 +484,11 @@ def apply_fsdp( if ep_degree == 1: return + # Skip prefetch setup if disabled + if disable_prefetch: + logger.info("FSDP prefetching is disabled") + return + # forward transformer_blocks = list(model.layers.values()) next_transformer_blocks = transformer_blocks[1:] + [None] diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index c932f6aa83..fffd491ae4 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import ExpertRoutingHistogram, FeedForward, MoE, MoEArgs +from .moe import ( + ExpertRoutingHistogram, + fast_init_normal_, + fast_init_trunc_normal_, + FeedForward, + MoE, + MoEArgs, +) -__all__ = ["FeedForward", "MoE", "MoEArgs", "ExpertRoutingHistogram"] +__all__ = [ + "FeedForward", + "MoE", + "MoEArgs", + "ExpertRoutingHistogram", + "fast_init_trunc_normal_", + "fast_init_normal_", +] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 46c2aa4484..af1f32a7bf 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -34,10 +34,50 @@ class ExpertRoutingHistogram: counts: list[float] +# see https://arxiv.org/pdf/2310.10837 def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 +def fast_init_trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + """ + Fast truncated normal initialization that handles bfloat16 tensors on CPU. + + When tensors are bfloat16 on CPU, nn.init.trunc_normal_ is extremely slow + because CPUs don't have native bfloat16 support. This function temporarily + converts to float32 for the initialization, then converts back. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + # Initialize in float32 for CPU performance + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.trunc_normal_(temp, mean=mean, std=std, a=a, b=b) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) + + +def fast_init_normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0 +) -> None: + """ + Fast normal initialization that handles bfloat16 tensors on CPU. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.normal_(temp, mean=mean, std=std) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.normal_(tensor, mean=mean, std=std) + + @dataclass class MoEArgs: num_experts: int = 8 @@ -174,9 +214,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=init_std) # NOTE: keeping this for-loop implementation for comparison @@ -456,9 +496,9 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) def _groupmm(x, w, offs): @@ -602,12 +642,12 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) - nn.init.trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) nn.init.zeros_(self.w1_lora_b) nn.init.zeros_(self.w2_lora_b) nn.init.zeros_(self.w3_lora_b) @@ -728,11 +768,15 @@ def forward( return top_scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float, n_layers: int): + # Init gate with each row normalized + # From "Approximating Two-Layer Feedforward Networks for Efficient Transformers" + # https://arxiv.org/pdf/2310.10837 + # NOTE: Must use in-place operations here. When FSDP wraps parameters as # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss # when CPU offload is enabled with 3+ GPUs. - nn.init.normal_(self.gate.weight, mean=0.0, std=1.0) + fast_init_normal_(self.gate.weight, mean=0.0, std=1.0) # Normalize rows in-place with torch.no_grad(): @@ -1053,7 +1097,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i if self.shared_experts is not None: self.shared_experts.init_weights(init_std) if self.shared_gate is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.shared_gate.weight, mean=0.0, std=moe_init_std(self.shared_gate.weight.shape[1], n_layers), diff --git a/torchtitan/tools/aggressive_memory_manager.py b/torchtitan/tools/aggressive_memory_manager.py new file mode 100644 index 0000000000..1c4861cb74 --- /dev/null +++ b/torchtitan/tools/aggressive_memory_manager.py @@ -0,0 +1,414 @@ +""" +Aggressive Memory Manager for reducing CUDA memory fragmentation. + +This module provides aggressive memory clearing strategies to minimize +fragmentation and allocation retries during distributed training. + +Usage: + from torchtitan.tools.aggressive_memory_manager import AggressiveMemoryManager + + # Initialize at start of training + mem_manager = AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + sync_before_clear=True, + defrag_threshold_mb=1000, # Defrag if fragmentation > 1GB + ) + + # In training loop: + loss.backward() + mem_manager.post_backward() + + optimizer.step() + mem_manager.post_optimizer() + + mem_manager.step_complete() +""" + +import gc +import os +import time +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + +from torchtitan.tools.logging import logger + + +@dataclass +class MemoryStats: + """Current memory statistics""" + + allocated: int + reserved: int + active: int + fragmentation: int + fragmentation_pct: float + num_alloc_retries: int + + +class AggressiveMemoryManager: + """ + Aggressive memory management to minimize CUDA memory fragmentation. + + Key strategies: + 1. Clear cache at strategic points (post-backward, post-optimizer) + 2. Synchronize before clearing to ensure all async ops complete + 3. Force garbage collection to release Python references + 4. Monitor fragmentation and trigger defrag when threshold exceeded + 5. Set optimal allocator configuration + + Args: + clear_after_backward: Clear cache after backward pass + clear_after_optimizer: Clear cache after optimizer step + clear_every_n_steps: Only clear every N steps (1 = every step) + sync_before_clear: Synchronize CUDA before clearing cache + defrag_threshold_mb: Trigger defrag if fragmentation exceeds this (MB) + gc_generation: Python GC generation to collect (0-2, higher = more thorough) + verbose: Log detailed memory stats + rank: Distributed rank (auto-detected if None) + """ + + def __init__( + self, + clear_after_backward: bool = True, + clear_after_optimizer: bool = True, + clear_every_n_steps: int = 1, + sync_before_clear: bool = True, + defrag_threshold_mb: float = 500.0, + gc_generation: int = 1, + verbose: bool = False, + rank: Optional[int] = None, + ): + self.clear_after_backward = clear_after_backward + self.clear_after_optimizer = clear_after_optimizer + self.clear_every_n_steps = clear_every_n_steps + self.sync_before_clear = sync_before_clear + self.defrag_threshold_mb = defrag_threshold_mb + self.gc_generation = gc_generation + self.verbose = verbose + + self.rank = ( + rank + if rank is not None + else (dist.get_rank() if dist.is_initialized() else 0) + ) + + self.step_count = 0 + self.total_clears = 0 + self.total_defrag_time_ms = 0.0 + + # Disable automatic GC - we'll control it manually + gc.disable() + + # Initial cleanup + self._aggressive_clear("initialization") + + if self.rank == 0: + logger.info( + f"[AggressiveMemoryManager] Initialized: " + f"clear_backward={clear_after_backward}, " + f"clear_optimizer={clear_after_optimizer}, " + f"every_n_steps={clear_every_n_steps}, " + f"sync={sync_before_clear}, " + f"defrag_threshold={defrag_threshold_mb}MB" + ) + + @staticmethod + def configure_allocator( + expandable_segments: bool = True, + max_split_size_mb: int = 128, + garbage_collection_threshold: float = 0.8, + roundup_power2_divisions: int = 4, + ) -> str: + """ + Configure PyTorch CUDA allocator for minimal fragmentation. + + Call this BEFORE any CUDA operations (before model creation). + + Args: + expandable_segments: Enable expandable memory segments + max_split_size_mb: Max size of memory splits (smaller = less fragmentation) + garbage_collection_threshold: Trigger GC when this fraction of memory is fragmented + roundup_power2_divisions: Memory rounding granularity + + Returns: + The PYTORCH_CUDA_ALLOC_CONF string that was set + """ + config_parts = [] + + if expandable_segments: + config_parts.append("expandable_segments:True") + + config_parts.append(f"max_split_size_mb:{max_split_size_mb}") + config_parts.append( + f"garbage_collection_threshold:{garbage_collection_threshold}" + ) + config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") + + config_str = ",".join(config_parts) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_str + + return config_str + + def get_memory_stats(self) -> MemoryStats: + """Get current memory statistics""" + if not torch.cuda.is_available(): + return MemoryStats(0, 0, 0, 0, 0.0, 0) + + stats = torch.cuda.memory_stats() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + active = stats.get("active_bytes.all.current", 0) + fragmentation = reserved - allocated + fragmentation_pct = (fragmentation / reserved * 100) if reserved > 0 else 0.0 + num_retries = stats.get("num_alloc_retries", 0) + + return MemoryStats( + allocated=allocated, + reserved=reserved, + active=active, + fragmentation=fragmentation, + fragmentation_pct=fragmentation_pct, + num_alloc_retries=num_retries, + ) + + def _should_clear(self) -> bool: + """Check if we should clear cache this step""" + return self.step_count % self.clear_every_n_steps == 0 + + def _aggressive_clear(self, reason: str) -> float: + """ + Perform aggressive memory clearing. + + Returns: + Time taken in milliseconds + """ + if not torch.cuda.is_available(): + return 0.0 + + start = time.perf_counter() + + # 1. Synchronize all CUDA streams to ensure ops complete + if self.sync_before_clear: + torch.cuda.synchronize() + + # 2. Python garbage collection (releases tensor references) + gc.collect(self.gc_generation) + + # 3. Clear CUDA cache (releases unused cached memory) + torch.cuda.empty_cache() + + # 4. Optional: Force synchronization after clear + if self.sync_before_clear: + torch.cuda.synchronize() + + elapsed_ms = (time.perf_counter() - start) * 1000 + self.total_clears += 1 + self.total_defrag_time_ms += elapsed_ms + + if self.verbose and self.rank == 0: + stats = self.get_memory_stats() + logger.info( + f"[AggressiveMemoryManager] {reason}: " + f"cleared in {elapsed_ms:.1f}ms, " + f"frag={stats.fragmentation_pct:.1f}%, " + f"reserved={stats.reserved/1e9:.2f}GB" + ) + + return elapsed_ms + + def _check_and_defrag(self, phase: str) -> bool: + """ + Check fragmentation and defrag if needed. + + Returns: + True if defrag was triggered + """ + stats = self.get_memory_stats() + fragmentation_mb = stats.fragmentation / (1024 * 1024) + + if fragmentation_mb > self.defrag_threshold_mb: + self._aggressive_clear(f"defrag_{phase}_frag={fragmentation_mb:.0f}MB") + return True + + return False + + def post_backward(self): + """Call after backward pass completes""" + if self.clear_after_backward and self._should_clear(): + self._check_and_defrag("post_backward") + self._aggressive_clear("post_backward") + + def post_optimizer(self): + """Call after optimizer step completes""" + if self.clear_after_optimizer and self._should_clear(): + self._check_and_defrag("post_optimizer") + self._aggressive_clear("post_optimizer") + + def step_complete(self): + """Call at the end of each training step""" + self.step_count += 1 + + # Always check for high fragmentation + self._check_and_defrag("step_end") + + def get_summary(self) -> str: + """Get summary of memory management activity""" + avg_time = self.total_defrag_time_ms / max(1, self.total_clears) + return ( + f"AggressiveMemoryManager Summary:\n" + f" Total clears: {self.total_clears}\n" + f" Total defrag time: {self.total_defrag_time_ms:.1f}ms\n" + f" Avg time per clear: {avg_time:.2f}ms\n" + f" Steps processed: {self.step_count}" + ) + + +class BackwardMemoryHook: + """ + Register hooks on model parameters to clear memory during backward pass. + + This clears memory incrementally as gradients are computed, rather than + waiting until the end of backward. + + Args: + clear_every_n_params: Clear cache after every N parameter gradients + sync_on_clear: Synchronize before clearing (slower but more thorough) + """ + + def __init__( + self, + clear_every_n_params: int = 10, + sync_on_clear: bool = False, + ): + self.clear_every_n_params = clear_every_n_params + self.sync_on_clear = sync_on_clear + self.param_count = 0 + self.handles = [] + + def _backward_hook(self, grad): + """Hook called when gradient is computed for a parameter""" + self.param_count += 1 + + if self.param_count % self.clear_every_n_params == 0: + if self.sync_on_clear: + torch.cuda.synchronize() + gc.collect(0) # Fast GC (generation 0 only) + torch.cuda.empty_cache() + + return grad + + def register(self, model: torch.nn.Module): + """Register hooks on all model parameters""" + for name, param in model.named_parameters(): + if param.requires_grad: + handle = param.register_post_accumulate_grad_hook( + lambda p, name=name: self._backward_hook(p.grad) + ) + self.handles.append(handle) + + logger.info( + f"[BackwardMemoryHook] Registered on {len(self.handles)} parameters, " + f"clearing every {self.clear_every_n_params} params" + ) + + def remove(self): + """Remove all registered hooks""" + for handle in self.handles: + handle.remove() + self.handles.clear() + + def reset_count(self): + """Reset parameter count (call at start of each backward)""" + self.param_count = 0 + + +def setup_aggressive_memory_environment(): + """ + Set up environment variables for aggressive memory management. + + Call this BEFORE importing torch or creating any CUDA tensors. + """ + # Optimal allocator settings for minimal fragmentation + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( + "expandable_segments:True," + "max_split_size_mb:128," + "garbage_collection_threshold:0.8," + "roundup_power2_divisions:4" + ) + + # Disable NCCL async error handling (can cause memory issues) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + # Force synchronous CUDA operations for debugging + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # Uncomment for debugging + + return os.environ.get("PYTORCH_CUDA_ALLOC_CONF") + + +# Convenience function for quick setup +def create_aggressive_memory_manager( + mode: str = "balanced", + verbose: bool = False, +) -> AggressiveMemoryManager: + """ + Create an AggressiveMemoryManager with preset configurations. + + Args: + mode: One of: + - "minimal": Only clear on high fragmentation + - "balanced": Clear after backward and optimizer + - "aggressive": Clear frequently with sync + - "maximum": Clear after every operation + verbose: Enable verbose logging + + Returns: + Configured AggressiveMemoryManager + """ + if mode == "minimal": + return AggressiveMemoryManager( + clear_after_backward=False, + clear_after_optimizer=False, + clear_every_n_steps=10, + sync_before_clear=False, + defrag_threshold_mb=2000, + gc_generation=0, + verbose=verbose, + ) + elif mode == "balanced": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=False, + defrag_threshold_mb=500, + gc_generation=1, + verbose=verbose, + ) + elif mode == "aggressive": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=200, + gc_generation=2, + verbose=verbose, + ) + elif mode == "maximum": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=100, + gc_generation=2, + verbose=verbose, + ) + else: + raise ValueError( + f"Unknown mode: {mode}. Use minimal/balanced/aggressive/maximum" + ) diff --git a/torchtitan/tools/cuda_memory_tracker.py b/torchtitan/tools/cuda_memory_tracker.py new file mode 100644 index 0000000000..0f7d7af5f4 --- /dev/null +++ b/torchtitan/tools/cuda_memory_tracker.py @@ -0,0 +1,123 @@ +"""Track CUDA memory directly from nvidia-smi and PyTorch""" +import logging +import subprocess +from typing import Dict, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class CUDAMemoryTracker: + """Track memory from both PyTorch and CUDA/nvidia-smi""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.device = torch.cuda.current_device() + self.device_name = torch.cuda.get_device_name(self.device) + + if self.enabled: + logger.info( + f"CUDAMemoryTracker enabled for device {self.device}: {self.device_name}" + ) + + def get_nvidia_smi_memory(self) -> Optional[Dict[str, int]]: + """Get memory from nvidia-smi""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.free,memory.total", + "--format=csv,noheader,nounits", + "-i", + str(self.device), + ], + capture_output=True, + text=True, + timeout=2, + ) + + if result.returncode == 0: + used, free, total = map(int, result.stdout.strip().split(",")) + return {"used_mb": used, "free_mb": free, "total_mb": total} + except Exception as e: + logger.warning(f"Failed to get nvidia-smi memory: {e}") + + return None + + def get_pytorch_memory(self) -> Dict[str, int]: + """Get memory from PyTorch""" + stats = torch.cuda.memory_stats(self.device) + + return { + "reserved_bytes": torch.cuda.memory_reserved(self.device), + "allocated_bytes": torch.cuda.memory_allocated(self.device), + "active_bytes": stats.get("active_bytes.all.current", 0), + "inactive_bytes": stats.get("inactive_split_bytes.all.current", 0), + "peak_active_bytes": stats.get("active_bytes.all.peak", 0), + "num_alloc_retries": stats.get("num_alloc_retries.all.current", 0), + "num_ooms": stats.get("num_ooms.all.current", 0), + } + + def get_cuda_device_memory(self) -> Dict[str, int]: + """Get memory directly from CUDA device properties""" + props = torch.cuda.get_device_properties(self.device) + + return { + "total_memory": props.total_memory, + "reserved_memory": torch.cuda.memory_reserved(self.device), + "allocated_memory": torch.cuda.memory_allocated(self.device), + } + + def measure_all(self, phase: str, step: int): + """Comprehensive memory measurement""" + if not self.enabled: + return + + # PyTorch memory + pytorch_mem = self.get_pytorch_memory() + + # CUDA device memory + cuda_mem = self.get_cuda_device_memory() + + # nvidia-smi memory (if available) + smi_mem = self.get_nvidia_smi_memory() + + # Calculate fragmentation + reserved = pytorch_mem["reserved_bytes"] + allocated = pytorch_mem["allocated_bytes"] + active = pytorch_mem["active_bytes"] + + fragmentation = reserved - allocated + frag_pct = (fragmentation / reserved * 100) if reserved > 0 else 0 + + # Log PyTorch view + logger.info( + f"[PyTorch] Step {step:2d} | {phase:25s} | " + f"Reserved: {reserved/1e9:6.2f} GB | " + f"Allocated: {allocated/1e6:8.2f} MB | " + f"Active: {active/1e6:8.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + # Log CUDA/nvidia-smi view + if smi_mem: + logger.info( + f"[CUDA-SMI] Step {step:2d} | {phase:25s} | " + f"Used: {smi_mem['used_mb']/1024:6.2f} GB | " + f"Free: {smi_mem['free_mb']/1024:6.2f} GB | " + f"Total: {smi_mem['total_mb']/1024:6.2f} GB" + ) + + # Log comparison + if smi_mem: + pytorch_used_gb = reserved / 1e9 + smi_used_gb = smi_mem["used_mb"] / 1024 + diff_gb = smi_used_gb - pytorch_used_gb + + logger.info( + f"[Compare] Step {step:2d} | {phase:25s} | " + f"PyTorch reports: {pytorch_used_gb:6.2f} GB | " + f"nvidia-smi reports: {smi_used_gb:6.2f} GB | " + f"Diff: {diff_gb:+6.2f} GB" + ) diff --git a/torchtitan/tools/detailed_memory_tracker.py b/torchtitan/tools/detailed_memory_tracker.py new file mode 100644 index 0000000000..7b513b3e20 --- /dev/null +++ b/torchtitan/tools/detailed_memory_tracker.py @@ -0,0 +1,160 @@ +"""Detailed memory tracking throughout training step""" +import logging +from typing import Dict, List + +import torch + +logger = logging.getLogger(__name__) + + +class DetailedMemoryTracker: + """Track memory at every phase of training with cache clearing""" + + def __init__(self, enabled: bool = True, clear_cache: bool = True): + self.enabled = enabled + self.clear_cache_between_steps = clear_cache + self.measurements: List[Dict] = [] + self.device = torch.cuda.current_device() + + if self.enabled: + logger.info(f"DetailedMemoryTracker enabled (clear_cache={clear_cache})") + + def measure(self, phase: str, step: int): + """Capture memory state at a specific phase""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + measurement = { + "step": step, + "phase": phase, + "reserved": torch.cuda.memory_reserved(self.device), + "allocated": torch.cuda.memory_allocated(self.device), + "active": stats.get("active_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc_retries.all.current", 0), + } + + self.measurements.append(measurement) + + # Calculate fragmentation + fragmentation = measurement["reserved"] - measurement["allocated"] + frag_pct = ( + (fragmentation / measurement["reserved"] * 100) + if measurement["reserved"] > 0 + else 0 + ) + + logger.info( + f"[MemTrack] Step {step} | {phase:20s} | " + f"Reserved: {measurement['reserved']/1e9:6.2f} GB | " + f"Allocated: {measurement['allocated']/1e6:7.2f} MB | " + f"Active: {measurement['active']/1e6:7.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + def clear_cache_and_measure(self, phase: str, step: int): + """Clear cache and measure to see minimum memory""" + if not self.enabled: + return + + # Measure before clearing + self.measure(f"{phase}_before_clear", step) + + # Clear cache + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Measure after clearing + self.measure(f"{phase}_after_clear", step) + + def step_complete(self, step: int): + """Called after each training step""" + if not self.enabled: + return + + if self.clear_cache_between_steps: + self.clear_cache_and_measure("step_end", step) + + def get_summary(self) -> str: + """Get summary of all measurements""" + if not self.measurements: + return "No measurements recorded" + + summary = ["", "=" * 100, "DETAILED MEMORY TRACKING SUMMARY", "=" * 100, ""] + + # Group by step + steps = {} + for m in self.measurements: + step = m["step"] + if step not in steps: + steps[step] = [] + steps[step].append(m) + + for step, measures in sorted(steps.items()): + summary.append(f"\nStep {step}:") + summary.append( + f"{'Phase':<30} {'Reserved':>12} {'Allocated':>12} {'Active':>12} {'Frag%':>8}" + ) + summary.append("-" * 80) + + for m in measures: + frag_pct = ( + ((m["reserved"] - m["allocated"]) / m["reserved"] * 100) + if m["reserved"] > 0 + else 0 + ) + summary.append( + f"{m['phase']:<30} " + f"{m['reserved']/1e9:10.2f} GB " + f"{m['allocated']/1e6:10.2f} MB " + f"{m['active']/1e6:10.2f} MB " + f"{frag_pct:7.1f}%" + ) + + # Peak measurements + summary.append("\n" + "=" * 100) + summary.append("PEAK MEASUREMENTS ACROSS ALL STEPS:") + summary.append("=" * 100) + + peak_reserved = max(m["reserved"] for m in self.measurements) + peak_allocated = max(m["allocated"] for m in self.measurements) + peak_active = max(m["active"] for m in self.measurements) + + peak_reserved_phase = [ + m for m in self.measurements if m["reserved"] == peak_reserved + ][0] + peak_allocated_phase = [ + m for m in self.measurements if m["allocated"] == peak_allocated + ][0] + peak_active_phase = [ + m for m in self.measurements if m["active"] == peak_active + ][0] + + summary.append( + f"Peak Reserved: {peak_reserved/1e9:7.2f} GB at Step {peak_reserved_phase['step']} ({peak_reserved_phase['phase']})" + ) + step = peak_allocated_phase["step"] + phase = peak_allocated_phase["phase"] + summary.append( + f"Peak Allocated: {peak_allocated/1e6:7.2f} MB at Step {step} ({phase})" + ) + summary.append( + f"Peak Active: {peak_active/1e6:7.2f} MB at Step {peak_active_phase['step']} ({peak_active_phase['phase']})" + ) + + # Minimum after cache clear + cleared_measures = [m for m in self.measurements if "after_clear" in m["phase"]] + if cleared_measures: + min_reserved_cleared = min(m["reserved"] for m in cleared_measures) + min_measure = [ + m for m in cleared_measures if m["reserved"] == min_reserved_cleared + ][0] + summary.append( + f"\nMinimum Reserved (after cache clear): {min_reserved_cleared/1e9:7.2f} GB at Step {min_measure['step']}" + ) + summary.append(f" Active at minimum: {min_measure['active']/1e6:7.2f} MB") + + summary.append("=" * 100) + return "\n".join(summary) diff --git a/torchtitan/tools/memory_profiler.py b/torchtitan/tools/memory_profiler.py new file mode 100644 index 0000000000..d555de15ee --- /dev/null +++ b/torchtitan/tools/memory_profiler.py @@ -0,0 +1,173 @@ +""" +Detailed memory profiler for distributed training. +Instruments key allocation points to track exactly where memory goes. +""" + +import json +from collections import defaultdict +from typing import Dict + +import torch +from torchtitan.tools.logging import logger + + +class DetailedMemoryProfiler: + """Track memory allocations at key points in training.""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.checkpoints = {} + self.allocations = defaultdict(list) + self.device = torch.cuda.current_device() + + def checkpoint(self, name: str): + """Record memory state at a checkpoint.""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + self.checkpoints[name] = { + "active": stats.get("active_bytes.all.current", 0), + "allocated": stats.get("allocated_bytes.all.current", 0), + "reserved": stats.get("reserved_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc.all.current", 0), + } + + def compute_delta(self, name: str, prev_checkpoint: str) -> Dict: + """Compute memory delta between two checkpoints.""" + if ( + not self.enabled + or name not in self.checkpoints + or prev_checkpoint not in self.checkpoints + ): + return {} + + current = self.checkpoints[name] + previous = self.checkpoints[prev_checkpoint] + + return { + "active_delta": current["active"] - previous["active"], + "allocated_delta": current["allocated"] - previous["allocated"], + "reserved_delta": current["reserved"] - previous["reserved"], + } + + def log_checkpoint(self, name: str, prev_checkpoint: str = None): + """Log checkpoint with optional delta.""" + if not self.enabled or name not in self.checkpoints: + return + + stats = self.checkpoints[name] + active_gb = stats["active"] / (1024**3) + reserved_gb = stats["reserved"] / (1024**3) + + msg = f"[MemProfile] {name}: Active={active_gb:.2f} GB, Reserved={reserved_gb:.2f} GB" + + if prev_checkpoint: + delta = self.compute_delta(name, prev_checkpoint) + if delta: + delta_gb = delta["active_delta"] / (1024**3) + msg += f", Delta={delta_gb:+.2f} GB" + + logger.info(msg) + + def get_breakdown(self) -> Dict: + """Compute memory breakdown from checkpoints.""" + if not self.enabled or not self.checkpoints: + return {} + + # Define key memory components based on checkpoints + breakdown = {} + + checkpoint_names = list(self.checkpoints.keys()) + for i in range(len(checkpoint_names) - 1): + curr_name = checkpoint_names[i + 1] + prev_name = checkpoint_names[i] + + delta = self.compute_delta(curr_name, prev_name) + if delta and delta["active_delta"] > 0: + breakdown[f"{prev_name}_to_{curr_name}"] = delta["active_delta"] + + return breakdown + + def save_report(self, filepath: str): + """Save detailed memory report to JSON.""" + if not self.enabled: + return + + report = { + "checkpoints": self.checkpoints, + "breakdown": self.get_breakdown(), + } + + with open(filepath, "w") as f: + json.dump(report, f, indent=2) + + logger.info(f"Memory report saved to {filepath}") + + def print_summary(self): + """Print summary table of memory usage.""" + if not self.enabled or not self.checkpoints: + return + + logger.info("=" * 80) + logger.info("DETAILED MEMORY PROFILE SUMMARY") + logger.info("=" * 80) + + # Print checkpoints + logger.info(f"\n{'Checkpoint':<40} {'Active (GB)':>15} {'Reserved (GB)':>15}") + logger.info("-" * 72) + + for name, stats in self.checkpoints.items(): + active_gb = stats["active"] / (1024**3) + reserved_gb = stats["reserved"] / (1024**3) + logger.info(f"{name:<40} {active_gb:>14.2f} {reserved_gb:>14.2f}") + + # Print breakdown + breakdown = self.get_breakdown() + if breakdown: + logger.info(f"\n{'Component':<40} {'Memory (GB)':>15} {'Percentage':>12}") + logger.info("-" * 72) + + total = sum(breakdown.values()) + for name, size in sorted( + breakdown.items(), key=lambda x: x[1], reverse=True + ): + size_gb = size / (1024**3) + pct = (size / total * 100) if total > 0 else 0 + logger.info(f"{name:<40} {size_gb:>14.2f} {pct:>11.1f}%") + + logger.info("=" * 80) + + +# Global profiler instance +_profiler = None + + +def get_memory_profiler() -> DetailedMemoryProfiler: + """Get or create global memory profiler.""" + global _profiler + if _profiler is None: + _profiler = DetailedMemoryProfiler(enabled=True) + return _profiler + + +def checkpoint(name: str): + """Record memory checkpoint.""" + get_memory_profiler().checkpoint(name) + + +def log_checkpoint(name: str, prev_checkpoint: str = None): + """Log memory checkpoint.""" + get_memory_profiler().log_checkpoint(name, prev_checkpoint) + + +def print_summary(): + """Print memory profile summary.""" + get_memory_profiler().print_summary() + + +def save_report(filepath: str): + """Save memory report.""" + get_memory_profiler().save_report(filepath) diff --git a/torchtitan/tools/mesh_visualizer.py b/torchtitan/tools/mesh_visualizer.py new file mode 100644 index 0000000000..0ba8fecb03 --- /dev/null +++ b/torchtitan/tools/mesh_visualizer.py @@ -0,0 +1,415 @@ +""" +Device Mesh Visualizer for Distributed Training + +Creates comprehensive visualization of how GPUs are allocated across +all parallelism dimensions: DP, PP, TP, CP, EP. +""" + +import os +from typing import Dict + +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.tools.logging import logger + + +def get_rank_info() -> Dict: + """Get current rank's information across all process groups.""" + info = { + "global_rank": dist.get_rank() if dist.is_initialized() else 0, + "world_size": dist.get_world_size() if dist.is_initialized() else 1, + "local_rank": int(os.environ.get("LOCAL_RANK", 0)), + "node_rank": int(os.environ.get("GROUP_RANK", os.environ.get("NODE_RANK", 0))), + } + return info + + +def visualize_mesh_structure( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a detailed text visualization of the device mesh structure. + + Args: + mesh: The DeviceMesh object + parallel_dims: ParallelDims object with all parallelism settings + rank: Current rank (only rank 0 prints full visualization) + + Returns: + String visualization of the mesh + """ + lines = [] + lines.append("=" * 100) + lines.append("DEVICE MESH VISUALIZATION") + lines.append("=" * 100) + + # Basic info + lines.append("\n[CLUSTER INFO]") + lines.append(f" Total GPUs: {parallel_dims.world_size}") + lines.append(f" Nodes: {parallel_dims.world_size // 8} (assuming 8 GPUs/node)") + + # Parallelism dimensions + lines.append("\n[PARALLELISM DIMENSIONS]") + lines.append(f" DP Replicate (HSDP): {parallel_dims.dp_replicate}") + lines.append(f" DP Shard (FSDP): {parallel_dims.dp_shard}") + lines.append(f" Context Parallel: {parallel_dims.cp}") + lines.append(f" Tensor Parallel: {parallel_dims.tp}") + lines.append(f" Pipeline Parallel: {parallel_dims.pp}") + lines.append(f" Expert Parallel: {parallel_dims.ep}") + lines.append(f" Expert TP: {parallel_dims.etp}") + + # Mesh structure + lines.append("\n[MESH STRUCTURE]") + lines.append(f" Mesh dim names: {mesh.mesh_dim_names}") + lines.append(f" Mesh shape: {mesh.mesh.shape}") + + # Log each dimension + for i, (name, size) in enumerate(zip(mesh.mesh_dim_names, mesh.mesh.shape)): + lines.append(f" Dim {i}: {name:20s} = {size}") + + # EP-specific derived dimensions + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append("\n[EXPERT PARALLEL DERIVED DIMENSIONS]") + lines.append(f" dp_shard_mod_ep (DP for non-experts): {dp_shard_mod_ep}") + lines.append(f" dp_shard_in_ep (DP within EP group): {dp_shard_in_ep}") + lines.append(f" ep_group_size (EP degree): {parallel_dims.ep}") + lines.append("") + lines.append(" Formula: dp_shard = dp_shard_mod_ep * dp_shard_in_ep") + lines.append( + f" {parallel_dims.dp_shard} = {dp_shard_mod_ep} * {dp_shard_in_ep}" + ) + lines.append("") + lines.append(" Formula: ep = dp_shard_in_ep * cp") + lines.append( + f" {parallel_dims.ep} = {dp_shard_in_ep} * {parallel_dims.cp}" + ) + + # Submesh info + lines.append("\n[SUBMESHES]") + + # Try to get submesh info + submesh_names = ["dp", "dp_shard_cp", "dp_cp", "ep", "cp", "tp", "pp"] + for name in submesh_names: + try: + submesh = mesh[name] + lines.append( + f" {name:15s}: size={submesh.size():4d}, dim_names={submesh.mesh_dim_names}" + ) + except (KeyError, RuntimeError): + pass + + return "\n".join(lines) + + +def visualize_gpu_allocation( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a grid visualization showing GPU allocation. + + For 16 nodes (128 GPUs) with EP=64, CP=8: + - Shows how each GPU maps to (dp_shard_mod_ep, dp_shard_in_ep, cp) coordinates + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("GPU ALLOCATION GRID") + lines.append("=" * 100) + + world_size = parallel_dims.world_size + num_nodes = world_size // 8 + + # For EP-enabled config + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append( + f"\nMesh: [{dp_shard_mod_ep}] x [{dp_shard_in_ep}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard_mod_ep] x [dp_shard_in_ep] x [cp]") + lines.append("") + + # Create mapping from global rank to mesh coordinates + lines.append("GPU -> Mesh Coordinate Mapping:") + lines.append("-" * 80) + lines.append( + f"{'Node':>6} | {'GPU':>4} | {'Rank':>5} | {'dp_mod_ep':>10} | {'dp_in_ep':>10} | {'cp':>4} | {'EP Group':>10}" + ) + lines.append("-" * 80) + + # The mesh is laid out as: dp_shard_mod_ep (slowest) x dp_shard_in_ep x cp (fastest) + for node in range(num_nodes): + for local_gpu in range(8): + global_rank = node * 8 + local_gpu + + # Compute mesh coordinates (assuming row-major ordering) + # Total size = dp_shard_mod_ep * dp_shard_in_ep * cp + cp_coord = global_rank % parallel_dims.cp + dp_in_ep_coord = (global_rank // parallel_dims.cp) % dp_shard_in_ep + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + + # EP group = dp_in_ep_coord * cp + cp_coord (within each dp_shard_mod_ep group) + ep_group = dp_in_ep_coord * parallel_dims.cp + cp_coord + + row = ( + f"{node:>6} | {local_gpu:>4} | {global_rank:>5} | " + f"{dp_mod_ep_coord:>10} | {dp_in_ep_coord:>10} | " + f"{cp_coord:>4} | {ep_group:>10}" + ) + lines.append(row) + + if node < num_nodes - 1: + lines.append("-" * 80) + else: + lines.append( + f"\nMesh: [{parallel_dims.dp_shard}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard] x [cp]") + + return "\n".join(lines) + + +def visualize_expert_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize which GPUs belong to which Expert Parallel group. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("EXPERT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.ep <= 1: + lines.append("Expert Parallel is disabled (EP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append(f"\nEP={parallel_dims.ep} experts distributed across GPUs") + lines.append( + f"Each EP group has {parallel_dims.ep} GPUs working on different experts" + ) + lines.append( + f"There are {dp_shard_mod_ep} such EP groups (for FSDP replication of experts)" + ) + lines.append("") + + # Group GPUs by their dp_shard_mod_ep coordinate + lines.append("EP Groups (GPUs that share the same set of experts):") + lines.append("-" * 80) + + for dp_mod_ep_idx in range(dp_shard_mod_ep): + # Find all ranks in this dp_shard_mod_ep group + ranks_in_group = [] + for global_rank in range(world_size): + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + if dp_mod_ep_coord == dp_mod_ep_idx: + ranks_in_group.append(global_rank) + + lines.append(f"\nDP_SHARD_MOD_EP group {dp_mod_ep_idx}:") + lines.append( + f" GPUs: {ranks_in_group[:16]}{'...' if len(ranks_in_group) > 16 else ''}" + ) + lines.append(f" Total: {len(ranks_in_group)} GPUs") + lines.append(" These GPUs have IDENTICAL expert parameters (FSDP sharded)") + + return "\n".join(lines) + + +def visualize_context_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize Context Parallel groups - GPUs that work on different parts of the sequence. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("CONTEXT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.cp <= 1: + lines.append("Context Parallel is disabled (CP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + cp = parallel_dims.cp + + lines.append(f"\nCP={cp} - Each sequence is split into {cp} chunks") + lines.append( + "GPUs with the same (dp_shard, ep) coordinates but different cp coordinates" + ) + lines.append("work on different parts of the same sequence.") + lines.append("") + + # Show a few example CP groups + lines.append("Example CP groups (first few):") + lines.append("-" * 80) + + num_cp_groups = world_size // cp + for cp_group_idx in range(min(4, num_cp_groups)): + ranks_in_group = [cp_group_idx * cp + i for i in range(cp)] + lines.append(f"\nCP group {cp_group_idx}:") + lines.append(f" GPUs: {ranks_in_group}") + lines.append(f" These {cp} GPUs process different chunks of the same sequence") + + if num_cp_groups > 4: + lines.append(f"\n... and {num_cp_groups - 4} more CP groups") + + return "\n".join(lines) + + +def visualize_fsdp_sharding( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize FSDP sharding - which GPUs share which parameters. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("FSDP SHARDING VISUALIZATION") + lines.append("=" * 100) + + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + + lines.append("\n[NON-EXPERT PARAMETERS (Attention, Embeddings, etc.)]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + lines.append( + f" All-gather buffer size per param: original_size / {dp_shard_cp_size}" + ) + + lines.append("\n[EXPERT PARAMETERS (MoE experts)]") + lines.append(" FSDP mesh: dp_shard_mod_ep") + lines.append(f" FSDP group size: {dp_shard_mod_ep} GPUs") + lines.append( + f" Each expert's parameters are sharded across {dp_shard_mod_ep} GPUs" + ) + lines.append( + f" All-gather buffer size per expert param: original_size / {dp_shard_mod_ep}" + ) + + lines.append("\n[MEMORY IMPLICATIONS]") + lines.append( + f" Non-expert params: sharded {dp_shard_cp_size}x -> small per-GPU footprint" + ) + lines.append( + f" Expert params: sharded only {dp_shard_mod_ep}x -> larger per-GPU footprint" + ) + lines.append(" ") + lines.append(" As DP increases:") + lines.append( + " - dp_shard_cp increases -> non-expert params get more sharded" + ) + lines.append( + " - dp_shard_mod_ep increases -> expert params get more sharded" + ) + lines.append( + " - BUT: all-gather/reduce-scatter buffers scale with group size!" + ) + + else: + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + lines.append("\n[ALL PARAMETERS]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + + return "\n".join(lines) + + +def create_full_visualization( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """Create a comprehensive visualization of the entire mesh structure.""" + parts = [ + visualize_mesh_structure(mesh, parallel_dims, rank), + visualize_gpu_allocation(mesh, parallel_dims, rank), + visualize_expert_parallel_groups(mesh, parallel_dims, rank), + visualize_context_parallel_groups(mesh, parallel_dims, rank), + visualize_fsdp_sharding(mesh, parallel_dims, rank), + ] + + full_viz = "\n".join(parts) + full_viz += "\n" + "=" * 100 + full_viz += "\nEND OF DEVICE MESH VISUALIZATION" + full_viz += "\n" + "=" * 100 + + return full_viz + + +def log_mesh_visualization(mesh: DeviceMesh, parallel_dims): + """Log the full mesh visualization (only on rank 0).""" + rank = dist.get_rank() if dist.is_initialized() else 0 + + if rank == 0: + viz = create_full_visualization(mesh, parallel_dims, rank) + # Log each line separately for better formatting + for line in viz.split("\n"): + logger.info(f"[MESH-VIZ] {line}") diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index c37345f8d3..4e4d03606c 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -76,6 +76,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, + profile_memory=True, # Track memory allocations per operation with_stack=profiling_config.with_stack, with_modules=profiling_config.with_modules, ) as torch_profiler: diff --git a/torchtitan/train.py b/torchtitan/train.py index 9c16c2ef47..ca82996c2f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -27,11 +27,15 @@ from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils +from torchtitan.tools.aggressive_memory_manager import create_aggressive_memory_manager +from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker +from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, maybe_enable_profiling, ) +from torchtitan.utils.nan_tracker import create_nan_tracker_for_deepseek class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -104,6 +108,41 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # Initialize detailed memory tracker + self.detailed_memory_tracker = DetailedMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + clear_cache=getattr( + job_config.training, "clear_cache_between_steps", False + ), + ) + + # Initialize CUDA memory tracker + self.cuda_memory_tracker = CUDAMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + ) + + # Initialize aggressive memory manager to reduce CUDA fragmentation + # This clears cache after backward/optimizer to prevent allocation retries + aggressive_mem_mode = getattr( + job_config.training, "aggressive_memory_mode", None + ) + if aggressive_mem_mode: + self.aggressive_mem_manager = create_aggressive_memory_manager( + mode=aggressive_mem_mode, + verbose=getattr( + job_config.training, "aggressive_memory_verbose", False + ), + ) + logger.info( + f"Aggressive memory manager enabled (mode={aggressive_mem_mode})" + ) + else: + self.aggressive_mem_manager = None + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -281,6 +320,17 @@ def __init__(self, job_config: JobConfig): self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) + # Initialize NaN tracker if enabled + self.nan_tracker = None + if job_config.debug.enable_nan_tracker: + rank = int(os.environ.get("RANK", 0)) + self.nan_tracker = create_nan_tracker_for_deepseek( + self.model_parts[0], + rank=rank, + verbose=job_config.debug.nan_tracker_verbose, + ) + logger.info("NaN tracker enabled - will track activations and gradients") + # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) @@ -510,10 +560,15 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} if getattr(self.model_args, "use_flex_attn", False): + # Pass CP mesh for Context Parallel + FlexAttention support + cp_mesh = None + if self.parallel_dims.cp_enabled: + cp_mesh = self.parallel_dims.world_mesh["cp"] extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, + cp_mesh=cp_mesh, ) return inputs, labels, extra_inputs, extra_kwargs @@ -683,11 +738,37 @@ def forward_backward_step( del pred loss.backward() + # Aggressive memory clearing after backward to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_backward() + return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): + # AGGRESSIVE cache clearing before step for accurate memory measurements + # Without this, cached memory from previous steps inflates readings + if self.job_config.training.aggressive_memory_mode: + import gc + + # 1. Synchronize all CUDA streams + torch.cuda.synchronize() + # 2. Python garbage collection (all generations for thorough cleanup) + gc.collect(0) + gc.collect(1) + gc.collect(2) + # 3. Clear CUDA cache (releases cached memory back to GPU) + torch.cuda.empty_cache() + # 4. Synchronize again to ensure cache clear completed + torch.cuda.synchronize() + # 5. Second round of clearing (catches any stragglers) + gc.collect(2) + torch.cuda.empty_cache() + # 6. Reset peak stats so we measure THIS step's peak only + torch.cuda.reset_peak_memory_stats() + self.metrics_processor.device_memory_monitor.reset_peak_stats() + self.optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -696,6 +777,10 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + # Track memory before forward pass + self.detailed_memory_tracker.measure("before_forward", self.step) + self.cuda_memory_tracker.measure_all("before_forward", self.step) + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -704,6 +789,17 @@ def train_step( loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + # Track memory after forward/backward + self.detailed_memory_tracker.measure("after_forward_backward", self.step) + self.cuda_memory_tracker.measure_all("after_forward_backward", self.step) + + # Check for NaN/Inf if tracker is enabled + if self.nan_tracker is not None: + if self.nan_tracker.has_nan(): + self.nan_tracker.print_nan_report() + logger.error(f"NaN detected at step {self.step}! See report above.") + self.nan_tracker.step() # Reset for next step + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, @@ -714,8 +810,39 @@ def train_step( ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() + + # Skip optimizer step if configured (for memory profiling) + if not self.job_config.training.skip_optimizer_step: + import datetime + import time as _time + + # Log step start with timestamp for correlation with vmstat + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info(f"[STEP {self.step}] optimizer.step() START @ {_ts}") + + _optim_start = _time.time() + self.optimizers.step() + _optim_elapsed = _time.time() - _optim_start + + # Aggressive memory clearing after optimizer to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_optimizer() + + # Log step end with timing + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info( + f"[STEP {self.step}] optimizer.step() END @ {_ts} | Duration: {_optim_elapsed:.2f}s" + ) + + self.lr_schedulers.step() + else: + logger.info("Skipping optimizer step (skip_optimizer_step=True)") + + # Track memory after optimizer step + self.detailed_memory_tracker.measure("after_optimizer", self.step) + self.cuda_memory_tracker.measure_all("after_optimizer", self.step) # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -755,11 +882,25 @@ def train_step( extra_metrics=extra_metrics, ) + # Signal step complete to aggressive memory manager (triggers defrag check) + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.step_complete() + @record def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + + # Pre-initialize bf16 optimizer states if configured + # This must happen BEFORE training to avoid rank skew during first step + if hasattr(self.optimizers, "init_bf16_states"): + self.optimizers.init_bf16_states() + # Barrier to ensure all ranks finish before training starts + if torch.distributed.is_initialized(): + torch.distributed.barrier() + logger.info("All ranks synchronized after bf16 optimizer state init") + logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( @@ -832,6 +973,10 @@ def train(self): if memory_profiler: memory_profiler.step() + # Track memory at step end and optionally clear cache + self.detailed_memory_tracker.step_complete(self.step) + self.cuda_memory_tracker.measure_all("step_end", self.step) + # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -846,6 +991,10 @@ def train(self): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) + # Log detailed memory tracking summary + if torch.distributed.get_rank() == 0: + logger.info(self.detailed_memory_tracker.get_summary()) + logger.info("Training completed") def should_continue_training(self) -> bool: diff --git a/torchtitan/utils/__init__.py b/torchtitan/utils/__init__.py new file mode 100644 index 0000000000..1b4ed35bac --- /dev/null +++ b/torchtitan/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.utils.nan_tracker import ( + create_nan_tracker_for_deepseek, + LayerStats, + NaNTracker, + TensorStats, +) + +__all__ = [ + "NaNTracker", + "create_nan_tracker_for_deepseek", + "TensorStats", + "LayerStats", +] diff --git a/torchtitan/utils/nan_tracker.py b/torchtitan/utils/nan_tracker.py new file mode 100644 index 0000000000..c65a5bd26a --- /dev/null +++ b/torchtitan/utils/nan_tracker.py @@ -0,0 +1,495 @@ +""" +Lightweight NaN/Inf tracker for debugging training issues. + +This module provides hooks to track tensor statistics (min, max, mean, nan_count, inf_count) +at each layer without saving tensors, making it suitable for large model debugging. + +Usage: + from torchtitan.utils.nan_tracker import NaNTracker + + tracker = NaNTracker() + tracker.register_hooks(model) + + # In training loop: + loss = model(inputs) + loss.backward() + + # Check for NaN + if tracker.has_nan(): + tracker.print_nan_report() + + tracker.step() # Reset for next step +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + + +@dataclass +class TensorStats: + """Statistics for a single tensor.""" + + name: str + shape: Tuple[int, ...] + dtype: str + min_val: float + max_val: float + mean_val: float + std_val: float + nan_count: int + inf_count: int + total_elements: int + + @property + def has_nan(self) -> bool: + return self.nan_count > 0 + + @property + def has_inf(self) -> bool: + return self.inf_count > 0 + + def __str__(self) -> str: + status = "" + if self.has_nan: + status += f" [NaN: {self.nan_count}/{self.total_elements}]" + if self.has_inf: + status += f" [Inf: {self.inf_count}/{self.total_elements}]" + return ( + f"{self.name}: shape={self.shape}, dtype={self.dtype}, " + f"min={self.min_val:.4g}, max={self.max_val:.4g}, " + f"mean={self.mean_val:.4g}, std={self.std_val:.4g}{status}" + ) + + +@dataclass +class LayerStats: + """Statistics for a layer's inputs and outputs.""" + + layer_name: str + layer_type: str + input_stats: List[TensorStats] = field(default_factory=list) + output_stats: List[TensorStats] = field(default_factory=list) + grad_input_stats: List[TensorStats] = field(default_factory=list) + grad_output_stats: List[TensorStats] = field(default_factory=list) + + @property + def has_nan(self) -> bool: + for stats_list in [ + self.input_stats, + self.output_stats, + self.grad_input_stats, + self.grad_output_stats, + ]: + for stats in stats_list: + if stats.has_nan: + return True + return False + + @property + def has_inf(self) -> bool: + for stats_list in [ + self.input_stats, + self.output_stats, + self.grad_input_stats, + self.grad_output_stats, + ]: + for stats in stats_list: + if stats.has_inf: + return True + return False + + +def compute_tensor_stats(tensor: torch.Tensor, name: str) -> Optional[TensorStats]: + """Compute statistics for a tensor without storing it.""" + if tensor is None: + return None + + if not isinstance(tensor, torch.Tensor): + return None + + # Handle DTensor by getting local tensor + if hasattr(tensor, "_local_tensor"): + tensor = tensor._local_tensor + + # Flatten for stats computation + flat = tensor.detach().float().flatten() + + # Count NaN and Inf + nan_mask = torch.isnan(flat) + inf_mask = torch.isinf(flat) + nan_count = nan_mask.sum().item() + inf_count = inf_mask.sum().item() + + # Compute stats on valid values only + valid_mask = ~(nan_mask | inf_mask) + valid_vals = flat[valid_mask] + + if valid_vals.numel() > 0: + min_val = valid_vals.min().item() + max_val = valid_vals.max().item() + mean_val = valid_vals.mean().item() + std_val = valid_vals.std().item() if valid_vals.numel() > 1 else 0.0 + else: + min_val = float("nan") + max_val = float("nan") + mean_val = float("nan") + std_val = float("nan") + + return TensorStats( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype), + min_val=min_val, + max_val=max_val, + mean_val=mean_val, + std_val=std_val, + nan_count=int(nan_count), + inf_count=int(inf_count), + total_elements=tensor.numel(), + ) + + +class NaNTracker: + """ + Lightweight tracker for NaN/Inf in model activations and gradients. + + Registers forward and backward hooks on model layers to compute statistics + without storing tensors. + + Args: + track_forward: Track forward pass activations + track_backward: Track backward pass gradients + layer_types: Only track these layer types (default: all) + include_patterns: Only track layers matching these patterns + exclude_patterns: Exclude layers matching these patterns + log_every_layer: Print stats for every layer (verbose) + rank: Current process rank for distributed training + """ + + def __init__( + self, + track_forward: bool = True, + track_backward: bool = True, + layer_types: Optional[Tuple[type, ...]] = None, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + log_every_layer: bool = False, + rank: int = 0, + ): + self.track_forward = track_forward + self.track_backward = track_backward + self.layer_types = layer_types + self.include_patterns = include_patterns or [] + self.exclude_patterns = exclude_patterns or [] + self.log_every_layer = log_every_layer + self.rank = rank + + self.step_num = 0 + self.layer_stats: Dict[str, LayerStats] = {} + self.hooks: List[torch.utils.hooks.RemovableHandle] = [] + self._first_nan_layer: Optional[str] = None + self._first_nan_phase: Optional[str] = None + self._forward_order: List[str] = [] + + def _should_track(self, name: str, module: nn.Module) -> bool: + """Check if this layer should be tracked.""" + # Check layer type filter + if self.layer_types is not None: + if not isinstance(module, self.layer_types): + return False + + # Check include patterns + if self.include_patterns: + if not any(p in name for p in self.include_patterns): + return False + + # Check exclude patterns + if self.exclude_patterns: + if any(p in name for p in self.exclude_patterns): + return False + + return True + + def _create_forward_hook(self, layer_name: str, layer_type: str): + """Create a forward hook for a layer.""" + + def hook(module, inputs, outputs): + if layer_name not in self.layer_stats: + self.layer_stats[layer_name] = LayerStats( + layer_name=layer_name, + layer_type=layer_type, + ) + self._forward_order.append(layer_name) + + stats = self.layer_stats[layer_name] + + # Process inputs + if isinstance(inputs, tuple): + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor): + tensor_stats = compute_tensor_stats(inp, f"input_{i}") + if tensor_stats: + stats.input_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"forward_input_{i}" + + # Process outputs + if isinstance(outputs, torch.Tensor): + tensor_stats = compute_tensor_stats(outputs, "output") + if tensor_stats: + stats.output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = "forward_output" + elif isinstance(outputs, tuple): + for i, out in enumerate(outputs): + if isinstance(out, torch.Tensor): + tensor_stats = compute_tensor_stats(out, f"output_{i}") + if tensor_stats: + stats.output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"forward_output_{i}" + + if self.log_every_layer and self.rank == 0: + self._print_layer_stats(layer_name, "forward") + + return hook + + def _create_backward_hook(self, layer_name: str, layer_type: str): + """Create a backward hook for a layer.""" + + def hook(module, grad_input, grad_output): + if layer_name not in self.layer_stats: + self.layer_stats[layer_name] = LayerStats( + layer_name=layer_name, + layer_type=layer_type, + ) + + stats = self.layer_stats[layer_name] + + # Process grad_output (gradient w.r.t. layer output) + if isinstance(grad_output, tuple): + for i, grad in enumerate(grad_output): + if isinstance(grad, torch.Tensor): + tensor_stats = compute_tensor_stats(grad, f"grad_output_{i}") + if tensor_stats: + stats.grad_output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"backward_grad_output_{i}" + + # Process grad_input (gradient w.r.t. layer input) + if isinstance(grad_input, tuple): + for i, grad in enumerate(grad_input): + if isinstance(grad, torch.Tensor): + tensor_stats = compute_tensor_stats(grad, f"grad_input_{i}") + if tensor_stats: + stats.grad_input_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"backward_grad_input_{i}" + + if self.log_every_layer and self.rank == 0: + self._print_layer_stats(layer_name, "backward") + + return hook + + def register_hooks(self, model: nn.Module) -> None: + """Register forward and backward hooks on model layers.""" + for name, module in model.named_modules(): + if not self._should_track(name, module): + continue + + # Skip container modules + if isinstance(module, (nn.ModuleList, nn.ModuleDict, nn.Sequential)): + continue + + layer_type = type(module).__name__ + + if self.track_forward: + hook = module.register_forward_hook( + self._create_forward_hook(name, layer_type) + ) + self.hooks.append(hook) + + if self.track_backward: + hook = module.register_full_backward_hook( + self._create_backward_hook(name, layer_type) + ) + self.hooks.append(hook) + + def remove_hooks(self) -> None: + """Remove all registered hooks.""" + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + def step(self) -> None: + """Reset statistics for next step.""" + self.step_num += 1 + self.layer_stats.clear() + self._first_nan_layer = None + self._first_nan_phase = None + self._forward_order.clear() + + def has_nan(self) -> bool: + """Check if any NaN was detected this step.""" + return self._first_nan_layer is not None + + def has_any_nan_or_inf(self) -> bool: + """Check if any NaN or Inf was detected this step.""" + for stats in self.layer_stats.values(): + if stats.has_nan or stats.has_inf: + return True + return False + + def get_first_nan_location(self) -> Optional[Tuple[str, str]]: + """Get the first layer and phase where NaN appeared.""" + if self._first_nan_layer: + return (self._first_nan_layer, self._first_nan_phase) + return None + + def _print_layer_stats(self, layer_name: str, phase: str) -> None: + """Print statistics for a single layer.""" + if layer_name not in self.layer_stats: + return + + stats = self.layer_stats[layer_name] + print(f"[Step {self.step_num}][{phase}] {layer_name} ({stats.layer_type}):") + + if phase == "forward": + for s in stats.input_stats: + print(f" IN: {s}") + for s in stats.output_stats: + print(f" OUT: {s}") + else: + for s in stats.grad_output_stats: + print(f" GRAD_OUT: {s}") + for s in stats.grad_input_stats: + print(f" GRAD_IN: {s}") + + def print_nan_report(self) -> None: + """Print a detailed report of where NaN/Inf occurred.""" + if self.rank != 0: + return + + print(f"\n{'='*80}") + print(f"NaN/Inf REPORT - Step {self.step_num}") + print(f"{'='*80}") + + if self._first_nan_layer: + print( + f"\n** FIRST NaN detected at: {self._first_nan_layer} ({self._first_nan_phase}) **\n" + ) + + # Print layers in forward order + nan_layers = [] + for layer_name in self._forward_order: + if layer_name in self.layer_stats: + stats = self.layer_stats[layer_name] + if stats.has_nan or stats.has_inf: + nan_layers.append(layer_name) + + if nan_layers: + print(f"Layers with NaN/Inf ({len(nan_layers)} total):") + for layer_name in nan_layers: + stats = self.layer_stats[layer_name] + print(f"\n {layer_name} ({stats.layer_type}):") + for s in stats.input_stats: + if s.has_nan or s.has_inf: + print(f" [FWD IN] {s}") + for s in stats.output_stats: + if s.has_nan or s.has_inf: + print(f" [FWD OUT] {s}") + for s in stats.grad_output_stats: + if s.has_nan or s.has_inf: + print(f" [BWD GRAD_OUT] {s}") + for s in stats.grad_input_stats: + if s.has_nan or s.has_inf: + print(f" [BWD GRAD_IN] {s}") + else: + print("No NaN/Inf detected in tracked layers.") + + print(f"\n{'='*80}\n") + + def print_summary(self) -> None: + """Print a summary of all layer statistics.""" + if self.rank != 0: + return + + print(f"\n{'='*80}") + print(f"LAYER STATISTICS SUMMARY - Step {self.step_num}") + print(f"{'='*80}") + print(f"Total layers tracked: {len(self.layer_stats)}") + + nan_count = sum(1 for s in self.layer_stats.values() if s.has_nan) + inf_count = sum(1 for s in self.layer_stats.values() if s.has_inf) + print(f"Layers with NaN: {nan_count}") + print(f"Layers with Inf: {inf_count}") + + if self._first_nan_layer: + print(f"\nFirst NaN at: {self._first_nan_layer} ({self._first_nan_phase})") + + print(f"{'='*80}\n") + + def get_stats_dict(self) -> Dict[str, Any]: + """Get statistics as a dictionary for logging.""" + result = { + "step": self.step_num, + "has_nan": self.has_nan(), + "first_nan_layer": self._first_nan_layer, + "first_nan_phase": self._first_nan_phase, + "layers_with_nan": [], + "layers_with_inf": [], + } + + for name, stats in self.layer_stats.items(): + if stats.has_nan: + result["layers_with_nan"].append(name) + if stats.has_inf: + result["layers_with_inf"].append(name) + + return result + + +def create_nan_tracker_for_deepseek( + model: nn.Module, + rank: int = 0, + verbose: bool = False, +) -> NaNTracker: + """ + Create a NaN tracker optimized for DeepSeek models. + + Tracks key layers that are most likely to produce NaN: + - Attention layers (FlexAttention, MLA) + - MoE layers (router, experts) + - Normalization layers + - Output projection + """ + tracker = NaNTracker( + track_forward=True, + track_backward=True, + include_patterns=[ + "attention", + "moe", + "router", + "expert", + "norm", + "output", + "feed_forward", + "tok_embeddings", + ], + exclude_patterns=[ + "freqs_cis", + ], + log_every_layer=verbose, + rank=rank, + ) + + tracker.register_hooks(model) + return tracker