From 114b8af4ecd44aa551496bc0b224488bcc00d8f6 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 19 Jan 2026 12:41:06 -0800 Subject: [PATCH 01/18] hacky --- torchtitan/config/job_config.py | 23 +++ .../distributed/activation_checkpoint.py | 32 +++- torchtitan/models/deepseek_v3/__init__.py | 46 +++++ .../models/deepseek_v3/infra/parallelize.py | 6 +- torchtitan/tools/cuda_memory_tracker.py | 123 +++++++++++++ torchtitan/tools/detailed_memory_tracker.py | 160 ++++++++++++++++ torchtitan/tools/memory_profiler.py | 173 ++++++++++++++++++ torchtitan/train.py | 59 +++++- 8 files changed, 617 insertions(+), 5 deletions(-) create mode 100644 torchtitan/tools/cuda_memory_tracker.py create mode 100644 torchtitan/tools/detailed_memory_tracker.py create mode 100644 torchtitan/tools/memory_profiler.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index ff909e6ae9..fb936ff24b 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -370,6 +370,21 @@ class Training: Whether to apply CPU offloading of parameters, gradients, and optimizer states in FSDP """ + enable_detailed_memory_tracking: bool = False + """ + Whether to enable detailed memory tracking at every training phase + """ + + clear_cache_between_steps: bool = False + """ + Whether to clear CUDA cache between training steps to measure minimum memory requirements + """ + + skip_optimizer_step: bool = False + """ + Whether to skip the optimizer step (for memory profiling purposes only) + """ + dtype: Literal["bfloat16", "float32"] = "float32" """ torch dtype for training. In contrast to mixed precision training, setting training_dtype=bfloat16 will @@ -846,6 +861,14 @@ class ActivationCheckpoint: https://docs.pytorch.org/docs/stable/checkpoint.html for details. """ + cpu_offload: bool = False + """ + Enable CPU offloading for activation checkpoints. When enabled, saved activations + are moved to CPU RAM during forward pass and brought back to GPU during backward pass. + This trades memory for PCIe bandwidth, saving GPU memory at the cost of data transfer time. + Only applies when mode is 'full' or 'selective'. + """ + @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 8359f71730..3cb5378637 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -67,6 +67,14 @@ def _apply_op_sac( Returns: nn.Module: The module with selective activation checkpointing applied. """ + # Use CPU offload if enabled + if ac_config.cpu_offload: + from torchtitan.distributed.activation_checkpoint_offload import ( + apply_selective_ac_with_cpu_offload, + ) + + return apply_selective_ac_with_cpu_offload(module, ac_config, base_fqn=base_fqn) + from torch.utils.checkpoint import ( CheckpointPolicy, create_selective_checkpoint_contexts, @@ -146,6 +154,14 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: Returns: nn.Module: The module with full activation checkpointing applied. """ + # Use CPU offload if enabled + if ac_config.cpu_offload: + from torchtitan.distributed.activation_checkpoint_offload import ( + apply_full_ac_with_cpu_offload, + ) + + return apply_full_ac_with_cpu_offload(module, ac_config) + return ptd_checkpoint_wrapper( module, preserve_rng_state=ac_config.preserve_rng_state, @@ -308,6 +324,18 @@ def apply_ac( Returns: None """ + # Special case: CPU offload without activation checkpointing + if ac_config.mode == "none" and ac_config.cpu_offload: + from torchtitan.distributed.activation_checkpoint_offload import ( + apply_offload_wrapper_only, + ) + + for layer_id, transformer_block in model.layers.named_children(): + transformer_block = apply_offload_wrapper_only(transformer_block) + model.layers.register_module(layer_id, transformer_block) + logger.info("Applied activation offloading WITHOUT checkpointing to the model") + return + if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: @@ -319,7 +347,7 @@ def apply_ac( torch._functorch.config.activation_memory_budget = ac_config.memory_budget logger.info(f"Selected {ac_config.memory_budget} budget option") - else: + elif ac_config.mode != "none": for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, @@ -331,4 +359,4 @@ def apply_ac( ) model.layers.register_module(layer_id, transformer_block) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index eedc20cbb5..6b186c5e53 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -50,6 +50,52 @@ v_head_dim=128, mscale=0.70, ), + "debugmodel_1b": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=1024, + inter_dim=4096, + moe_inter_dim=1024, + n_layers=24, + n_dense_layers=2, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), + "debugmodel_7b": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=2048, + inter_dim=8192, + moe_inter_dim=2048, + n_layers=32, + n_dense_layers=3, + n_heads=32, + moe_args=MoEArgs( + num_experts=16, + num_shared_experts=4, + top_k=4, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + ), "debugmodel_flex_attn": DeepSeekV3ModelArgs( vocab_size=2048, dim=256, diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 7da79c361e..262ed992bf 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -108,7 +108,11 @@ def parallelize_deepseekv3( job_config.compile.enable and "model" in job_config.compile.components ) - if job_config.activation_checkpoint.mode != "none": + # Apply activation checkpointing or CPU offloading + if ( + job_config.activation_checkpoint.mode != "none" + or job_config.activation_checkpoint.cpu_offload + ): apply_ac( model, job_config.activation_checkpoint, diff --git a/torchtitan/tools/cuda_memory_tracker.py b/torchtitan/tools/cuda_memory_tracker.py new file mode 100644 index 0000000000..0f7d7af5f4 --- /dev/null +++ b/torchtitan/tools/cuda_memory_tracker.py @@ -0,0 +1,123 @@ +"""Track CUDA memory directly from nvidia-smi and PyTorch""" +import logging +import subprocess +from typing import Dict, Optional + +import torch + +logger = logging.getLogger(__name__) + + +class CUDAMemoryTracker: + """Track memory from both PyTorch and CUDA/nvidia-smi""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.device = torch.cuda.current_device() + self.device_name = torch.cuda.get_device_name(self.device) + + if self.enabled: + logger.info( + f"CUDAMemoryTracker enabled for device {self.device}: {self.device_name}" + ) + + def get_nvidia_smi_memory(self) -> Optional[Dict[str, int]]: + """Get memory from nvidia-smi""" + try: + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used,memory.free,memory.total", + "--format=csv,noheader,nounits", + "-i", + str(self.device), + ], + capture_output=True, + text=True, + timeout=2, + ) + + if result.returncode == 0: + used, free, total = map(int, result.stdout.strip().split(",")) + return {"used_mb": used, "free_mb": free, "total_mb": total} + except Exception as e: + logger.warning(f"Failed to get nvidia-smi memory: {e}") + + return None + + def get_pytorch_memory(self) -> Dict[str, int]: + """Get memory from PyTorch""" + stats = torch.cuda.memory_stats(self.device) + + return { + "reserved_bytes": torch.cuda.memory_reserved(self.device), + "allocated_bytes": torch.cuda.memory_allocated(self.device), + "active_bytes": stats.get("active_bytes.all.current", 0), + "inactive_bytes": stats.get("inactive_split_bytes.all.current", 0), + "peak_active_bytes": stats.get("active_bytes.all.peak", 0), + "num_alloc_retries": stats.get("num_alloc_retries.all.current", 0), + "num_ooms": stats.get("num_ooms.all.current", 0), + } + + def get_cuda_device_memory(self) -> Dict[str, int]: + """Get memory directly from CUDA device properties""" + props = torch.cuda.get_device_properties(self.device) + + return { + "total_memory": props.total_memory, + "reserved_memory": torch.cuda.memory_reserved(self.device), + "allocated_memory": torch.cuda.memory_allocated(self.device), + } + + def measure_all(self, phase: str, step: int): + """Comprehensive memory measurement""" + if not self.enabled: + return + + # PyTorch memory + pytorch_mem = self.get_pytorch_memory() + + # CUDA device memory + cuda_mem = self.get_cuda_device_memory() + + # nvidia-smi memory (if available) + smi_mem = self.get_nvidia_smi_memory() + + # Calculate fragmentation + reserved = pytorch_mem["reserved_bytes"] + allocated = pytorch_mem["allocated_bytes"] + active = pytorch_mem["active_bytes"] + + fragmentation = reserved - allocated + frag_pct = (fragmentation / reserved * 100) if reserved > 0 else 0 + + # Log PyTorch view + logger.info( + f"[PyTorch] Step {step:2d} | {phase:25s} | " + f"Reserved: {reserved/1e9:6.2f} GB | " + f"Allocated: {allocated/1e6:8.2f} MB | " + f"Active: {active/1e6:8.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + # Log CUDA/nvidia-smi view + if smi_mem: + logger.info( + f"[CUDA-SMI] Step {step:2d} | {phase:25s} | " + f"Used: {smi_mem['used_mb']/1024:6.2f} GB | " + f"Free: {smi_mem['free_mb']/1024:6.2f} GB | " + f"Total: {smi_mem['total_mb']/1024:6.2f} GB" + ) + + # Log comparison + if smi_mem: + pytorch_used_gb = reserved / 1e9 + smi_used_gb = smi_mem["used_mb"] / 1024 + diff_gb = smi_used_gb - pytorch_used_gb + + logger.info( + f"[Compare] Step {step:2d} | {phase:25s} | " + f"PyTorch reports: {pytorch_used_gb:6.2f} GB | " + f"nvidia-smi reports: {smi_used_gb:6.2f} GB | " + f"Diff: {diff_gb:+6.2f} GB" + ) diff --git a/torchtitan/tools/detailed_memory_tracker.py b/torchtitan/tools/detailed_memory_tracker.py new file mode 100644 index 0000000000..7b513b3e20 --- /dev/null +++ b/torchtitan/tools/detailed_memory_tracker.py @@ -0,0 +1,160 @@ +"""Detailed memory tracking throughout training step""" +import logging +from typing import Dict, List + +import torch + +logger = logging.getLogger(__name__) + + +class DetailedMemoryTracker: + """Track memory at every phase of training with cache clearing""" + + def __init__(self, enabled: bool = True, clear_cache: bool = True): + self.enabled = enabled + self.clear_cache_between_steps = clear_cache + self.measurements: List[Dict] = [] + self.device = torch.cuda.current_device() + + if self.enabled: + logger.info(f"DetailedMemoryTracker enabled (clear_cache={clear_cache})") + + def measure(self, phase: str, step: int): + """Capture memory state at a specific phase""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + measurement = { + "step": step, + "phase": phase, + "reserved": torch.cuda.memory_reserved(self.device), + "allocated": torch.cuda.memory_allocated(self.device), + "active": stats.get("active_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc_retries.all.current", 0), + } + + self.measurements.append(measurement) + + # Calculate fragmentation + fragmentation = measurement["reserved"] - measurement["allocated"] + frag_pct = ( + (fragmentation / measurement["reserved"] * 100) + if measurement["reserved"] > 0 + else 0 + ) + + logger.info( + f"[MemTrack] Step {step} | {phase:20s} | " + f"Reserved: {measurement['reserved']/1e9:6.2f} GB | " + f"Allocated: {measurement['allocated']/1e6:7.2f} MB | " + f"Active: {measurement['active']/1e6:7.2f} MB | " + f"Frag: {frag_pct:5.1f}%" + ) + + def clear_cache_and_measure(self, phase: str, step: int): + """Clear cache and measure to see minimum memory""" + if not self.enabled: + return + + # Measure before clearing + self.measure(f"{phase}_before_clear", step) + + # Clear cache + torch.cuda.empty_cache() + torch.cuda.synchronize() + + # Measure after clearing + self.measure(f"{phase}_after_clear", step) + + def step_complete(self, step: int): + """Called after each training step""" + if not self.enabled: + return + + if self.clear_cache_between_steps: + self.clear_cache_and_measure("step_end", step) + + def get_summary(self) -> str: + """Get summary of all measurements""" + if not self.measurements: + return "No measurements recorded" + + summary = ["", "=" * 100, "DETAILED MEMORY TRACKING SUMMARY", "=" * 100, ""] + + # Group by step + steps = {} + for m in self.measurements: + step = m["step"] + if step not in steps: + steps[step] = [] + steps[step].append(m) + + for step, measures in sorted(steps.items()): + summary.append(f"\nStep {step}:") + summary.append( + f"{'Phase':<30} {'Reserved':>12} {'Allocated':>12} {'Active':>12} {'Frag%':>8}" + ) + summary.append("-" * 80) + + for m in measures: + frag_pct = ( + ((m["reserved"] - m["allocated"]) / m["reserved"] * 100) + if m["reserved"] > 0 + else 0 + ) + summary.append( + f"{m['phase']:<30} " + f"{m['reserved']/1e9:10.2f} GB " + f"{m['allocated']/1e6:10.2f} MB " + f"{m['active']/1e6:10.2f} MB " + f"{frag_pct:7.1f}%" + ) + + # Peak measurements + summary.append("\n" + "=" * 100) + summary.append("PEAK MEASUREMENTS ACROSS ALL STEPS:") + summary.append("=" * 100) + + peak_reserved = max(m["reserved"] for m in self.measurements) + peak_allocated = max(m["allocated"] for m in self.measurements) + peak_active = max(m["active"] for m in self.measurements) + + peak_reserved_phase = [ + m for m in self.measurements if m["reserved"] == peak_reserved + ][0] + peak_allocated_phase = [ + m for m in self.measurements if m["allocated"] == peak_allocated + ][0] + peak_active_phase = [ + m for m in self.measurements if m["active"] == peak_active + ][0] + + summary.append( + f"Peak Reserved: {peak_reserved/1e9:7.2f} GB at Step {peak_reserved_phase['step']} ({peak_reserved_phase['phase']})" + ) + step = peak_allocated_phase["step"] + phase = peak_allocated_phase["phase"] + summary.append( + f"Peak Allocated: {peak_allocated/1e6:7.2f} MB at Step {step} ({phase})" + ) + summary.append( + f"Peak Active: {peak_active/1e6:7.2f} MB at Step {peak_active_phase['step']} ({peak_active_phase['phase']})" + ) + + # Minimum after cache clear + cleared_measures = [m for m in self.measurements if "after_clear" in m["phase"]] + if cleared_measures: + min_reserved_cleared = min(m["reserved"] for m in cleared_measures) + min_measure = [ + m for m in cleared_measures if m["reserved"] == min_reserved_cleared + ][0] + summary.append( + f"\nMinimum Reserved (after cache clear): {min_reserved_cleared/1e9:7.2f} GB at Step {min_measure['step']}" + ) + summary.append(f" Active at minimum: {min_measure['active']/1e6:7.2f} MB") + + summary.append("=" * 100) + return "\n".join(summary) diff --git a/torchtitan/tools/memory_profiler.py b/torchtitan/tools/memory_profiler.py new file mode 100644 index 0000000000..d555de15ee --- /dev/null +++ b/torchtitan/tools/memory_profiler.py @@ -0,0 +1,173 @@ +""" +Detailed memory profiler for distributed training. +Instruments key allocation points to track exactly where memory goes. +""" + +import json +from collections import defaultdict +from typing import Dict + +import torch +from torchtitan.tools.logging import logger + + +class DetailedMemoryProfiler: + """Track memory allocations at key points in training.""" + + def __init__(self, enabled: bool = True): + self.enabled = enabled + self.checkpoints = {} + self.allocations = defaultdict(list) + self.device = torch.cuda.current_device() + + def checkpoint(self, name: str): + """Record memory state at a checkpoint.""" + if not self.enabled: + return + + stats = torch.cuda.memory_stats(self.device) + + self.checkpoints[name] = { + "active": stats.get("active_bytes.all.current", 0), + "allocated": stats.get("allocated_bytes.all.current", 0), + "reserved": stats.get("reserved_bytes.all.current", 0), + "peak_active": stats.get("active_bytes.all.peak", 0), + "num_allocs": stats.get("num_alloc.all.current", 0), + } + + def compute_delta(self, name: str, prev_checkpoint: str) -> Dict: + """Compute memory delta between two checkpoints.""" + if ( + not self.enabled + or name not in self.checkpoints + or prev_checkpoint not in self.checkpoints + ): + return {} + + current = self.checkpoints[name] + previous = self.checkpoints[prev_checkpoint] + + return { + "active_delta": current["active"] - previous["active"], + "allocated_delta": current["allocated"] - previous["allocated"], + "reserved_delta": current["reserved"] - previous["reserved"], + } + + def log_checkpoint(self, name: str, prev_checkpoint: str = None): + """Log checkpoint with optional delta.""" + if not self.enabled or name not in self.checkpoints: + return + + stats = self.checkpoints[name] + active_gb = stats["active"] / (1024**3) + reserved_gb = stats["reserved"] / (1024**3) + + msg = f"[MemProfile] {name}: Active={active_gb:.2f} GB, Reserved={reserved_gb:.2f} GB" + + if prev_checkpoint: + delta = self.compute_delta(name, prev_checkpoint) + if delta: + delta_gb = delta["active_delta"] / (1024**3) + msg += f", Delta={delta_gb:+.2f} GB" + + logger.info(msg) + + def get_breakdown(self) -> Dict: + """Compute memory breakdown from checkpoints.""" + if not self.enabled or not self.checkpoints: + return {} + + # Define key memory components based on checkpoints + breakdown = {} + + checkpoint_names = list(self.checkpoints.keys()) + for i in range(len(checkpoint_names) - 1): + curr_name = checkpoint_names[i + 1] + prev_name = checkpoint_names[i] + + delta = self.compute_delta(curr_name, prev_name) + if delta and delta["active_delta"] > 0: + breakdown[f"{prev_name}_to_{curr_name}"] = delta["active_delta"] + + return breakdown + + def save_report(self, filepath: str): + """Save detailed memory report to JSON.""" + if not self.enabled: + return + + report = { + "checkpoints": self.checkpoints, + "breakdown": self.get_breakdown(), + } + + with open(filepath, "w") as f: + json.dump(report, f, indent=2) + + logger.info(f"Memory report saved to {filepath}") + + def print_summary(self): + """Print summary table of memory usage.""" + if not self.enabled or not self.checkpoints: + return + + logger.info("=" * 80) + logger.info("DETAILED MEMORY PROFILE SUMMARY") + logger.info("=" * 80) + + # Print checkpoints + logger.info(f"\n{'Checkpoint':<40} {'Active (GB)':>15} {'Reserved (GB)':>15}") + logger.info("-" * 72) + + for name, stats in self.checkpoints.items(): + active_gb = stats["active"] / (1024**3) + reserved_gb = stats["reserved"] / (1024**3) + logger.info(f"{name:<40} {active_gb:>14.2f} {reserved_gb:>14.2f}") + + # Print breakdown + breakdown = self.get_breakdown() + if breakdown: + logger.info(f"\n{'Component':<40} {'Memory (GB)':>15} {'Percentage':>12}") + logger.info("-" * 72) + + total = sum(breakdown.values()) + for name, size in sorted( + breakdown.items(), key=lambda x: x[1], reverse=True + ): + size_gb = size / (1024**3) + pct = (size / total * 100) if total > 0 else 0 + logger.info(f"{name:<40} {size_gb:>14.2f} {pct:>11.1f}%") + + logger.info("=" * 80) + + +# Global profiler instance +_profiler = None + + +def get_memory_profiler() -> DetailedMemoryProfiler: + """Get or create global memory profiler.""" + global _profiler + if _profiler is None: + _profiler = DetailedMemoryProfiler(enabled=True) + return _profiler + + +def checkpoint(name: str): + """Record memory checkpoint.""" + get_memory_profiler().checkpoint(name) + + +def log_checkpoint(name: str, prev_checkpoint: str = None): + """Log memory checkpoint.""" + get_memory_profiler().log_checkpoint(name, prev_checkpoint) + + +def print_summary(): + """Print memory profile summary.""" + get_memory_profiler().print_summary() + + +def save_report(filepath: str): + """Save memory report.""" + get_memory_profiler().save_report(filepath) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9c16c2ef47..71c906a625 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -25,8 +25,11 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.memory_defrag import MemoryDefragManager from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils +from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker +from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger from torchtitan.tools.profiling import ( maybe_enable_memory_snapshot, @@ -104,6 +107,30 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) + # Initialize memory defragmentation manager + self.defrag_manager = MemoryDefragManager( + enabled=getattr(job_config.training, "enable_memory_defrag", False), + defrag_freq=getattr(job_config.training, "defrag_freq", 1), + aggressive=getattr(job_config.training, "aggressive_defrag", False), + ) + + # Initialize detailed memory tracker + self.detailed_memory_tracker = DetailedMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + clear_cache=getattr( + job_config.training, "clear_cache_between_steps", False + ), + ) + + # Initialize CUDA memory tracker + self.cuda_memory_tracker = CUDAMemoryTracker( + enabled=getattr( + job_config.training, "enable_detailed_memory_tracking", False + ), + ) + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -696,6 +723,10 @@ def train_step( # the major variables that are used in the training loop. parallel_dims = self.parallel_dims + # Track memory before forward pass + self.detailed_memory_tracker.measure("before_forward", self.step) + self.cuda_memory_tracker.measure_all("before_forward", self.step) + accumulated_losses = [] # If data runs out during gradient accumulation, that # entire step will not be executed. @@ -704,6 +735,10 @@ def train_step( loss = self.forward_backward_step(input_dict, labels) accumulated_losses.append(loss.detach()) + # Track memory after forward/backward + self.detailed_memory_tracker.measure("after_forward_backward", self.step) + self.cuda_memory_tracker.measure_all("after_forward_backward", self.step) + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, @@ -714,8 +749,17 @@ def train_step( ep_enabled=parallel_dims.ep_enabled, ) self.checkpointer.maybe_wait_for_staging() - self.optimizers.step() - self.lr_schedulers.step() + + # Skip optimizer step if configured (for memory profiling) + if not self.job_config.training.skip_optimizer_step: + self.optimizers.step() + self.lr_schedulers.step() + else: + logger.info("Skipping optimizer step (skip_optimizer_step=True)") + + # Track memory after optimizer step + self.detailed_memory_tracker.measure("after_optimizer", self.step) + self.cuda_memory_tracker.measure_all("after_optimizer", self.step) # Reduce the data collected over gradient accumulation steps. loss = torch.sum(torch.stack(accumulated_losses)) @@ -832,6 +876,13 @@ def train(self): if memory_profiler: memory_profiler.step() + # Run memory defragmentation if enabled + self.defrag_manager.step(self.step) + + # Track memory at step end and optionally clear cache + self.detailed_memory_tracker.step_complete(self.step) + self.cuda_memory_tracker.measure_all("step_end", self.step) + # reduce timeout after first train step for faster signal # (assuming lazy init and compilation are finished) if self.step == 1: @@ -846,6 +897,10 @@ def train(self): logger.info("Sleeping 2 seconds for other ranks to complete") time.sleep(2) + # Log detailed memory tracking summary + if torch.distributed.get_rank() == 0: + logger.info(self.detailed_memory_tracker.get_summary()) + logger.info("Training completed") def should_continue_training(self) -> bool: From 1fda9bb8cf2da98d0f0794f3c809c1f27d27b51c Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Mon, 19 Jan 2026 19:16:07 -0800 Subject: [PATCH 02/18] Add Kimi 1T training configs, activation offload, and memory defrag - Add activation checkpoint offload module - Add memory defragmentation utilities - Add deep memory profiler script - Add various Kimi 1T training configs (EP64, EP96, EP128, CP2, etc.) - Add Qwen3 activation offload test configs - Add slurm launch scripts - Update DeepSeek V3 model with MoE improvements Co-Authored-By: Claude Opus 4.5 --- deep_memory_profiler.py | 245 ++++++++++++++ launch_kimi_1t_emozilla.slurm | 84 +++++ test_single_node.slurm | 40 +++ .../activation_checkpoint_offload.py | 313 ++++++++++++++++++ torchtitan/memory_defrag.py | 101 ++++++ torchtitan/models/deepseek_v3/__init__.py | 32 ++ .../models/deepseek_v3/infra/parallelize.py | 5 +- torchtitan/models/deepseek_v3/model/args.py | 13 +- torchtitan/models/deepseek_v3/model/model.py | 29 +- .../train_configs/debug_1b_baseline.toml | 84 +++++ .../debug_1b_no_ac_baseline.toml | 54 +++ .../train_configs/debug_1b_offload.toml | 84 +++++ .../train_configs/debug_1b_offload_only.toml | 55 +++ .../train_configs/debug_7b_baseline.toml | 55 +++ .../train_configs/debug_7b_offload.toml | 55 +++ .../debug_activation_offload.toml | 82 +++++ .../exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml | 46 +++ .../exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml | 46 +++ .../exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml | 46 +++ .../exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml | 46 +++ .../exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml | 46 +++ .../exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml | 46 +++ .../exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml | 46 +++ .../exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml | 46 +++ .../exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml | 46 +++ .../exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml | 46 +++ .../exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml | 46 +++ .../exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml | 46 +++ .../exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml | 46 +++ .../exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml | 46 +++ .../exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml | 46 +++ .../exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml | 46 +++ .../exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml | 46 +++ .../exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml | 46 +++ .../exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml | 46 +++ .../exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml | 46 +++ .../exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml | 46 +++ .../exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml | 46 +++ .../exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml | 46 +++ .../exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml | 46 +++ .../exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml | 46 +++ .../exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml | 46 +++ .../exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml | 46 +++ .../exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml | 46 +++ .../exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml | 46 +++ .../exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml | 46 +++ .../train_configs/kimi_1t_10n_28k_flb.toml | 46 +++ .../train_configs/kimi_1t_10n_ep16_2k.toml | 45 +++ .../train_configs/kimi_1t_10n_ep16_4k.toml | 43 +++ .../train_configs/kimi_1t_10n_ep16_8k.toml | 43 +++ .../kimi_1t_12n_28k_ac_offload.toml | 48 +++ .../kimi_1t_12n_28k_selective_ac.toml | 49 +++ .../kimi_1t_12n_cp2_28k_flex_fix.toml | 48 +++ .../train_configs/kimi_1t_12n_cp2_30720.toml | 46 +++ .../train_configs/kimi_1t_12n_cp2_32768.toml | 46 +++ .../kimi_1t_12n_ep12_28k_flb.toml | 47 +++ .../kimi_1t_12n_ep12_32k_flb.toml | 47 +++ .../kimi_1t_12n_ep96_28k_flb.toml | 47 +++ .../kimi_1t_12n_ep96_32k_flb.toml | 47 +++ .../train_configs/kimi_1t_16k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_16n_ep128_2k.toml | 45 +++ .../train_configs/kimi_1t_16n_ep128_4k.toml | 45 +++ .../train_configs/kimi_1t_16n_ep128_8k.toml | 45 +++ .../train_configs/kimi_1t_20k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_24k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_28k_ac_offload.toml | 48 +++ .../train_configs/kimi_1t_28k_force_lb.toml | 46 +++ .../kimi_1t_28k_selective_ac.toml | 49 +++ .../train_configs/kimi_1t_32k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_4k_force_lb.toml | 47 +++ .../train_configs/kimi_1t_6k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_8k_force_lb.toml | 46 +++ .../train_configs/kimi_1t_8n_cp2_28k_flb.toml | 48 +++ .../kimi_1t_8n_cp2_28k_flex_fix.toml | 48 +++ .../kimi_1t_8n_cp2_28k_sdpa.toml | 48 +++ .../train_configs/kimi_1t_8n_cp2_30720.toml | 46 +++ .../train_configs/kimi_1t_8n_cp2_32768.toml | 46 +++ .../train_configs/kimi_1t_8n_cp2_ep1_28k.toml | 48 +++ .../train_configs/kimi_1t_8n_tp2_28k.toml | 48 +++ .../kimi_1t_activation_offload.toml | 81 +++++ ...i_1t_baseline_ep32_40nodes_no_offload.toml | 50 +++ ...mi_1t_baseline_ep64_8nodes_no_offload.toml | 51 +++ ..._cpuoffload_ep32_40nodes_with_offload.toml | 52 +++ ...mi_1t_debug_ep32_40nodes_with_offload.toml | 53 +++ .../kimi_1t_debug_ep32_8nodes_seq512.toml | 52 +++ .../train_configs/kimi_1t_defrag_test.toml | 60 ++++ .../kimi_1t_detailed_memory_profiling.toml | 78 +++++ .../kimi_1t_detailed_tracking.toml | 47 +++ .../train_configs/kimi_1t_emozilla.toml | 116 +++++++ .../train_configs/kimi_1t_fix_combined.toml | 50 +++ .../train_configs/kimi_1t_fix_defrag.toml | 48 +++ .../kimi_1t_fix_selective_ac.toml | 47 +++ .../train_configs/kimi_1t_memory_16k_ctx.toml | 46 +++ .../train_configs/kimi_1t_memory_1k_ctx.toml | 43 +++ .../train_configs/kimi_1t_memory_2k_ctx.toml | 43 +++ .../train_configs/kimi_1t_memory_4k_ctx.toml | 43 +++ .../train_configs/kimi_1t_memory_8k_ctx.toml | 46 +++ .../kimi_1t_memprof_24k_flb.toml | 53 +++ .../kimi_1t_memprof_28k_flb.toml | 53 +++ .../train_configs/kimi_1t_memprof_2k.toml | 50 +++ .../kimi_1t_memprof_2k_force_lb.toml | 53 +++ .../train_configs/kimi_1t_memprof_4k.toml | 50 +++ .../kimi_1t_memprof_4k_force_lb.toml | 53 +++ .../kimi_1t_offload_no_cache_clear.toml | 51 +++ ...t_optimized_ep32_40nodes_with_offload.toml | 50 +++ ...1t_optimized_ep64_8nodes_with_offload.toml | 51 +++ .../train_configs/kimi_1t_profiling.toml | 116 +++++++ .../kimi_1t_profiling_ep64_8nodes.toml | 51 +++ .../qwen3_1.7b_local_test_baseline.toml | 70 ++++ .../qwen3_1.7b_local_test_offload.toml | 70 ++++ ...qwen3_30b_a3b_activation_offload_test.toml | 75 +++++ torchtitan/train.py | 5 + 112 files changed, 6093 insertions(+), 10 deletions(-) create mode 100644 deep_memory_profiler.py create mode 100755 launch_kimi_1t_emozilla.slurm create mode 100755 test_single_node.slurm create mode 100644 torchtitan/distributed/activation_checkpoint_offload.py create mode 100644 torchtitan/memory_defrag.py create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml create mode 100644 torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml create mode 100644 torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml create mode 100644 torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml diff --git a/deep_memory_profiler.py b/deep_memory_profiler.py new file mode 100644 index 0000000000..e6bfb93932 --- /dev/null +++ b/deep_memory_profiler.py @@ -0,0 +1,245 @@ +#!/usr/bin/env python3 +""" +Deep Memory Profiler for Kimi K2 1T Model +Tracks memory allocation at each layer/operation to identify where OOM occurs. +""" + +import json +import sys +from collections import defaultdict +from typing import Dict, List + +import torch + + +class DeepMemoryProfiler: + def __init__(self, output_file: str = "memory_profile.json"): + self.output_file = output_file + self.memory_events: List[Dict] = [] + self.hooks = [] + self.current_step = 0 + self.current_phase = "init" + + def _get_memory_stats(self) -> Dict: + """Get current GPU memory statistics.""" + if not torch.cuda.is_available(): + return {} + + stats = torch.cuda.memory_stats() + return { + "allocated_gb": torch.cuda.memory_allocated() / 1e9, + "reserved_gb": torch.cuda.memory_reserved() / 1e9, + "max_allocated_gb": torch.cuda.max_memory_allocated() / 1e9, + "active_gb": stats.get("active_bytes.all.current", 0) / 1e9, + "inactive_gb": stats.get("inactive_split_bytes.all.current", 0) / 1e9, + "num_alloc_retries": stats.get("num_alloc_retries", 0), + "num_ooms": stats.get("num_ooms", 0), + } + + def log_memory(self, event_name: str, extra_info: Dict = None): + """Log memory at a specific event.""" + mem_stats = self._get_memory_stats() + event = { + "step": self.current_step, + "phase": self.current_phase, + "event": event_name, + "memory": mem_stats, + } + if extra_info: + event["extra"] = extra_info + self.memory_events.append(event) + + # Print for real-time monitoring + print( + f"[MemProf] Step {self.current_step} | {self.current_phase} | {event_name} | " + f"Alloc: {mem_stats.get('allocated_gb', 0):.2f} GB | " + f"Reserved: {mem_stats.get('reserved_gb', 0):.2f} GB" + ) + + def _make_forward_hook(self, layer_name: str): + """Create a forward hook for a layer.""" + + def hook(module, input, output): + input_shapes = [] + for inp in input: + if isinstance(inp, torch.Tensor): + input_shapes.append(list(inp.shape)) + + output_shapes = [] + if isinstance(output, torch.Tensor): + output_shapes.append(list(output.shape)) + elif isinstance(output, (tuple, list)): + for out in output: + if isinstance(out, torch.Tensor): + output_shapes.append(list(out.shape)) + + self.log_memory( + f"forward:{layer_name}", + { + "input_shapes": input_shapes, + "output_shapes": output_shapes, + }, + ) + + return hook + + def _make_backward_hook(self, layer_name: str): + """Create a backward hook for a layer.""" + + def hook(module, grad_input, grad_output): + self.log_memory(f"backward:{layer_name}") + + return hook + + def attach_hooks(self, model: torch.nn.Module, layers_to_track: List[str] = None): + """Attach memory tracking hooks to model layers.""" + if layers_to_track is None: + # Default: track key layers in DeepSeek/MoE model + layers_to_track = [ + "embed_tokens", + "layers.0", # First transformer layer + "layers.30", # Middle layer + "layers.60", # Last layer (if exists) + "moe", # MoE layers + "experts", # Expert modules + "norm", + "lm_head", + ] + + for name, module in model.named_modules(): + should_track = any(track_name in name for track_name in layers_to_track) + if should_track: + # Forward hook + handle = module.register_forward_hook(self._make_forward_hook(name)) + self.hooks.append(handle) + # Backward hook + handle = module.register_full_backward_hook( + self._make_backward_hook(name) + ) + self.hooks.append(handle) + print(f"[MemProf] Attached hooks to: {name}") + + def remove_hooks(self): + """Remove all hooks.""" + for hook in self.hooks: + hook.remove() + self.hooks = [] + + def set_step(self, step: int): + self.current_step = step + + def set_phase(self, phase: str): + self.current_phase = phase + + def save_profile(self): + """Save memory profile to JSON file.""" + with open(self.output_file, "w") as f: + json.dump(self.memory_events, f, indent=2) + print(f"[MemProf] Saved profile to {self.output_file}") + + def print_summary(self): + """Print memory profile summary.""" + print("\n" + "=" * 80) + print("MEMORY PROFILE SUMMARY") + print("=" * 80) + + # Group by event name and find max memory + event_max_mem = defaultdict(float) + event_counts = defaultdict(int) + + for event in self.memory_events: + name = event["event"] + mem = event["memory"].get("allocated_gb", 0) + event_max_mem[name] = max(event_max_mem[name], mem) + event_counts[name] += 1 + + # Sort by max memory + sorted_events = sorted(event_max_mem.items(), key=lambda x: x[1], reverse=True) + + print(f"\n{'Event':<60} {'Max Alloc (GB)':<15} {'Count':<10}") + print("-" * 85) + for event_name, max_mem in sorted_events[:30]: + print(f"{event_name:<60} {max_mem:<15.2f} {event_counts[event_name]:<10}") + + # Find peak memory point + if self.memory_events: + peak_event = max( + self.memory_events, key=lambda x: x["memory"].get("reserved_gb", 0) + ) + print(f"\n{'='*80}") + print( + f"PEAK MEMORY: {peak_event['memory'].get('reserved_gb', 0):.2f} GB reserved" + ) + print( + f" At: Step {peak_event['step']} | Phase: {peak_event['phase']} | Event: {peak_event['event']}" + ) + if "extra" in peak_event: + print(f" Extra: {peak_event['extra']}") + + +def analyze_memory_difference(profile_2k: str, profile_4k: str): + """Compare memory profiles between 2k and 4k to find differences.""" + with open(profile_2k) as f: + events_2k = json.load(f) + with open(profile_4k) as f: + events_4k = json.load(f) + + print("\n" + "=" * 80) + print("MEMORY COMPARISON: 2k vs 4k context") + print("=" * 80) + + # Build event maps + def build_event_map(events): + event_map = {} + for e in events: + key = (e["step"], e["phase"], e["event"]) + event_map[key] = e["memory"] + return event_map + + map_2k = build_event_map(events_2k) + map_4k = build_event_map(events_4k) + + # Find common events and compare + common_keys = set(map_2k.keys()) & set(map_4k.keys()) + + differences = [] + for key in common_keys: + mem_2k = map_2k[key].get("allocated_gb", 0) + mem_4k = map_4k[key].get("allocated_gb", 0) + diff = mem_4k - mem_2k + if abs(diff) > 0.1: # Only show significant differences + differences.append((key, mem_2k, mem_4k, diff)) + + # Sort by difference + differences.sort(key=lambda x: x[3], reverse=True) + + print(f"\n{'Event':<50} {'2k (GB)':<10} {'4k (GB)':<10} {'Diff (GB)':<10}") + print("-" * 80) + for key, mem_2k, mem_4k, diff in differences[:20]: + step, phase, event = key + event_short = event[:45] if len(event) > 45 else event + print(f"{event_short:<50} {mem_2k:<10.2f} {mem_4k:<10.2f} {diff:<+10.2f}") + + # Summary + total_2k = ( + max(e["memory"].get("reserved_gb", 0) for e in events_2k) if events_2k else 0 + ) + total_4k = ( + max(e["memory"].get("reserved_gb", 0) for e in events_4k) if events_4k else 0 + ) + + print(f"\n{'='*80}") + print("Peak Reserved Memory:") + print(f" 2k context: {total_2k:.2f} GB") + print(f" 4k context: {total_4k:.2f} GB") + print(f" Difference: {total_4k - total_2k:+.2f} GB") + + +if __name__ == "__main__": + if len(sys.argv) > 2: + # Compare mode + analyze_memory_difference(sys.argv[1], sys.argv[2]) + else: + print( + "Usage: python deep_memory_profiler.py " + ) diff --git a/launch_kimi_1t_emozilla.slurm b/launch_kimi_1t_emozilla.slurm new file mode 100755 index 0000000000..0876b3f364 --- /dev/null +++ b/launch_kimi_1t_emozilla.slurm @@ -0,0 +1,84 @@ +#!/bin/bash +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. + +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# --- This script is optimized for AWS with EFA +# --- adjust NCCL_BUFFSIZE if you encounter memory +# --- constraint issues or to tune for improved performance. +# --- + +#SBATCH --job-name=kimi_1t_emozilla + +#SBATCH --ntasks=40 + +#SBATCH --nodes=40 + +#SBATCH --gpus-per-task=8 + +#SBATCH --cpus-per-task=64 + +#SBATCH --partition=batch + + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo Node IP: $head_node_ip +export LOGLEVEL=INFO +# Enable for A100 +export FI_PROVIDER="efa" +# Ensure that P2P is available +# export NCCL_P2P_DISABLE=1 +# export NCCL_IB_DISABLE=1 + +# debugging flags (optional) +export NCCL_DEBUG=WARN +export PYTHONFAULTHANDLER=1 +# optional debug settings +# export NCCL_DEBUG=INFO +# NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV + +export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export CUDA_LAUNCH_BLOCKING=0 + +# on your cluster you might need these: +# set the network interface +export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" +export NCCL_BUFFSIZE=2097152 +#export TORCH_DIST_INIT_BARRIER=1 +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 + +export TRITON_HOME=/tmp/emotritoncache_$$_$SLURM_PROCID +export NUMBA_CACHE_DIR=/tmp/numbacache_$$_$SLURM_PROCID +mkdir -p $TRITON_HOME $NUMBA_CACHE_DIR + +export WANDB_ENTITY="nous_research" +#export WANDB_PROJECT="torchtune" +export WANDB_PROJECT="moe" + +export HF_HOME="/home/phuc/.cache/huggingface" +mkdir -p $HF_HOME +#export NVSHMEM_DISABLE_NIC_LOCKING=1 +#export NVSHMEM_VERBOSE=3 +#export PYTORCH_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128,garbage_collection_threshold:0.95" + +# Activate conda environment on all nodes +export PATH="/home/phuc/kimi_1t/env/bin:$PATH" +export CONDA_PREFIX="/home/phuc/kimi_1t/env" + +CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml"} +#CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"} +#CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_405b.toml"} + + +#dcgmi profile --pause +# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below +# to your specific node count, and update target launch file. +srun --export=ALL torchrun --nnodes 40 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" -m torchtitan.train --job.config_file ${CONFIG_FILE} "$@" +#dcgmi profile --resume diff --git a/test_single_node.slurm b/test_single_node.slurm new file mode 100755 index 0000000000..c90b0dc71f --- /dev/null +++ b/test_single_node.slurm @@ -0,0 +1,40 @@ +#!/bin/bash +#SBATCH --job-name=kimi_1t_test_1node +#SBATCH --ntasks=1 +#SBATCH --nodes=1 +#SBATCH --gpus-per-task=8 +#SBATCH --cpus-per-task=64 +#SBATCH --partition=batch +#SBATCH --time=00:30:00 + +nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) +nodes_array=($nodes) +head_node=${nodes_array[0]} +head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) + +echo "Node IP: $head_node_ip" +export LOGLEVEL=INFO +export FI_PROVIDER="efa" +export NCCL_DEBUG=WARN +export PYTHONFAULTHANDLER=1 +export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH +export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH +export CUDA_LAUNCH_BLOCKING=0 +export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" +export NCCL_BUFFSIZE=2097152 +export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 +export TRITON_HOME=/tmp/emotritoncache +export NUMBA_CACHE_DIR=/tmp/numbacache +export WANDB_ENTITY="nous_research" +export WANDB_PROJECT="moe" +export HF_HOME="/home/phuc/.cache/huggingface" +mkdir -p $HF_HOME + +# Activate conda environment +export PATH="/home/phuc/kimi_1t/env/bin:$PATH" +export CONDA_PREFIX="/home/phuc/kimi_1t/env" + +CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml" + +echo "Testing single node with 8 GPUs" +srun torchrun --nnodes 1 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" -m torchtitan.train --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/distributed/activation_checkpoint_offload.py b/torchtitan/distributed/activation_checkpoint_offload.py new file mode 100644 index 0000000000..dce14874cf --- /dev/null +++ b/torchtitan/distributed/activation_checkpoint_offload.py @@ -0,0 +1,313 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Activation Checkpointing with CPU Offloading Support + +This module extends torchtitan's activation checkpointing with CPU offloading capability, +inspired by DeepSpeed's CPU checkpointing implementation. + +CPU offloading moves activation tensors to CPU RAM during the forward pass and brings them +back to GPU during the backward pass, trading memory for PCIe bandwidth. +""" + +from collections import defaultdict + +import torch +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + checkpoint_wrapper as ptd_checkpoint_wrapper, +) +from torch.utils.checkpoint import ( + CheckpointPolicy, + create_selective_checkpoint_contexts, +) + +from torchtitan.config.job_config import ActivationCheckpoint as ACConfig +from torchtitan.tools.logging import logger + + +def _cpu_offload_context_fn(): + """ + Create a context function for CPU offloading of activation checkpoints. + + This function returns a tuple of contexts that uses saved_tensors_hooks to automatically + offload tensors to CPU when they're saved during forward pass and reload them + to GPU during backward pass. + + Returns: + A tuple of (forward_context, recompute_context) + """ + + def pack_to_cpu(tensor): + """Move tensor to CPU during forward pass""" + if not isinstance(tensor, torch.Tensor): + return tensor + # Only offload CUDA tensors that are floating point and large enough + if tensor.is_cuda and tensor.is_floating_point() and tensor.numel() > 0: + # Use non-blocking transfer for better performance + return tensor.to("cpu", non_blocking=True) + return tensor + + def unpack_from_cpu(tensor): + """Move tensor back to GPU during backward pass""" + if not isinstance(tensor, torch.Tensor): + return tensor + # If tensor is on CPU, move it back to the current CUDA device + if tensor.device.type == "cpu": + return tensor.to(torch.cuda.current_device(), non_blocking=True) + return tensor + + # Return the same context for both forward and recompute phases + ctx = torch.autograd.graph.saved_tensors_hooks(pack_to_cpu, unpack_from_cpu) + return (ctx, ctx) + + +def _cpu_offload_selective_context_fn(ac_config: ACConfig, mm_recompute_shapes: set): + """ + Create a selective checkpoint context with CPU offloading support. + + This combines selective activation checkpointing (choosing which ops to save vs recompute) + with CPU offloading (moving saved tensors to CPU). + + Args: + ac_config: Activation checkpoint configuration + mm_recompute_shapes: Set of matrix multiplication shapes to force recompute + + Returns: + A context function for selective checkpointing with CPU offloading + """ + # Get the default op save list for selective AC + op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten.addmm.default, + torch.ops.aten.bmm.default, + torch.ops.aten.linear.default, + } + + def _get_custom_policy(meta): + def _custom_policy(ctx, func, *args, **kwargs): + # Always save CPU offload ops + if ( + func == torch.ops.aten._to_copy.default + and "cuda" in str(args[0].device) + and "device" in kwargs + and str(kwargs["device"]) == "cpu" + ): + return CheckpointPolicy.MUST_SAVE + + mode = "recompute" if ctx.is_recompute else "forward" + mm_count_key = f"{mode}_mm_count" + if func == torch.ops.aten.mm.default: + if args[1].shape in mm_recompute_shapes: + return CheckpointPolicy.PREFER_RECOMPUTE + meta[mm_count_key] += 1 + + # Saves output of all compute ops, except every second mm + to_save = func in op_sac_save_list and not ( + func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 + ) + return ( + CheckpointPolicy.MUST_SAVE + if to_save + else CheckpointPolicy.PREFER_RECOMPUTE + ) + + return _custom_policy + + def selective_checkpointing_with_cpu_offload(): + """Combined context for selective AC + CPU offload""" + meta = defaultdict(int) + ( + selective_forward_ctx, + selective_recompute_ctx, + ) = create_selective_checkpoint_contexts(_get_custom_policy(meta)) + cpu_offload_forward_ctx, cpu_offload_recompute_ctx = _cpu_offload_context_fn() + + # Stack both contexts for forward phase + class CombinedForwardContext: + def __enter__(self): + self.selective = selective_forward_ctx.__enter__() + self.cpu_offload = cpu_offload_forward_ctx.__enter__() + return self + + def __exit__(self, *args): + self.cpu_offload.__exit__(*args) + self.selective.__exit__(*args) + + # Stack both contexts for recompute phase + class CombinedRecomputeContext: + def __enter__(self): + self.selective = selective_recompute_ctx.__enter__() + self.cpu_offload = cpu_offload_recompute_ctx.__enter__() + return self + + def __exit__(self, *args): + self.cpu_offload.__exit__(*args) + self.selective.__exit__(*args) + + return (CombinedForwardContext(), CombinedRecomputeContext()) + + return selective_checkpointing_with_cpu_offload + + +def apply_full_ac_with_cpu_offload(module: nn.Module, ac_config: ACConfig) -> nn.Module: + """ + Apply full activation checkpointing with CPU offloading to the module. + + This will checkpoint all activations and offload them to CPU RAM. + + Args: + module: The module to apply full AC with CPU offload to + ac_config: The activation checkpointing config + + Returns: + The wrapped module with full AC + CPU offload applied + """ + logger.info("Applying full activation checkpointing with CPU offload") + + return ptd_checkpoint_wrapper( + module, + context_fn=_cpu_offload_context_fn, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug, + ) + + +def apply_selective_ac_with_cpu_offload( + module: nn.Module, + ac_config: ACConfig, + *, + base_fqn: str | None = None, +) -> nn.Module: + """ + Apply selective activation checkpointing with CPU offloading to the module. + + This selectively checkpoints certain operations while offloading saved tensors to CPU. + + Args: + module: The module to apply selective AC with CPU offload to + ac_config: The activation checkpointing config + base_fqn: The base fully qualified name of the module + + Returns: + The wrapped module with selective AC + CPU offload applied + """ + logger.info("Applying selective activation checkpointing with CPU offload") + + # Collect mm shapes to force recompute if configured + mm_recompute_shapes = set() + if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: + for module_fqn, submod in module.named_modules(): + fqn = module_fqn + if base_fqn is not None: + fqn = f"{base_fqn}.{module_fqn}" + if not any( + filter_fqn in fqn + for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns + ): + continue + if not isinstance(submod, nn.Linear): + raise ValueError( + "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " + f"a nn.Linear, but got: {submod}" + ) + out_f, in_f = submod.weight.shape + mm_recompute_shapes.add((in_f, out_f)) + + def context_fn_wrapper(): + return _cpu_offload_selective_context_fn(ac_config, mm_recompute_shapes) + + return ptd_checkpoint_wrapper( + module, + context_fn=context_fn_wrapper, + preserve_rng_state=ac_config.preserve_rng_state, + determinism_check=ac_config.determinism_check, + early_stop=ac_config.early_stop, + debug=ac_config.debug, + ) + + +class ActivationOffloadWrapper(nn.Module): + """ + Wrapper that offloads layer activations to CPU without checkpointing/recomputation. + + This keeps all activations but moves them to CPU RAM to save GPU memory. + """ + + def __init__(self, module: nn.Module): + super().__init__() + self.module = module + self._cpu_activations = [] + + def forward(self, *args, **kwargs): + # Move inputs to GPU if they were offloaded + args = tuple( + self._to_gpu(arg) if isinstance(arg, torch.Tensor) else arg for arg in args + ) + kwargs = { + k: self._to_gpu(v) if isinstance(v, torch.Tensor) else v + for k, v in kwargs.items() + } + + # Run forward pass + output = self.module(*args, **kwargs) + + # Offload output to CPU during forward pass + if isinstance(output, torch.Tensor): + output_cpu = output.to("cpu", non_blocking=True) + # Register hook to bring it back for backward + output.register_hook(lambda grad: self._backward_hook(grad, output_cpu)) + return output_cpu + elif isinstance(output, tuple): + output_cpu = tuple( + o.to("cpu", non_blocking=True) if isinstance(o, torch.Tensor) else o + for o in output + ) + # Register hooks for tensor outputs + for i, (o, o_cpu) in enumerate(zip(output, output_cpu)): + if isinstance(o, torch.Tensor): + o.register_hook( + lambda grad, oc=o_cpu: self._backward_hook(grad, oc) + ) + return output_cpu + return output + + def _to_gpu(self, tensor): + """Move tensor from CPU to GPU""" + if tensor.device.type == "cpu": + return tensor.to(torch.cuda.current_device(), non_blocking=True) + return tensor + + def _backward_hook(self, grad, cpu_activation): + """Called during backward to move activation back to GPU""" + return cpu_activation.to(grad.device, non_blocking=True) + + def __getattr__(self, name): + """Forward attribute access to the wrapped module""" + try: + return super().__getattr__(name) + except AttributeError: + return getattr(self.module, name) + + +def apply_offload_wrapper_only(module: nn.Module) -> nn.Module: + """ + Apply activation offloading WITHOUT checkpointing. + + This wraps the module to offload all activations to CPU, keeping them in memory + but freeing GPU RAM. No recomputation happens - activations are transferred + back to GPU during backward pass. + + Args: + module: The module to wrap + + Returns: + The wrapped module with activation offloading + """ + return ActivationOffloadWrapper(module) diff --git a/torchtitan/memory_defrag.py b/torchtitan/memory_defrag.py new file mode 100644 index 0000000000..b5f3761bcf --- /dev/null +++ b/torchtitan/memory_defrag.py @@ -0,0 +1,101 @@ +"""Memory defragmentation utilities for training""" +import logging +from typing import Optional + +import torch +import torch.distributed as dist + +logger = logging.getLogger(__name__) + + +class MemoryDefragManager: + """Manages memory defragmentation during training""" + + def __init__( + self, + enabled: bool = True, + defrag_freq: int = 10, # Defrag every N steps + aggressive: bool = False, + ): + self.enabled = enabled + self.defrag_freq = defrag_freq + self.aggressive = aggressive + self.step_count = 0 + + if self.enabled: + logger.info( + f"MemoryDefragManager enabled: freq={defrag_freq}, aggressive={aggressive}" + ) + + def step(self, step_num: int): + """Called after each training step""" + if not self.enabled: + return + + self.step_count += 1 + + if self.step_count % self.defrag_freq == 0: + self._defragment() + + def _defragment(self): + """Perform memory defragmentation""" + if not self.enabled: + return + + device = torch.cuda.current_device() + + # Get memory stats before + before_reserved = torch.cuda.memory_reserved(device) + before_allocated = torch.cuda.memory_allocated(device) + + # Method 1: Empty cache (basic) + torch.cuda.empty_cache() + + if self.aggressive: + # Method 2: Synchronize and empty cache again + torch.cuda.synchronize() + if dist.is_initialized(): + dist.barrier() + torch.cuda.empty_cache() + + # Get memory stats after + after_reserved = torch.cuda.memory_reserved(device) + after_allocated = torch.cuda.memory_allocated(device) + + freed_mb = (before_reserved - after_reserved) / (1024**2) + + if freed_mb > 0: + logger.info( + f"[Defrag] Freed {freed_mb:.2f} MB " + f"(reserved: {before_reserved/(1024**3):.2f} GB → {after_reserved/(1024**3):.2f} GB, " + f"allocated: {after_allocated/(1024**2):.2f} MB)" + ) + + +def setup_allocator_config( + max_split_size_mb: Optional[int] = None, + garbage_collection_threshold: Optional[float] = None, + roundup_power2_divisions: Optional[int] = None, +): + """Configure PyTorch CUDA allocator for reduced fragmentation""" + import os + + config_parts = ["expandable_segments:True"] + + if max_split_size_mb is not None: + config_parts.append(f"max_split_size_mb:{max_split_size_mb}") + + if garbage_collection_threshold is not None: + config_parts.append( + f"garbage_collection_threshold:{garbage_collection_threshold}" + ) + + if roundup_power2_divisions is not None: + config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") + + config = ",".join(config_parts) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config + + logger.info(f"Allocator config: {config}") + + return config diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 6b186c5e53..165f2ba156 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -232,6 +232,38 @@ rope_factor=32.0, beta_fast=1, ), + # kimi_k2 with SDPA (no FlexAttention) for Context Parallel compatibility + "kimi_k2_sdpa": DeepSeekV3ModelArgs( + vocab_size=163840, + dim=7168, + inter_dim=18432, + moe_inter_dim=2048, + n_layers=61, + n_dense_layers=1, + n_heads=64, + norm_eps=1e-6, + moe_args=MoEArgs( + num_experts=384, + num_shared_experts=1, + top_k=8, + score_func="sigmoid", + route_norm=True, + route_scale=2.827, + score_before_experts=False, + ), + n_expert_groups=1, + n_limited_groups=1, + q_lora_rank=1536, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + use_flex_attn=False, # Use SDPA for CP compatibility + attn_mask_type="causal", + rope_theta=50000.0, + rope_factor=32.0, + beta_fast=1, + ), } diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index 262ed992bf..a54ac81be8 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -62,8 +62,9 @@ def parallelize_deepseekv3( """ use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") + # NOTE: CP + FlexAttention now supported in PyTorch 2.9+ (PRs #145896, #146397) + # if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + # raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index 1a6ff3cf6e..cbf7548a72 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -86,7 +86,9 @@ class DeepSeekV3ModelArgs(BaseModelArgs): beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.0 - mscale_all_dim: float = 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + mscale_all_dim: float = ( + 1.0 # When mscale == mscale_all_dim, effective mscale is 1.0 + ) def update_from_config(self, job_config: JobConfig, **kwargs) -> None: seq_len = job_config.training.seq_len @@ -102,10 +104,11 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) + # NOTE: CP + FlexAttention now supported in PyTorch 2.9+ (PRs #145896, #146397) + # if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: + # raise NotImplementedError( + # "CP support for FlexAttention is still in progress." + # ) self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 67bb37480e..0c8917edbc 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -5,12 +5,20 @@ # LICENSE file in the root directory of this source tree. import math +from typing import Optional import torch from torch import nn +from torch.distributed.device_mesh import DeviceMesh from torch.nn.attention.flex_attention import and_masks, BlockMask +# Import CP block mask creator for Context Parallel + FlexAttention +try: + from torch.distributed.tensor.experimental._attention import create_cp_block_mask +except ImportError: + create_cp_block_mask = None + from torchtitan.components.peft.lora import lora_or_linear, per_layer_config from torchtitan.components.tokenizer import BaseTokenizer from torchtitan.config.job_config import PEFT @@ -478,6 +486,7 @@ def get_attention_masks( input_batch: torch.Tensor, tokenizer: BaseTokenizer, extra_inputs: dict[str, torch.Tensor] | None = None, + cp_mesh: Optional[DeviceMesh] = None, ) -> AttentionMasksType: mask_mods = [get_causal_mask_mod()] match self.model_args.attn_mask_type: @@ -500,9 +509,23 @@ def get_attention_masks( raise ValueError( f"Unknown attention mask type: {self.model_args.attn_mask_type}" ) - return create_attention_mask( - and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1] - ) + + combined_mask_mod = and_masks(*mask_mods) + seq_len = input_batch.shape[1] + H = self.model_args.n_heads # Number of attention heads + + # Use CP-aware block mask when Context Parallel is enabled + if cp_mesh is not None and create_cp_block_mask is not None: + return create_cp_block_mask( + mask_mod=combined_mask_mod, + B=B, + H=H, + Q_LEN=seq_len, + KV_LEN=seq_len, + device_mesh=cp_mesh, + ) + else: + return create_attention_mask(combined_mask_mod, B, None, seq_len, seq_len) def forward( self, diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml new file mode 100644 index 0000000000..50c110d127 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml @@ -0,0 +1,84 @@ +# DeepSeek V3 ~1B debug model - BASELINE (no activation offload) + +[job] +dump_folder = "./outputs/debug_1b_baseline" +description = "DeepSeek-V3 ~1B debug - BASELINE (no activation offload)" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_1b" # ~1B parameters +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 4096 # Longer sequence to see activation memory +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' +# NO CPU OFFLOAD - BASELINE +cpu_offload = false + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml new file mode 100644 index 0000000000..932fbb7084 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml @@ -0,0 +1,54 @@ +# DeepSeek V3 ~1B - NO AC, NO OFFLOAD (will show true activation memory) + +[job] +dump_folder = "./outputs/debug_1b_no_ac_baseline" +description = "DeepSeek-V3 ~1B - NO AC, NO OFFLOAD" +print_config = false + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_1b" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 2 +seq_len = 4096 +max_norm = 1.0 +steps = 5 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +expert_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "none" # NO ACTIVATION CHECKPOINTING + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml new file mode 100644 index 0000000000..98f0ea4598 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml @@ -0,0 +1,84 @@ +# DeepSeek V3 ~1B debug model - WITH activation offload + +[job] +dump_folder = "./outputs/debug_1b_offload" +description = "DeepSeek-V3 ~1B debug - WITH activation offload" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_1b" # ~1B parameters +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 4096 # Longer sequence to see activation memory +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' +# ENABLE CPU OFFLOAD FOR ACTIVATIONS +cpu_offload = true + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml new file mode 100644 index 0000000000..a823e7e856 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml @@ -0,0 +1,55 @@ +# DeepSeek V3 ~1B - Offload-only (NO AC, just offload to CPU) + +[job] +dump_folder = "./outputs/debug_1b_offload_only" +description = "DeepSeek-V3 ~1B - Offload-only (NO AC)" +print_config = false + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_1b" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 2 +seq_len = 4096 +max_norm = 1.0 +steps = 5 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +expert_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "none" # NO AC - just offload +cpu_offload = true # But enable offload + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml new file mode 100644 index 0000000000..e04aa9c7f7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml @@ -0,0 +1,55 @@ +# DeepSeek V3 ~7B debug model - BASELINE (no activation offload) + +[job] +dump_folder = "./outputs/debug_7b_baseline" +description = "DeepSeek-V3 ~7B debug - BASELINE (no activation offload)" +print_config = false + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_7b" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 1 +seq_len = 4096 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +expert_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "full" +cpu_offload = false + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml new file mode 100644 index 0000000000..19a99d0843 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml @@ -0,0 +1,55 @@ +# DeepSeek V3 ~7B debug model - WITH activation offload + +[job] +dump_folder = "./outputs/debug_7b_offload" +description = "DeepSeek-V3 ~7B debug - WITH activation offload" +print_config = false + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel_7b" +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 + +[training] +local_batch_size = 1 +seq_len = 4096 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +expert_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "full" +cpu_offload = true + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml new file mode 100644 index 0000000000..7c4f45cbe0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml @@ -0,0 +1,82 @@ +[job] +dump_folder = "./outputs" +description = "DeepSeek-V3 debug training with Activation Offloading" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "debugmodel" +# test tokenizer, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 2 +seq_len = 512 +max_norm = 1.0 +steps = 5 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "1F1B" +context_parallel_degree = 1 +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' +# Enable CPU offloading for activations +cpu_offload = true + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml new file mode 100644 index 0000000000..ad3249c845 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1aa0: 8 nodes EP=64 CP=2 LBS=1 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1aa0_8n_EP64_CP2_LBS1_ctx16k" +description = "exp1aa0_8n_EP64_CP2_LBS1_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml new file mode 100644 index 0000000000..1614e22495 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1aa10: 8 nodes EP=64 CP=2 LBS=1 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1aa10_8n_EP64_CP2_LBS1_ctx32k" +description = "exp1aa10_8n_EP64_CP2_LBS1_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml new file mode 100644 index 0000000000..ff5eae2822 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1aa11: 8 nodes EP=64 CP=2 LBS=2 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1aa11_8n_EP64_CP2_LBS2_ctx32k" +description = "exp1aa11_8n_EP64_CP2_LBS2_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml new file mode 100644 index 0000000000..227b601d57 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1aa12: 8 nodes EP=64 CP=2 LBS=4 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1aa12_8n_EP64_CP2_LBS4_ctx32k" +description = "exp1aa12_8n_EP64_CP2_LBS4_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml new file mode 100644 index 0000000000..31467d95f7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1aa13: 8 nodes EP=64 CP=2 LBS=6 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1aa13_8n_EP64_CP2_LBS6_ctx32k" +description = "exp1aa13_8n_EP64_CP2_LBS6_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml new file mode 100644 index 0000000000..e269a0f35a --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1aa14: 8 nodes EP=64 CP=2 LBS=8 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1aa14_8n_EP64_CP2_LBS8_ctx32k" +description = "exp1aa14_8n_EP64_CP2_LBS8_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml new file mode 100644 index 0000000000..e0a8c4725d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1aa1: 8 nodes EP=64 CP=2 LBS=2 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1aa1_8n_EP64_CP2_LBS2_ctx16k" +description = "exp1aa1_8n_EP64_CP2_LBS2_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml new file mode 100644 index 0000000000..a3b1845a19 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1aa2: 8 nodes EP=64 CP=2 LBS=4 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1aa2_8n_EP64_CP2_LBS4_ctx16k" +description = "exp1aa2_8n_EP64_CP2_LBS4_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml new file mode 100644 index 0000000000..92abbed703 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1aa3: 8 nodes EP=64 CP=2 LBS=6 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1aa3_8n_EP64_CP2_LBS6_ctx16k" +description = "exp1aa3_8n_EP64_CP2_LBS6_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml new file mode 100644 index 0000000000..dfb310182c --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1aa4: 8 nodes EP=64 CP=2 LBS=8 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1aa4_8n_EP64_CP2_LBS8_ctx16k" +description = "exp1aa4_8n_EP64_CP2_LBS8_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml new file mode 100644 index 0000000000..33ffb89ef3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1aa5: 8 nodes EP=64 CP=2 LBS=1 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1aa5_8n_EP64_CP2_LBS1_ctx24k" +description = "exp1aa5_8n_EP64_CP2_LBS1_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml new file mode 100644 index 0000000000..7201988625 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1aa6: 8 nodes EP=64 CP=2 LBS=2 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1aa6_8n_EP64_CP2_LBS2_ctx24k" +description = "exp1aa6_8n_EP64_CP2_LBS2_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml new file mode 100644 index 0000000000..5127dafa97 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1aa7: 8 nodes EP=64 CP=2 LBS=4 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1aa7_8n_EP64_CP2_LBS4_ctx24k" +description = "exp1aa7_8n_EP64_CP2_LBS4_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml new file mode 100644 index 0000000000..f6a143b448 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1aa8: 8 nodes EP=64 CP=2 LBS=6 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1aa8_8n_EP64_CP2_LBS6_ctx24k" +description = "exp1aa8_8n_EP64_CP2_LBS6_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml new file mode 100644 index 0000000000..0de9b9730c --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1aa9: 8 nodes EP=64 CP=2 LBS=8 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1aa9_8n_EP64_CP2_LBS8_ctx24k" +description = "exp1aa9_8n_EP64_CP2_LBS8_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml new file mode 100644 index 0000000000..3a60dfc42a --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1ab0: 12 nodes EP=96 CP=2 LBS=1 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1ab0_12n_EP96_CP2_LBS1_ctx16k" +description = "exp1ab0_12n_EP96_CP2_LBS1_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml new file mode 100644 index 0000000000..2ae12c912d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1ab10: 12 nodes EP=96 CP=2 LBS=1 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1ab10_12n_EP96_CP2_LBS1_ctx32k" +description = "exp1ab10_12n_EP96_CP2_LBS1_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml new file mode 100644 index 0000000000..95083976af --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1ab11: 12 nodes EP=96 CP=2 LBS=2 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1ab11_12n_EP96_CP2_LBS2_ctx32k" +description = "exp1ab11_12n_EP96_CP2_LBS2_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml new file mode 100644 index 0000000000..9d96209b09 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1ab12: 12 nodes EP=96 CP=2 LBS=4 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1ab12_12n_EP96_CP2_LBS4_ctx32k" +description = "exp1ab12_12n_EP96_CP2_LBS4_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml new file mode 100644 index 0000000000..3f53673681 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1ab13: 12 nodes EP=96 CP=2 LBS=6 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1ab13_12n_EP96_CP2_LBS6_ctx32k" +description = "exp1ab13_12n_EP96_CP2_LBS6_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml new file mode 100644 index 0000000000..b40dd6b684 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml @@ -0,0 +1,46 @@ +# Exp1ab14: 12 nodes EP=96 CP=2 LBS=8 ctx=32k +[job] +dump_folder = "./outputs/exp1a/exp1ab14_12n_EP96_CP2_LBS8_ctx32k" +description = "exp1ab14_12n_EP96_CP2_LBS8_ctx32k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml new file mode 100644 index 0000000000..4d7a0bbf2d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1ab1: 12 nodes EP=96 CP=2 LBS=2 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1ab1_12n_EP96_CP2_LBS2_ctx16k" +description = "exp1ab1_12n_EP96_CP2_LBS2_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml new file mode 100644 index 0000000000..817ea3793a --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1ab2: 12 nodes EP=96 CP=2 LBS=4 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1ab2_12n_EP96_CP2_LBS4_ctx16k" +description = "exp1ab2_12n_EP96_CP2_LBS4_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml new file mode 100644 index 0000000000..7592a5d554 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1ab3: 12 nodes EP=96 CP=2 LBS=6 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1ab3_12n_EP96_CP2_LBS6_ctx16k" +description = "exp1ab3_12n_EP96_CP2_LBS6_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml new file mode 100644 index 0000000000..5a6dbfa6c0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml @@ -0,0 +1,46 @@ +# Exp1ab4: 12 nodes EP=96 CP=2 LBS=8 ctx=16k +[job] +dump_folder = "./outputs/exp1a/exp1ab4_12n_EP96_CP2_LBS8_ctx16k" +description = "exp1ab4_12n_EP96_CP2_LBS8_ctx16k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml new file mode 100644 index 0000000000..4780f17a45 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1ab5: 12 nodes EP=96 CP=2 LBS=1 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1ab5_12n_EP96_CP2_LBS1_ctx24k" +description = "exp1ab5_12n_EP96_CP2_LBS1_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml new file mode 100644 index 0000000000..4d4a50f8d4 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1ab6: 12 nodes EP=96 CP=2 LBS=2 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1ab6_12n_EP96_CP2_LBS2_ctx24k" +description = "exp1ab6_12n_EP96_CP2_LBS2_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 2 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml new file mode 100644 index 0000000000..353ae40ead --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1ab7: 12 nodes EP=96 CP=2 LBS=4 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1ab7_12n_EP96_CP2_LBS4_ctx24k" +description = "exp1ab7_12n_EP96_CP2_LBS4_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 4 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml new file mode 100644 index 0000000000..24df0535b3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1ab8: 12 nodes EP=96 CP=2 LBS=6 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1ab8_12n_EP96_CP2_LBS6_ctx24k" +description = "exp1ab8_12n_EP96_CP2_LBS6_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 6 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml new file mode 100644 index 0000000000..50de7dac1e --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml @@ -0,0 +1,46 @@ +# Exp1ab9: 12 nodes EP=96 CP=2 LBS=8 ctx=24k +[job] +dump_folder = "./outputs/exp1a/exp1ab9_12n_EP96_CP2_LBS8_ctx24k" +description = "exp1ab9_12n_EP96_CP2_LBS8_ctx24k" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 8 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml new file mode 100644 index 0000000000..e5eecac4d0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml @@ -0,0 +1,46 @@ +# 28k context - 10 nodes EP=16 with FORCE LOAD BALANCE + +[job] +dump_folder = "./outputs/10n_28k_flb" +description = "28k context 10 nodes EP=16 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml new file mode 100644 index 0000000000..cb66bc4a3d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml @@ -0,0 +1,45 @@ +# MEMORY TEST: 10 nodes, EP=16, 2k context +# 10 nodes × 8 GPUs = 80 GPUs, EP=16, 384/16 = 24 experts per GPU + +[job] +dump_folder = "./outputs/kimi_1t_10n_ep16_2k" +description = "Memory test - 10 nodes EP=16 - 2k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 2048 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml new file mode 100644 index 0000000000..d72e81d9ea --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml @@ -0,0 +1,43 @@ +# MEMORY TEST: 10 nodes, EP=16, 4k context +[job] +dump_folder = "./outputs/kimi_1t_10n_ep16_4k" +description = "Memory test - 10 nodes EP=16 - 4k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml new file mode 100644 index 0000000000..79d285cc96 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml @@ -0,0 +1,43 @@ +# MEMORY TEST: 10 nodes, EP=16, 8k context +[job] +dump_folder = "./outputs/kimi_1t_10n_ep16_8k" +description = "Memory test - 10 nodes EP=16 - 8k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 8192 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml new file mode 100644 index 0000000000..166f5594af --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml @@ -0,0 +1,48 @@ +# 28k context with FORCE LOAD BALANCE + AC CPU OFFLOAD - 12 nodes EP=96 +# Testing if activation checkpoint CPU offload helps avoid OOM + +[job] +dump_folder = "./outputs/12n_28k_ac_offload" +description = "28k context 12n EP=96 with AC CPU offload" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 + +[activation_checkpoint] +mode = "full" +cpu_offload = true + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml new file mode 100644 index 0000000000..0dd5b8e9bf --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml @@ -0,0 +1,49 @@ +# 28k context with FORCE LOAD BALANCE + SELECTIVE AC - 12 nodes EP=96 +# Testing if selective op-level AC helps avoid OOM + +[job] +dump_folder = "./outputs/12n_28k_selective_ac" +description = "28k context 12n EP=96 with selective AC" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 + +[activation_checkpoint] +mode = "selective" +selective_ac_option = "op" +cpu_offload = true + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml new file mode 100644 index 0000000000..642fa35563 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml @@ -0,0 +1,48 @@ +# 28k context - 12 nodes EP=96 with CP=2 + FlexAttention (with fix) +# Testing create_cp_block_mask fix with more nodes + +[job] +dump_folder = "./outputs/12n_cp2_28k_flex_fix" +description = "28k context 12n EP=96 CP=2 FlexAttention with CP block mask fix" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml new file mode 100644 index 0000000000..06678106e6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml @@ -0,0 +1,46 @@ +# 30720 context - 12 nodes EP=96 with CP=2 + FlexAttention +[job] +dump_folder = "./outputs/12n_cp2_30720" +description = "30720 context 12n EP=96 CP=2" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 30720 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml new file mode 100644 index 0000000000..fbd8c45b73 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml @@ -0,0 +1,46 @@ +# 32768 context - 12 nodes EP=96 with CP=2 + FlexAttention +[job] +dump_folder = "./outputs/12n_cp2_32768" +description = "32768 context 12n EP=96 CP=2" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml new file mode 100644 index 0000000000..d431b45824 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml @@ -0,0 +1,47 @@ +# 28k context - 12 nodes EP=12 with FORCE LOAD BALANCE +# 32 experts per GPU + +[job] +dump_folder = "./outputs/12n_ep12_28k_flb" +description = "28k context 12 nodes EP=12 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 12 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml new file mode 100644 index 0000000000..1d45a540f3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml @@ -0,0 +1,47 @@ +# 32k context - 12 nodes EP=12 with FORCE LOAD BALANCE +# 32 experts per GPU + +[job] +dump_folder = "./outputs/12n_ep12_32k_flb" +description = "32k context 12 nodes EP=12 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 12 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml new file mode 100644 index 0000000000..bdb5b51596 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml @@ -0,0 +1,47 @@ +# 28k context - 12 nodes EP=96 with FORCE LOAD BALANCE +# 4 experts per GPU (better than 8 nodes with 6 experts/GPU) + +[job] +dump_folder = "./outputs/12n_ep96_28k_flb" +description = "28k context 12 nodes EP=96 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml new file mode 100644 index 0000000000..ba0bfd78c7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml @@ -0,0 +1,47 @@ +# 32k context - 12 nodes EP=96 with FORCE LOAD BALANCE +# 4 experts per GPU (better than 8 nodes with 6 experts/GPU) + +[job] +dump_folder = "./outputs/12n_ep96_32k_flb" +description = "32k context 12 nodes EP=96 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml new file mode 100644 index 0000000000..c2c2584d39 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml @@ -0,0 +1,46 @@ +# 16k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/16k_force_lb" +description = "16k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 16384 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml new file mode 100644 index 0000000000..c782b42092 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml @@ -0,0 +1,45 @@ +# MEMORY TEST: 16 nodes, EP=128, 2k context +# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU + +[job] +dump_folder = "./outputs/kimi_1t_16n_ep128_2k" +description = "Memory test - 16 nodes EP=128 - 2k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 2048 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 128 # 3 experts per GPU + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml new file mode 100644 index 0000000000..483f36cdc0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml @@ -0,0 +1,45 @@ +# MEMORY TEST: 16 nodes, EP=128, 4k context +# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU + +[job] +dump_folder = "./outputs/kimi_1t_16n_ep128_4k" +description = "Memory test - 16 nodes EP=128 - 4k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 128 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml new file mode 100644 index 0000000000..cf587b35df --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml @@ -0,0 +1,45 @@ +# MEMORY TEST: 16 nodes, EP=128, 8k context +# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU + +[job] +dump_folder = "./outputs/kimi_1t_16n_ep128_8k" +description = "Memory test - 16 nodes EP=128 - 8k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 8192 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 128 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml new file mode 100644 index 0000000000..49d382f70e --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml @@ -0,0 +1,46 @@ +# 20k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/20k_force_lb" +description = "20k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 20480 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml new file mode 100644 index 0000000000..36c664c81b --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml @@ -0,0 +1,46 @@ +# 24k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/24k_force_lb" +description = "24k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 24576 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml new file mode 100644 index 0000000000..2aa63e98d7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml @@ -0,0 +1,48 @@ +# 28k context with FORCE LOAD BALANCE + AC CPU OFFLOAD +# Testing if activation checkpoint CPU offload helps avoid OOM + +[job] +dump_folder = "./outputs/28k_ac_offload" +description = "28k context with AC CPU offload" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" +cpu_offload = true + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml new file mode 100644 index 0000000000..b4022f85df --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml @@ -0,0 +1,46 @@ +# 28k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/28k_force_lb" +description = "28k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml new file mode 100644 index 0000000000..a0b5a40665 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml @@ -0,0 +1,49 @@ +# 28k context with FORCE LOAD BALANCE + SELECTIVE AC (op level) +# Testing if selective op-level AC helps avoid OOM + +[job] +dump_folder = "./outputs/28k_selective_ac" +description = "28k context with selective op-level AC" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "selective" +selective_ac_option = "op" +cpu_offload = true + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml new file mode 100644 index 0000000000..ffa6efdc54 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml @@ -0,0 +1,46 @@ +# 32k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/32k_force_lb" +description = "32k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml new file mode 100644 index 0000000000..beea82e717 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml @@ -0,0 +1,47 @@ +# 4k context with FORCE LOAD BALANCE - 8 nodes EP=64 +# Test if forced uniform expert distribution prevents OOM + +[job] +dump_folder = "./outputs/4k_force_lb" +description = "4k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml new file mode 100644 index 0000000000..ab7214f598 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml @@ -0,0 +1,46 @@ +# 6k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/6k_force_lb" +description = "6k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 6144 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml new file mode 100644 index 0000000000..ebab4dda5d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml @@ -0,0 +1,46 @@ +# 8k context with FORCE LOAD BALANCE - 8 nodes EP=64 + +[job] +dump_folder = "./outputs/8k_force_lb" +description = "8k context with force load balance" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 8192 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml new file mode 100644 index 0000000000..44f9f7b80b --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml @@ -0,0 +1,48 @@ +# 28k context - 8 nodes EP=64 with CP=2 (Context Parallel) +# Testing FlexAttention + Context Parallel + +[job] +dump_folder = "./outputs/8n_cp2_28k_flb" +description = "28k context 8 nodes EP=64 CP=2 force LB" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml new file mode 100644 index 0000000000..8e93ca2dc5 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml @@ -0,0 +1,48 @@ +# 28k context - 8 nodes EP=64 with CP=2 + FlexAttention (with fix) +# Testing create_cp_block_mask fix + +[job] +dump_folder = "./outputs/8n_cp2_28k_flex_fix" +description = "28k context CP=2 FlexAttention with CP block mask fix" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml new file mode 100644 index 0000000000..9f9f93eba3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml @@ -0,0 +1,48 @@ +# 28k context - 8 nodes EP=64 with CP=2 (Context Parallel) +# Using SDPA instead of FlexAttention for CP compatibility + +[job] +dump_folder = "./outputs/8n_cp2_28k_sdpa" +description = "28k context 8 nodes EP=64 CP=2 SDPA" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2_sdpa" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml new file mode 100644 index 0000000000..c1b2172a88 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml @@ -0,0 +1,46 @@ +# 30720 context - 8 nodes EP=64 with CP=2 + FlexAttention +[job] +dump_folder = "./outputs/8n_cp2_30720" +description = "30720 context 8n EP=64 CP=2" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 30720 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml new file mode 100644 index 0000000000..59a6394f09 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml @@ -0,0 +1,46 @@ +# 32768 context - 8 nodes EP=64 with CP=2 + FlexAttention +[job] +dump_folder = "./outputs/8n_cp2_32768" +description = "32768 context 8n EP=64 CP=2" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml new file mode 100644 index 0000000000..1d5f3d5277 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml @@ -0,0 +1,48 @@ +# 28k context - 8 nodes with CP=2 but NO EP +# Testing if CP works without Expert Parallelism + +[job] +dump_folder = "./outputs/8n_cp2_ep1_28k" +description = "28k context CP=2 NO EP (isolate CP issue)" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2_sdpa" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 3 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 1 +context_parallel_degree = 2 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml new file mode 100644 index 0000000000..991b6dceb3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml @@ -0,0 +1,48 @@ +# 28k context - 8 nodes with TP=2 (Tensor Parallel) +# Testing if TP helps fit 28k by sharding attention + +[job] +dump_folder = "./outputs/8n_tp2_28k" +description = "28k context TP=2 EP=32" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +tensor_parallel_degree = 2 +expert_parallel_degree = 32 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml new file mode 100644 index 0000000000..439ba2dcb9 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml @@ -0,0 +1,81 @@ +[job] +dump_folder = "./outputs" +description = "Kimi K2 1T model training with Activation Offloading" +print_config = false + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2_000 +decay_ratio = 0.8 +decay_type = "cosine" +min_lr_factor = 0.1 + +[training] +local_batch_size = 1 +seq_len = 256 +max_norm = 1.0 +steps = 20 +dataset = "c4_test" +dtype = "bfloat16" +# Disable parameter/gradient CPU offload to test activation offload specifically +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" +expert_parallel_degree = 64 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "bfloat16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' +# Enable CPU offloading for activations +cpu_offload = true + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml new file mode 100644 index 0000000000..4644ba6f40 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml @@ -0,0 +1,50 @@ +# BASELINE: Production-like config (NO optimizations) +# Test: Measure memory usage WITHOUT any optimizations +# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 +# Purpose: Establish baseline memory consumption (like original config) + +[job] +dump_folder = "./outputs/kimi_1t_baseline_ep32_40nodes_no_offload" +description = "Baseline - No Optimizations - EP32 - 40 nodes" + +[profiling] +enable_profiling = false +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" + +# BASELINE: No optimizations (like original) +# enable_cpu_offload = false (default) +# enable_detailed_memory_tracking = false (default) +# clear_cache_between_steps = false (default) + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 32 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml new file mode 100644 index 0000000000..0c3ec0b6e1 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml @@ -0,0 +1,51 @@ +# BASELINE: EP=64, 8 nodes, NO optimizations +# Test: Measure memory WITHOUT CPU offload or cache clearing +# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 +# Purpose: Establish baseline at EP=64 (should work, ~23-25 GB expected) + +[job] +dump_folder = "./outputs/kimi_1t_baseline_ep64_8n_no_offload" +description = "Baseline EP64 - No Optimizations - 8 nodes" + +[profiling] +enable_profiling = false +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" +skip_optimizer_step = true +enable_detailed_memory_tracking = true + +# BASELINE: No optimizations +# enable_cpu_offload = false (default) +# clear_cache_between_steps = false (default) + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 # 384 experts / 64 = 6 experts per GPU + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true # Production-like diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml new file mode 100644 index 0000000000..019ae9fa0b --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml @@ -0,0 +1,52 @@ +# TEST: WITH CPU Offloading +# Test: Measure memory usage WITH CPU offloading +# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 +# Purpose: Measure CPU offloading impact vs baseline + +[job] +dump_folder = "./outputs/kimi_1t_cpuoffload_ep32_40nodes_with_offload" +description = "CPU Offload Test - EP32 - 40 nodes" + +[profiling] +enable_profiling = true +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" + +# TEST: CPU offload ENABLED +enable_cpu_offload = true + +# Memory tracking enabled for both tests +enable_detailed_memory_tracking = true +clear_cache_between_steps = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 32 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml new file mode 100644 index 0000000000..8e3e0af01d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml @@ -0,0 +1,53 @@ +# DEBUG: Deep profiling to find memory bottleneck +# Test: EP=32 with CPU offload - investigate why memory spikes to 73GB +# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 +# Purpose: Find where memory is being allocated during forward pass + +[job] +dump_folder = "./outputs/kimi_1t_debug_ep32_40nodes_with_offload" +description = "DEBUG - Memory Investigation - EP32 - 40 nodes" + +[profiling] +enable_profiling = true +profile_freq = 1 # Profile at step 1 +enable_memory_snapshot = true # Enable memory snapshot +save_traces_folder = "profile_traces" +with_stack = true # Capture stack traces +with_modules = true + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 2 # Just 2 steps for quick debug +dataset = "c4_test" + +# OPTIMIZATIONS ENABLED - Same as Job 1270 +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 32 # 12 experts per GPU + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false # No compile for cleaner profiling diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml new file mode 100644 index 0000000000..3d3368cfc7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml @@ -0,0 +1,52 @@ +# DEBUG: 8 nodes, EP=32, seq=512 to reproduce OOM +# Expected: Should OOM like Job 1270, but on stable 8-node setup +# Configuration: 8 nodes, 64 GPUs, EP=32 (12 experts/GPU), seq=512 + +[job] +dump_folder = "./outputs/kimi_1t_debug_ep32_8n_seq512" +description = "DEBUG - EP32 seq512 - 8 nodes - Expected OOM" + +[profiling] +enable_profiling = true +profile_freq = 1 +enable_memory_snapshot = true +save_traces_folder = "profile_traces" +with_stack = true +with_modules = true + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 # Same as Job 1270 +steps = 2 +dataset = "c4_test" + +# OPTIMIZATIONS ENABLED +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 32 # 384 experts / 32 = 12 experts per GPU + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml new file mode 100644 index 0000000000..2402915040 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml @@ -0,0 +1,60 @@ +# Kimi K2 1T - Memory Defragmentation Test +# Testing different allocator settings + selective AC + +[job] +dump_folder = "./outputs/kimi_1t_defrag_test" +description = "Kimi K2 1T - Test fragmentation fixes" + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 3 +profiler_warmup = 1 +profiler_active = 1 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" +with_stack = true +with_modules = true + +[metrics] +log_freq = 1 +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 256 +steps = 5 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = true + +[parallelism] +data_parallel_shard_degree = -1 +tensor_parallel_degree = 1 +pipeline_parallel_degree = 1 +expert_parallel_degree = 64 + +[checkpoint] +enable = false + +[activation_checkpoint] +# Try selective instead of full to reduce allocation churn +mode = "selective" +selective_ac_option = '2' # Checkpoint every 2 layers + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml new file mode 100644 index 0000000000..5af25527d7 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml @@ -0,0 +1,78 @@ +# Kimi K2 1T - DETAILED MEMORY PROFILING Configuration +# +# Purpose: Get exact breakdown of where the 55 GB "activation memory" goes +# Approach: Enhanced memory snapshots + instrumented checkpoints +# Expected output: Detailed memory allocation report showing: +# - NCCL communication buffers +# - FSDP temporary storage +# - MoE all-to-all staging +# - CPU offload overhead +# - Actual checkpointed activations +# - PyTorch allocator overhead + +[job] +dump_folder = "./outputs/kimi_1t_detailed_memory_profiling" +description = "Kimi K2 1T - Detailed Memory Breakdown Profiling" +print_config = true # Print config for verification + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 3 # Profile step 3 (after warmup) +profiler_warmup = 1 +profiler_active = 1 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" +with_stack = true # Enable stack traces for memory allocations +with_modules = true # Enable module tracking + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 +decay_type = "cosine" +min_lr_factor = 0.1 + +[training] +local_batch_size = 1 +seq_len = 256 +max_norm = 1.0 +steps = 5 # Just 5 steps for profiling +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +expert_parallel_degree = 64 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml new file mode 100644 index 0000000000..485b60a104 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml @@ -0,0 +1,47 @@ +# Detailed Memory Tracking Test +# Track memory at every phase with cache clearing between steps + +[job] +dump_folder = "./outputs/kimi_1t_detailed_tracking" + +[profiling] +enable_profiling = true +profile_freq = 10 # Don't take snapshot (expensive) +enable_memory_snapshot = false # Disable snapshot for cleaner logs + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 256 +steps = 5 # Just 5 steps for detailed tracking +dataset = "c4_test" +enable_cpu_offload = true + +# Detailed memory tracking +enable_detailed_memory_tracking = true +clear_cache_between_steps = true # Clear cache after each step + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml new file mode 100644 index 0000000000..17dc3a4538 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml @@ -0,0 +1,116 @@ +[job] +dump_folder = "./outputs" +description = "Kimi K2 1T model training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" +#converters = ["quantize.linear.float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +#[optimizer] +#name = "Muon" +##lr = 2.2e-4 +#lr = 1e-5 +#eps = 1e-8 +#weight_decay = 0.1 +# +## Muon-specific parameters +#mu = 0.95 # Momentum factor for Muon +#algorithm = "muon" # Main algorithm to use for 2D matrices +#nesterov = false # Whether to use Nesterov momentum +#adjust_lr = "rms_norm" # How to adjust LR: "spectral_norm", "rms_norm", or null +#flatten = false # Whether to flatten 3D+ tensors to 2D +#use_triton = true # Whether to use Triton kernel for Newton-Schulz +# +## Parameter-specific optimizer selection +#scalar_optimizer = "adamw" # For 1D parameters (biases, layer norms) +#embedding_optimizer = "adamw" # For embedding layers +#head_optimizer = "adamw" # For model head/output layers +#head_lr_scaling = false # Apply 1/sqrt(dim) scaling to head layers +# +## Learning rate scaling factors +#scalar_lr_factor = 1 # LR multiplier for scalar parameters +#embedding_lr_factor = 1 # LR multiplier for embedding parameters +#head_lr_factor = 1.0 # LR multiplier for head parameters (after head_lr_scaling) +#routing_lr_factor = 1.0 # LR multiplier for routing parameters +#expert_lr_factor = 1.0 + +[lr_scheduler] +warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "cosine" +min_lr_factor = 0.1 + +[training] +local_batch_size = 1 +seq_len = 256 +max_norm = 1.0 # grad norm clipping +steps = 10_000 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +dtype = "bfloat16" +enable_cpu_offload = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" +expert_parallel_degree = 64 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "bfloat16" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable = false +components = ["model", "loss"] # ["model", "loss"] +# fullgraph = false + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] + +#[deepep] +#sync_comm_stream = false +#fused_weighted_scatter_add = false +#fused_silu_gate_prob = true +# +#[debug] +#moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml new file mode 100644 index 0000000000..f48ab381cc --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml @@ -0,0 +1,50 @@ +# Test 3: Combined Approach +# Approach: Selective AC + Defragmentation + Better allocator + +[job] +dump_folder = "./outputs/kimi_1t_fix_combined" + +[profiling] +enable_profiling = true +profile_freq = 3 +enable_memory_snapshot = true +with_stack = true + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 256 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +# Moderate defragmentation (not every step) +enable_memory_defrag = true +defrag_freq = 2 # Every 2 steps +aggressive_defrag = false + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +# Selective AC +mode = "selective" +selective_ac_option = '2' + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml new file mode 100644 index 0000000000..acc42b24d6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml @@ -0,0 +1,48 @@ +# Test 1: Aggressive Memory Defragmentation +# Approach: Clear cache every step to reduce fragmentation + +[job] +dump_folder = "./outputs/kimi_1t_fix_defrag" + +[profiling] +enable_profiling = true +profile_freq = 3 +enable_memory_snapshot = true +with_stack = true + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 256 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +# Enable aggressive defragmentation +enable_memory_defrag = true +defrag_freq = 1 # Every step +aggressive_defrag = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" # Keep full AC + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml new file mode 100644 index 0000000000..c41ba80ffc --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml @@ -0,0 +1,47 @@ +# Test 2: Selective Activation Checkpointing +# Approach: Checkpoint fewer layers to reduce allocation churn + +[job] +dump_folder = "./outputs/kimi_1t_fix_selective_ac" + +[profiling] +enable_profiling = true +profile_freq = 3 +enable_memory_snapshot = true +with_stack = true + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 256 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_memory_defrag = false + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +# Selective AC: checkpoint every 2 layers instead of all +mode = "selective" +selective_ac_option = '2' + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml new file mode 100644 index 0000000000..38a0761336 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml @@ -0,0 +1,46 @@ +# MEMORY TEST: 16k context length +# Purpose: Measure memory at 16384 seq_len + +[job] +dump_folder = "./outputs/kimi_1t_memory_16k" +description = "Memory test - 16k context - CPU Offload" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 16384 # 16k context +steps = 5 +dataset = "c4_test" + +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml new file mode 100644 index 0000000000..f4f5358293 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml @@ -0,0 +1,43 @@ +# MEMORY TEST: 1k context length +[job] +dump_folder = "./outputs/kimi_1t_memory_1k" +description = "Memory test - 1k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 1024 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml new file mode 100644 index 0000000000..4895a24ab3 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml @@ -0,0 +1,43 @@ +# MEMORY TEST: 2k context length +[job] +dump_folder = "./outputs/kimi_1t_memory_2k" +description = "Memory test - 2k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 2048 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml new file mode 100644 index 0000000000..9b3611ddd8 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml @@ -0,0 +1,43 @@ +# MEMORY TEST: 4k context length +[job] +dump_folder = "./outputs/kimi_1t_memory_4k" +description = "Memory test - 4k context" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml new file mode 100644 index 0000000000..2cb6ac136d --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml @@ -0,0 +1,46 @@ +# MEMORY TEST: 8k context length +# Purpose: Measure memory at 8192 seq_len + +[job] +dump_folder = "./outputs/kimi_1t_memory_8k" +description = "Memory test - 8k context - CPU Offload" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 8192 # 8k context +steps = 5 +dataset = "c4_test" + +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml new file mode 100644 index 0000000000..4fa07a6084 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml @@ -0,0 +1,53 @@ +# DEEP MEMORY PROFILING: 24k context with FORCE LOAD BALANCE (works) +# Compare with 28k to find exact OOM location + +[job] +dump_folder = "./outputs/memprof_24k_flb" +description = "Deep Memory Profiling - 24k context (force LB)" + +[profiling] +enable_profiling = true +profile_freq = 3 +profiler_warmup = 1 +profiler_active = 1 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_24k_flb" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 24576 +steps = 2 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml new file mode 100644 index 0000000000..c3ee3516fa --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml @@ -0,0 +1,53 @@ +# DEEP MEMORY PROFILING: 28k context with FORCE LOAD BALANCE (OOM) +# Compare with 24k to find exact OOM location + +[job] +dump_folder = "./outputs/memprof_28k_flb" +description = "Deep Memory Profiling - 28k context (force LB)" + +[profiling] +enable_profiling = true +profile_freq = 3 +profiler_warmup = 1 +profiler_active = 1 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_28k_flb" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 28672 +steps = 2 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml new file mode 100644 index 0000000000..b0d8678ca1 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml @@ -0,0 +1,50 @@ +# DEEP MEMORY PROFILING: 2k context (baseline that works) +# Purpose: Capture detailed memory snapshots to compare with 4k + +[job] +dump_folder = "./outputs/memprof_2k" +description = "Deep Memory Profiling - 2k context" + +[profiling] +enable_profiling = true +profile_freq = 5 +profiler_warmup = 1 +profiler_active = 2 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_2k" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 2048 +steps = 3 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml new file mode 100644 index 0000000000..3571e2a0e8 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml @@ -0,0 +1,53 @@ +# DEEP MEMORY PROFILING: 2k context with FORCE LOAD BALANCE +# Baseline for comparison with 4k + +[job] +dump_folder = "./outputs/memprof_2k_flb" +description = "Deep Memory Profiling - 2k context (force LB)" + +[profiling] +enable_profiling = true +profile_freq = 5 +profiler_warmup = 1 +profiler_active = 2 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_2k_flb" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 2048 +steps = 3 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml new file mode 100644 index 0000000000..cb829756b4 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml @@ -0,0 +1,50 @@ +# DEEP MEMORY PROFILING: 4k context (OOMs - need to capture where) +# Purpose: Capture detailed memory snapshots to identify OOM point + +[job] +dump_folder = "./outputs/memprof_4k" +description = "Deep Memory Profiling - 4k context" + +[profiling] +enable_profiling = true +profile_freq = 5 +profiler_warmup = 1 +profiler_active = 2 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_4k" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 3 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml new file mode 100644 index 0000000000..3f0b618255 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml @@ -0,0 +1,53 @@ +# DEEP MEMORY PROFILING: 4k context with FORCE LOAD BALANCE +# Compare with 2k to find where memory increases + +[job] +dump_folder = "./outputs/memprof_4k_flb" +description = "Deep Memory Profiling - 4k context (force LB)" + +[profiling] +enable_profiling = true +profile_freq = 5 +profiler_warmup = 1 +profiler_active = 2 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot_4k_flb" + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 4096 +steps = 3 +dataset = "c4_test" +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false + +[debug] +moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml new file mode 100644 index 0000000000..bb1c50aad0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml @@ -0,0 +1,51 @@ +# CPU OFFLOAD ONLY: EP=64, 8 nodes, NO cache clearing +# Test: Measure memory WITH CPU offload but WITHOUT cache clearing +# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 +# Purpose: Apple-to-apple comparison with cache clearing + +[job] +dump_folder = "./outputs/kimi_1t_offload_no_cache_clear" +description = "CPU Offload Only - No Cache Clearing - 8 nodes" + +[profiling] +enable_profiling = false +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" +skip_optimizer_step = true +enable_detailed_memory_tracking = true + +# CPU OFFLOAD ONLY - NO CACHE CLEARING +enable_cpu_offload = true +clear_cache_between_steps = false + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml new file mode 100644 index 0000000000..7ad05e8bae --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml @@ -0,0 +1,50 @@ +# OPTIMIZED: WITH all memory optimizations +# Test: Measure memory usage WITH CPU offloading + tracking + cache clearing +# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 +# Purpose: Measure optimization impact vs baseline + +[job] +dump_folder = "./outputs/kimi_1t_optimized_ep32_40nodes_with_offload" +description = "Optimized - CPU Offload + Tracking - EP32 - 40 nodes" + +[profiling] +enable_profiling = true +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" + +# OPTIMIZATIONS ENABLED +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 32 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml new file mode 100644 index 0000000000..ff0d3bb4b0 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml @@ -0,0 +1,51 @@ +# OPTIMIZED: EP=64, 8 nodes, WITH all optimizations +# Test: Measure memory WITH CPU offload + cache clearing + tracking +# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 +# Purpose: Measure optimization impact at EP=64 + +[job] +dump_folder = "./outputs/kimi_1t_optimized_ep64_8n_with_offload" +description = "Optimized EP64 - CPU Offload + Cache Clearing - 8 nodes" + +[profiling] +enable_profiling = true +profile_freq = 10 +enable_memory_snapshot = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 5 +dataset = "c4_test" + +# OPTIMIZATIONS ENABLED +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 # Same as baseline: 6 experts per GPU + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false # For cleaner memory tracking diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml new file mode 100644 index 0000000000..f7cda07a23 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml @@ -0,0 +1,116 @@ +[job] +dump_folder = "./outputs" +description = "Kimi K2 1T model training" +print_config = false + +[profiling] +enable_profiling = true +save_traces_folder = "profile_trace" +profile_freq = 5 +enable_memory_snapshot = true +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" +#converters = ["quantize.linear.float8"] + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +eps = 1e-8 + +#[optimizer] +#name = "Muon" +##lr = 2.2e-4 +#lr = 1e-5 +#eps = 1e-8 +#weight_decay = 0.1 +# +## Muon-specific parameters +#mu = 0.95 # Momentum factor for Muon +#algorithm = "muon" # Main algorithm to use for 2D matrices +#nesterov = false # Whether to use Nesterov momentum +#adjust_lr = "rms_norm" # How to adjust LR: "spectral_norm", "rms_norm", or null +#flatten = false # Whether to flatten 3D+ tensors to 2D +#use_triton = true # Whether to use Triton kernel for Newton-Schulz +# +## Parameter-specific optimizer selection +#scalar_optimizer = "adamw" # For 1D parameters (biases, layer norms) +#embedding_optimizer = "adamw" # For embedding layers +#head_optimizer = "adamw" # For model head/output layers +#head_lr_scaling = false # Apply 1/sqrt(dim) scaling to head layers +# +## Learning rate scaling factors +#scalar_lr_factor = 1 # LR multiplier for scalar parameters +#embedding_lr_factor = 1 # LR multiplier for embedding parameters +#head_lr_factor = 1.0 # LR multiplier for head parameters (after head_lr_scaling) +#routing_lr_factor = 1.0 # LR multiplier for routing parameters +#expert_lr_factor = 1.0 + +[lr_scheduler] +warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "cosine" +min_lr_factor = 0.1 + +[training] +local_batch_size = 1 +seq_len = 256 +max_norm = 1.0 # grad norm clipping +steps = 20 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) +dtype = "bfloat16" +enable_cpu_offload = true + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +pipeline_parallel_degree = 1 +pipeline_parallel_schedule = "Interleaved1F1B" +expert_parallel_degree = 64 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "bfloat16" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" + +[activation_checkpoint] +mode = "full" # ["none", "selective", "full"] +selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable = false +components = ["model", "loss"] # ["model", "loss"] +# fullgraph = false + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] + +#[deepep] +#sync_comm_stream = false +#fused_weighted_scatter_add = false +#fused_silu_gate_prob = true +# +#[debug] +#moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml new file mode 100644 index 0000000000..cdb16eac9c --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml @@ -0,0 +1,51 @@ +# PROFILING: EP=64, 8 nodes, CPU offload + cache clear +# Purpose: Deep profiling to identify bottleneck operations +# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 + +[job] +dump_folder = "./outputs/kimi_1t_profiling_ep64" +description = "Deep Profiling - CPU Offload + Cache Clear - 8 nodes" + +[profiling] +enable_profiling = true +profile_freq = 5 # Must be >= warmup (1) + active (3) +profiler_warmup = 1 +profiler_active = 3 + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "./assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +local_batch_size = 1 +seq_len = 512 +steps = 15 # More steps for better profiling data (need >= profile_freq for traces) +dataset = "c4_test" + +# OPTIMIZATIONS ENABLED +enable_cpu_offload = true +enable_detailed_memory_tracking = true +clear_cache_between_steps = true +skip_optimizer_step = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 64 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = false diff --git a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml new file mode 100644 index 0000000000..140744642b --- /dev/null +++ b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml @@ -0,0 +1,70 @@ +# Qwen3 1.7B - Local test WITHOUT activation offloading (BASELINE) + +[job] +dump_folder = "./outputs/qwen3_1.7b_baseline" +description = "Qwen 3 1.7B local test - BASELINE (no activation offload)" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "1.7B" +hf_assets_path = "./assets/hf/Qwen3-1.7B" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 5 + +[training] +local_batch_size = 1 +seq_len = 2048 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 50 +last_save_model_only = false +export_dtype = "bfloat16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = "op" +# NO CPU OFFLOAD - BASELINE +cpu_offload = false + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml new file mode 100644 index 0000000000..188cc580c2 --- /dev/null +++ b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml @@ -0,0 +1,70 @@ +# Qwen3 1.7B - Local test WITH activation offloading + +[job] +dump_folder = "./outputs/qwen3_1.7b_offload" +description = "Qwen 3 1.7B local test - WITH activation offload" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "1.7B" +hf_assets_path = "./assets/hf/Qwen3-1.7B" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 5 + +[training] +local_batch_size = 1 +seq_len = 2048 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +context_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 50 +last_save_model_only = false +export_dtype = "bfloat16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = "op" +# ENABLE CPU OFFLOAD FOR ACTIVATIONS +cpu_offload = true + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml b/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml new file mode 100644 index 0000000000..3ad3d6321b --- /dev/null +++ b/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml @@ -0,0 +1,75 @@ +# Qwen3 30B A3B with Activation Offloading - Local Test Config + +[job] +dump_folder = "./outputs/qwen3_30b_act_offload_test" +description = "Qwen3 30B A3B with Activation Offloading Test" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 100 + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "qwen3" +flavor = "30B-A3B" +hf_assets_path = "./assets/hf/Qwen3-30B-A3B-Instruct-2507" + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 10 +decay_ratio = 0.8 +decay_type = "cosine" +min_lr_factor = 0.1 + +[training] +local_batch_size = 1 +seq_len = 2048 +max_norm = 1.0 +steps = 10 +dataset = "c4_test" +dtype = "bfloat16" +enable_cpu_offload = false + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 2 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 500 +last_save_model_only = true +export_dtype = "bfloat16" +async_mode = "disabled" + +[activation_checkpoint] +mode = "full" +selective_ac_option = 'op' +# Enable CPU offloading for activations - THIS IS THE KEY TEST +cpu_offload = true + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output"] diff --git a/torchtitan/train.py b/torchtitan/train.py index 71c906a625..198d5cdc9f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -537,10 +537,15 @@ def post_dataloading_process( extra_kwargs: dict[str, Any] = {} if getattr(self.model_args, "use_flex_attn", False): + # Pass CP mesh for Context Parallel + FlexAttention support + cp_mesh = None + if self.parallel_dims.cp_enabled: + cp_mesh = self.parallel_dims.world_mesh["cp"] extra_kwargs["attention_masks"] = self.model_parts[0].get_attention_masks( input_batch=inputs, tokenizer=self.tokenizer, extra_inputs=extra_inputs, + cp_mesh=cp_mesh, ) return inputs, labels, extra_inputs, extra_kwargs From b4f226d5a70259ff53db9845032d639200dd48b8 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 20 Jan 2026 11:37:42 -0800 Subject: [PATCH 03/18] backup --- torchtitan/config/job_config.py | 6 + torchtitan/models/attention.py | 3 +- torchtitan/models/deepseek_v3/__init__.py | 52 +++ torchtitan/train.py | 19 + torchtitan/utils/__init__.py | 19 + torchtitan/utils/nan_tracker.py | 495 ++++++++++++++++++++++ 6 files changed, 593 insertions(+), 1 deletion(-) create mode 100644 torchtitan/utils/__init__.py create mode 100644 torchtitan/utils/nan_tracker.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index fb936ff24b..77535e932c 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -1244,6 +1244,12 @@ class Debug: moe_force_load_balance: bool = False """If True, we force each experts to get the same amount of tokens via round-robin. This option is for debugging usage only.""" + enable_nan_tracker: bool = False + """If True, enable lightweight NaN/Inf tracking to find where NaN first appears in the model.""" + + nan_tracker_verbose: bool = False + """If True, print stats for every layer (very verbose output).""" + @dataclass class JobConfig: diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index 8873ef2f90..c2b3f51ef2 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -46,9 +46,10 @@ class FlexAttentionWrapper(torch.nn.Module): block_mask as a keyword argument to be compatible with _ContextParallel. """ + # Using dynamic=True to avoid CUDA crashes with CP, but need to debug NaN issue _compiled_flex_attn: ClassVar[Callable] = torch.compile( flex_attention, - mode="max-autotune-no-cudagraphs", + dynamic=True, ) def forward( diff --git a/torchtitan/models/deepseek_v3/__init__.py b/torchtitan/models/deepseek_v3/__init__.py index 165f2ba156..452a873fb0 100644 --- a/torchtitan/models/deepseek_v3/__init__.py +++ b/torchtitan/models/deepseek_v3/__init__.py @@ -121,6 +121,58 @@ use_flex_attn=True, attn_mask_type="block_causal", ), + # debugmodel_1b with FlexAttention for CP NaN testing + "debugmodel_1b_flex_attn": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=1024, + inter_dim=4096, + moe_inter_dim=1024, + n_layers=24, + n_dense_layers=2, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, # Enable FlexAttention + attn_mask_type="block_causal", # Same as kimi_k2 + ), + # Debug model with FlexAttention + causal-only mask for CP testing + "debugmodel_flex_attn_causal": DeepSeekV3ModelArgs( + vocab_size=2048, + dim=256, + inter_dim=1024, + moe_inter_dim=256, + n_layers=6, + n_dense_layers=1, + n_heads=16, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=2, + top_k=3, + score_func="softmax", + route_norm=False, + score_before_experts=False, + ), + q_lora_rank=0, + kv_lora_rank=512, + qk_nope_head_dim=128, + qk_rope_head_dim=64, + v_head_dim=128, + mscale=0.70, + use_flex_attn=True, + attn_mask_type="causal", # Simple causal, no document mask + ), "16B": DeepSeekV3ModelArgs( vocab_size=129280, dim=2048, diff --git a/torchtitan/train.py b/torchtitan/train.py index 198d5cdc9f..9323bf428b 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -35,6 +35,7 @@ maybe_enable_memory_snapshot, maybe_enable_profiling, ) +from torchtitan.utils.nan_tracker import create_nan_tracker_for_deepseek class Trainer(torch.distributed.checkpoint.stateful.Stateful): @@ -308,6 +309,17 @@ def __init__(self, job_config: JobConfig): self.ft_manager.maybe_set_all_reduce_hook(self.model_parts) + # Initialize NaN tracker if enabled + self.nan_tracker = None + if job_config.debug.enable_nan_tracker: + rank = int(os.environ.get("RANK", 0)) + self.nan_tracker = create_nan_tracker_for_deepseek( + self.model_parts[0], + rank=rank, + verbose=job_config.debug.nan_tracker_verbose, + ) + logger.info("NaN tracker enabled - will track activations and gradients") + # initialize device memory monitor and get peak flops for MFU calculation device_memory_monitor = self.metrics_processor.device_memory_monitor gpu_peak_flops = utils.get_peak_flops(device_memory_monitor.device_name) @@ -744,6 +756,13 @@ def train_step( self.detailed_memory_tracker.measure("after_forward_backward", self.step) self.cuda_memory_tracker.measure_all("after_forward_backward", self.step) + # Check for NaN/Inf if tracker is enabled + if self.nan_tracker is not None: + if self.nan_tracker.has_nan(): + self.nan_tracker.print_nan_report() + logger.error(f"NaN detected at step {self.step}! See report above.") + self.nan_tracker.step() # Reset for next step + grad_norm = dist_utils.clip_grad_norm_( [p for m in self.model_parts for p in m.parameters()], self.job_config.training.max_norm, diff --git a/torchtitan/utils/__init__.py b/torchtitan/utils/__init__.py new file mode 100644 index 0000000000..1b4ed35bac --- /dev/null +++ b/torchtitan/utils/__init__.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.utils.nan_tracker import ( + create_nan_tracker_for_deepseek, + LayerStats, + NaNTracker, + TensorStats, +) + +__all__ = [ + "NaNTracker", + "create_nan_tracker_for_deepseek", + "TensorStats", + "LayerStats", +] diff --git a/torchtitan/utils/nan_tracker.py b/torchtitan/utils/nan_tracker.py new file mode 100644 index 0000000000..c65a5bd26a --- /dev/null +++ b/torchtitan/utils/nan_tracker.py @@ -0,0 +1,495 @@ +""" +Lightweight NaN/Inf tracker for debugging training issues. + +This module provides hooks to track tensor statistics (min, max, mean, nan_count, inf_count) +at each layer without saving tensors, making it suitable for large model debugging. + +Usage: + from torchtitan.utils.nan_tracker import NaNTracker + + tracker = NaNTracker() + tracker.register_hooks(model) + + # In training loop: + loss = model(inputs) + loss.backward() + + # Check for NaN + if tracker.has_nan(): + tracker.print_nan_report() + + tracker.step() # Reset for next step +""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + +import torch +import torch.nn as nn + + +@dataclass +class TensorStats: + """Statistics for a single tensor.""" + + name: str + shape: Tuple[int, ...] + dtype: str + min_val: float + max_val: float + mean_val: float + std_val: float + nan_count: int + inf_count: int + total_elements: int + + @property + def has_nan(self) -> bool: + return self.nan_count > 0 + + @property + def has_inf(self) -> bool: + return self.inf_count > 0 + + def __str__(self) -> str: + status = "" + if self.has_nan: + status += f" [NaN: {self.nan_count}/{self.total_elements}]" + if self.has_inf: + status += f" [Inf: {self.inf_count}/{self.total_elements}]" + return ( + f"{self.name}: shape={self.shape}, dtype={self.dtype}, " + f"min={self.min_val:.4g}, max={self.max_val:.4g}, " + f"mean={self.mean_val:.4g}, std={self.std_val:.4g}{status}" + ) + + +@dataclass +class LayerStats: + """Statistics for a layer's inputs and outputs.""" + + layer_name: str + layer_type: str + input_stats: List[TensorStats] = field(default_factory=list) + output_stats: List[TensorStats] = field(default_factory=list) + grad_input_stats: List[TensorStats] = field(default_factory=list) + grad_output_stats: List[TensorStats] = field(default_factory=list) + + @property + def has_nan(self) -> bool: + for stats_list in [ + self.input_stats, + self.output_stats, + self.grad_input_stats, + self.grad_output_stats, + ]: + for stats in stats_list: + if stats.has_nan: + return True + return False + + @property + def has_inf(self) -> bool: + for stats_list in [ + self.input_stats, + self.output_stats, + self.grad_input_stats, + self.grad_output_stats, + ]: + for stats in stats_list: + if stats.has_inf: + return True + return False + + +def compute_tensor_stats(tensor: torch.Tensor, name: str) -> Optional[TensorStats]: + """Compute statistics for a tensor without storing it.""" + if tensor is None: + return None + + if not isinstance(tensor, torch.Tensor): + return None + + # Handle DTensor by getting local tensor + if hasattr(tensor, "_local_tensor"): + tensor = tensor._local_tensor + + # Flatten for stats computation + flat = tensor.detach().float().flatten() + + # Count NaN and Inf + nan_mask = torch.isnan(flat) + inf_mask = torch.isinf(flat) + nan_count = nan_mask.sum().item() + inf_count = inf_mask.sum().item() + + # Compute stats on valid values only + valid_mask = ~(nan_mask | inf_mask) + valid_vals = flat[valid_mask] + + if valid_vals.numel() > 0: + min_val = valid_vals.min().item() + max_val = valid_vals.max().item() + mean_val = valid_vals.mean().item() + std_val = valid_vals.std().item() if valid_vals.numel() > 1 else 0.0 + else: + min_val = float("nan") + max_val = float("nan") + mean_val = float("nan") + std_val = float("nan") + + return TensorStats( + name=name, + shape=tuple(tensor.shape), + dtype=str(tensor.dtype), + min_val=min_val, + max_val=max_val, + mean_val=mean_val, + std_val=std_val, + nan_count=int(nan_count), + inf_count=int(inf_count), + total_elements=tensor.numel(), + ) + + +class NaNTracker: + """ + Lightweight tracker for NaN/Inf in model activations and gradients. + + Registers forward and backward hooks on model layers to compute statistics + without storing tensors. + + Args: + track_forward: Track forward pass activations + track_backward: Track backward pass gradients + layer_types: Only track these layer types (default: all) + include_patterns: Only track layers matching these patterns + exclude_patterns: Exclude layers matching these patterns + log_every_layer: Print stats for every layer (verbose) + rank: Current process rank for distributed training + """ + + def __init__( + self, + track_forward: bool = True, + track_backward: bool = True, + layer_types: Optional[Tuple[type, ...]] = None, + include_patterns: Optional[List[str]] = None, + exclude_patterns: Optional[List[str]] = None, + log_every_layer: bool = False, + rank: int = 0, + ): + self.track_forward = track_forward + self.track_backward = track_backward + self.layer_types = layer_types + self.include_patterns = include_patterns or [] + self.exclude_patterns = exclude_patterns or [] + self.log_every_layer = log_every_layer + self.rank = rank + + self.step_num = 0 + self.layer_stats: Dict[str, LayerStats] = {} + self.hooks: List[torch.utils.hooks.RemovableHandle] = [] + self._first_nan_layer: Optional[str] = None + self._first_nan_phase: Optional[str] = None + self._forward_order: List[str] = [] + + def _should_track(self, name: str, module: nn.Module) -> bool: + """Check if this layer should be tracked.""" + # Check layer type filter + if self.layer_types is not None: + if not isinstance(module, self.layer_types): + return False + + # Check include patterns + if self.include_patterns: + if not any(p in name for p in self.include_patterns): + return False + + # Check exclude patterns + if self.exclude_patterns: + if any(p in name for p in self.exclude_patterns): + return False + + return True + + def _create_forward_hook(self, layer_name: str, layer_type: str): + """Create a forward hook for a layer.""" + + def hook(module, inputs, outputs): + if layer_name not in self.layer_stats: + self.layer_stats[layer_name] = LayerStats( + layer_name=layer_name, + layer_type=layer_type, + ) + self._forward_order.append(layer_name) + + stats = self.layer_stats[layer_name] + + # Process inputs + if isinstance(inputs, tuple): + for i, inp in enumerate(inputs): + if isinstance(inp, torch.Tensor): + tensor_stats = compute_tensor_stats(inp, f"input_{i}") + if tensor_stats: + stats.input_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"forward_input_{i}" + + # Process outputs + if isinstance(outputs, torch.Tensor): + tensor_stats = compute_tensor_stats(outputs, "output") + if tensor_stats: + stats.output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = "forward_output" + elif isinstance(outputs, tuple): + for i, out in enumerate(outputs): + if isinstance(out, torch.Tensor): + tensor_stats = compute_tensor_stats(out, f"output_{i}") + if tensor_stats: + stats.output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"forward_output_{i}" + + if self.log_every_layer and self.rank == 0: + self._print_layer_stats(layer_name, "forward") + + return hook + + def _create_backward_hook(self, layer_name: str, layer_type: str): + """Create a backward hook for a layer.""" + + def hook(module, grad_input, grad_output): + if layer_name not in self.layer_stats: + self.layer_stats[layer_name] = LayerStats( + layer_name=layer_name, + layer_type=layer_type, + ) + + stats = self.layer_stats[layer_name] + + # Process grad_output (gradient w.r.t. layer output) + if isinstance(grad_output, tuple): + for i, grad in enumerate(grad_output): + if isinstance(grad, torch.Tensor): + tensor_stats = compute_tensor_stats(grad, f"grad_output_{i}") + if tensor_stats: + stats.grad_output_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"backward_grad_output_{i}" + + # Process grad_input (gradient w.r.t. layer input) + if isinstance(grad_input, tuple): + for i, grad in enumerate(grad_input): + if isinstance(grad, torch.Tensor): + tensor_stats = compute_tensor_stats(grad, f"grad_input_{i}") + if tensor_stats: + stats.grad_input_stats.append(tensor_stats) + if tensor_stats.has_nan and self._first_nan_layer is None: + self._first_nan_layer = layer_name + self._first_nan_phase = f"backward_grad_input_{i}" + + if self.log_every_layer and self.rank == 0: + self._print_layer_stats(layer_name, "backward") + + return hook + + def register_hooks(self, model: nn.Module) -> None: + """Register forward and backward hooks on model layers.""" + for name, module in model.named_modules(): + if not self._should_track(name, module): + continue + + # Skip container modules + if isinstance(module, (nn.ModuleList, nn.ModuleDict, nn.Sequential)): + continue + + layer_type = type(module).__name__ + + if self.track_forward: + hook = module.register_forward_hook( + self._create_forward_hook(name, layer_type) + ) + self.hooks.append(hook) + + if self.track_backward: + hook = module.register_full_backward_hook( + self._create_backward_hook(name, layer_type) + ) + self.hooks.append(hook) + + def remove_hooks(self) -> None: + """Remove all registered hooks.""" + for hook in self.hooks: + hook.remove() + self.hooks.clear() + + def step(self) -> None: + """Reset statistics for next step.""" + self.step_num += 1 + self.layer_stats.clear() + self._first_nan_layer = None + self._first_nan_phase = None + self._forward_order.clear() + + def has_nan(self) -> bool: + """Check if any NaN was detected this step.""" + return self._first_nan_layer is not None + + def has_any_nan_or_inf(self) -> bool: + """Check if any NaN or Inf was detected this step.""" + for stats in self.layer_stats.values(): + if stats.has_nan or stats.has_inf: + return True + return False + + def get_first_nan_location(self) -> Optional[Tuple[str, str]]: + """Get the first layer and phase where NaN appeared.""" + if self._first_nan_layer: + return (self._first_nan_layer, self._first_nan_phase) + return None + + def _print_layer_stats(self, layer_name: str, phase: str) -> None: + """Print statistics for a single layer.""" + if layer_name not in self.layer_stats: + return + + stats = self.layer_stats[layer_name] + print(f"[Step {self.step_num}][{phase}] {layer_name} ({stats.layer_type}):") + + if phase == "forward": + for s in stats.input_stats: + print(f" IN: {s}") + for s in stats.output_stats: + print(f" OUT: {s}") + else: + for s in stats.grad_output_stats: + print(f" GRAD_OUT: {s}") + for s in stats.grad_input_stats: + print(f" GRAD_IN: {s}") + + def print_nan_report(self) -> None: + """Print a detailed report of where NaN/Inf occurred.""" + if self.rank != 0: + return + + print(f"\n{'='*80}") + print(f"NaN/Inf REPORT - Step {self.step_num}") + print(f"{'='*80}") + + if self._first_nan_layer: + print( + f"\n** FIRST NaN detected at: {self._first_nan_layer} ({self._first_nan_phase}) **\n" + ) + + # Print layers in forward order + nan_layers = [] + for layer_name in self._forward_order: + if layer_name in self.layer_stats: + stats = self.layer_stats[layer_name] + if stats.has_nan or stats.has_inf: + nan_layers.append(layer_name) + + if nan_layers: + print(f"Layers with NaN/Inf ({len(nan_layers)} total):") + for layer_name in nan_layers: + stats = self.layer_stats[layer_name] + print(f"\n {layer_name} ({stats.layer_type}):") + for s in stats.input_stats: + if s.has_nan or s.has_inf: + print(f" [FWD IN] {s}") + for s in stats.output_stats: + if s.has_nan or s.has_inf: + print(f" [FWD OUT] {s}") + for s in stats.grad_output_stats: + if s.has_nan or s.has_inf: + print(f" [BWD GRAD_OUT] {s}") + for s in stats.grad_input_stats: + if s.has_nan or s.has_inf: + print(f" [BWD GRAD_IN] {s}") + else: + print("No NaN/Inf detected in tracked layers.") + + print(f"\n{'='*80}\n") + + def print_summary(self) -> None: + """Print a summary of all layer statistics.""" + if self.rank != 0: + return + + print(f"\n{'='*80}") + print(f"LAYER STATISTICS SUMMARY - Step {self.step_num}") + print(f"{'='*80}") + print(f"Total layers tracked: {len(self.layer_stats)}") + + nan_count = sum(1 for s in self.layer_stats.values() if s.has_nan) + inf_count = sum(1 for s in self.layer_stats.values() if s.has_inf) + print(f"Layers with NaN: {nan_count}") + print(f"Layers with Inf: {inf_count}") + + if self._first_nan_layer: + print(f"\nFirst NaN at: {self._first_nan_layer} ({self._first_nan_phase})") + + print(f"{'='*80}\n") + + def get_stats_dict(self) -> Dict[str, Any]: + """Get statistics as a dictionary for logging.""" + result = { + "step": self.step_num, + "has_nan": self.has_nan(), + "first_nan_layer": self._first_nan_layer, + "first_nan_phase": self._first_nan_phase, + "layers_with_nan": [], + "layers_with_inf": [], + } + + for name, stats in self.layer_stats.items(): + if stats.has_nan: + result["layers_with_nan"].append(name) + if stats.has_inf: + result["layers_with_inf"].append(name) + + return result + + +def create_nan_tracker_for_deepseek( + model: nn.Module, + rank: int = 0, + verbose: bool = False, +) -> NaNTracker: + """ + Create a NaN tracker optimized for DeepSeek models. + + Tracks key layers that are most likely to produce NaN: + - Attention layers (FlexAttention, MLA) + - MoE layers (router, experts) + - Normalization layers + - Output projection + """ + tracker = NaNTracker( + track_forward=True, + track_backward=True, + include_patterns=[ + "attention", + "moe", + "router", + "expert", + "norm", + "output", + "feed_forward", + "tok_embeddings", + ], + exclude_patterns=[ + "freqs_cis", + ], + log_every_layer=verbose, + rank=rank, + ) + + tracker.register_hooks(model) + return tracker From 82f2afe83facfc62587cbfc6b6e351d574b06d03 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 20 Jan 2026 11:43:26 -0800 Subject: [PATCH 04/18] Remove experimental configs and debug scripts from tracking Remove kimi_1t, debug, exp1a training configs, qwen3 test configs, and root-level debug scripts from git tracking. --- deep_memory_profiler.py | 245 ------------------ launch_kimi_1t_emozilla.slurm | 84 ------ test_single_node.slurm | 40 --- .../train_configs/debug_1b_baseline.toml | 84 ------ .../debug_1b_no_ac_baseline.toml | 54 ---- .../train_configs/debug_1b_offload.toml | 84 ------ .../train_configs/debug_1b_offload_only.toml | 55 ---- .../train_configs/debug_7b_baseline.toml | 55 ---- .../train_configs/debug_7b_offload.toml | 55 ---- .../debug_activation_offload.toml | 82 ------ .../exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml | 46 ---- .../exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml | 46 ---- .../exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml | 46 ---- .../exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml | 46 ---- .../exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml | 46 ---- .../exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml | 46 ---- .../exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml | 46 ---- .../exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml | 46 ---- .../exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml | 46 ---- .../exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml | 46 ---- .../exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml | 46 ---- .../exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml | 46 ---- .../exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml | 46 ---- .../exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml | 46 ---- .../exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml | 46 ---- .../exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml | 46 ---- .../exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml | 46 ---- .../exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml | 46 ---- .../exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml | 46 ---- .../exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml | 46 ---- .../exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml | 46 ---- .../exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml | 46 ---- .../exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml | 46 ---- .../exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml | 46 ---- .../exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml | 46 ---- .../exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml | 46 ---- .../exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml | 46 ---- .../exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml | 46 ---- .../exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml | 46 ---- .../exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml | 46 ---- .../train_configs/kimi_1t_10n_28k_flb.toml | 46 ---- .../train_configs/kimi_1t_10n_ep16_2k.toml | 45 ---- .../train_configs/kimi_1t_10n_ep16_4k.toml | 43 --- .../train_configs/kimi_1t_10n_ep16_8k.toml | 43 --- .../kimi_1t_12n_28k_ac_offload.toml | 48 ---- .../kimi_1t_12n_28k_selective_ac.toml | 49 ---- .../kimi_1t_12n_cp2_28k_flex_fix.toml | 48 ---- .../train_configs/kimi_1t_12n_cp2_30720.toml | 46 ---- .../train_configs/kimi_1t_12n_cp2_32768.toml | 46 ---- .../kimi_1t_12n_ep12_28k_flb.toml | 47 ---- .../kimi_1t_12n_ep12_32k_flb.toml | 47 ---- .../kimi_1t_12n_ep96_28k_flb.toml | 47 ---- .../kimi_1t_12n_ep96_32k_flb.toml | 47 ---- .../train_configs/kimi_1t_16k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_16n_ep128_2k.toml | 45 ---- .../train_configs/kimi_1t_16n_ep128_4k.toml | 45 ---- .../train_configs/kimi_1t_16n_ep128_8k.toml | 45 ---- .../train_configs/kimi_1t_20k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_24k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_28k_ac_offload.toml | 48 ---- .../train_configs/kimi_1t_28k_force_lb.toml | 46 ---- .../kimi_1t_28k_selective_ac.toml | 49 ---- .../train_configs/kimi_1t_32k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_4k_force_lb.toml | 47 ---- .../train_configs/kimi_1t_6k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_8k_force_lb.toml | 46 ---- .../train_configs/kimi_1t_8n_cp2_28k_flb.toml | 48 ---- .../kimi_1t_8n_cp2_28k_flex_fix.toml | 48 ---- .../kimi_1t_8n_cp2_28k_sdpa.toml | 48 ---- .../train_configs/kimi_1t_8n_cp2_30720.toml | 46 ---- .../train_configs/kimi_1t_8n_cp2_32768.toml | 46 ---- .../train_configs/kimi_1t_8n_cp2_ep1_28k.toml | 48 ---- .../train_configs/kimi_1t_8n_tp2_28k.toml | 48 ---- .../kimi_1t_activation_offload.toml | 81 ------ ...i_1t_baseline_ep32_40nodes_no_offload.toml | 50 ---- ...mi_1t_baseline_ep64_8nodes_no_offload.toml | 51 ---- ..._cpuoffload_ep32_40nodes_with_offload.toml | 52 ---- ...mi_1t_debug_ep32_40nodes_with_offload.toml | 53 ---- .../kimi_1t_debug_ep32_8nodes_seq512.toml | 52 ---- .../train_configs/kimi_1t_defrag_test.toml | 60 ----- .../kimi_1t_detailed_memory_profiling.toml | 78 ------ .../kimi_1t_detailed_tracking.toml | 47 ---- .../train_configs/kimi_1t_emozilla.toml | 116 --------- .../train_configs/kimi_1t_fix_combined.toml | 50 ---- .../train_configs/kimi_1t_fix_defrag.toml | 48 ---- .../kimi_1t_fix_selective_ac.toml | 47 ---- .../train_configs/kimi_1t_memory_16k_ctx.toml | 46 ---- .../train_configs/kimi_1t_memory_1k_ctx.toml | 43 --- .../train_configs/kimi_1t_memory_2k_ctx.toml | 43 --- .../train_configs/kimi_1t_memory_4k_ctx.toml | 43 --- .../train_configs/kimi_1t_memory_8k_ctx.toml | 46 ---- .../kimi_1t_memprof_24k_flb.toml | 53 ---- .../kimi_1t_memprof_28k_flb.toml | 53 ---- .../train_configs/kimi_1t_memprof_2k.toml | 50 ---- .../kimi_1t_memprof_2k_force_lb.toml | 53 ---- .../train_configs/kimi_1t_memprof_4k.toml | 50 ---- .../kimi_1t_memprof_4k_force_lb.toml | 53 ---- .../kimi_1t_offload_no_cache_clear.toml | 51 ---- ...t_optimized_ep32_40nodes_with_offload.toml | 50 ---- ...1t_optimized_ep64_8nodes_with_offload.toml | 51 ---- .../train_configs/kimi_1t_profiling.toml | 116 --------- .../kimi_1t_profiling_ep64_8nodes.toml | 51 ---- .../qwen3_1.7b_local_test_baseline.toml | 70 ----- .../qwen3_1.7b_local_test_offload.toml | 70 ----- ...qwen3_30b_a3b_activation_offload_test.toml | 75 ------ 105 files changed, 5605 deletions(-) delete mode 100644 deep_memory_profiler.py delete mode 100755 launch_kimi_1t_emozilla.slurm delete mode 100755 test_single_node.slurm delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml delete mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml delete mode 100644 torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml delete mode 100644 torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml delete mode 100644 torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml diff --git a/deep_memory_profiler.py b/deep_memory_profiler.py deleted file mode 100644 index e6bfb93932..0000000000 --- a/deep_memory_profiler.py +++ /dev/null @@ -1,245 +0,0 @@ -#!/usr/bin/env python3 -""" -Deep Memory Profiler for Kimi K2 1T Model -Tracks memory allocation at each layer/operation to identify where OOM occurs. -""" - -import json -import sys -from collections import defaultdict -from typing import Dict, List - -import torch - - -class DeepMemoryProfiler: - def __init__(self, output_file: str = "memory_profile.json"): - self.output_file = output_file - self.memory_events: List[Dict] = [] - self.hooks = [] - self.current_step = 0 - self.current_phase = "init" - - def _get_memory_stats(self) -> Dict: - """Get current GPU memory statistics.""" - if not torch.cuda.is_available(): - return {} - - stats = torch.cuda.memory_stats() - return { - "allocated_gb": torch.cuda.memory_allocated() / 1e9, - "reserved_gb": torch.cuda.memory_reserved() / 1e9, - "max_allocated_gb": torch.cuda.max_memory_allocated() / 1e9, - "active_gb": stats.get("active_bytes.all.current", 0) / 1e9, - "inactive_gb": stats.get("inactive_split_bytes.all.current", 0) / 1e9, - "num_alloc_retries": stats.get("num_alloc_retries", 0), - "num_ooms": stats.get("num_ooms", 0), - } - - def log_memory(self, event_name: str, extra_info: Dict = None): - """Log memory at a specific event.""" - mem_stats = self._get_memory_stats() - event = { - "step": self.current_step, - "phase": self.current_phase, - "event": event_name, - "memory": mem_stats, - } - if extra_info: - event["extra"] = extra_info - self.memory_events.append(event) - - # Print for real-time monitoring - print( - f"[MemProf] Step {self.current_step} | {self.current_phase} | {event_name} | " - f"Alloc: {mem_stats.get('allocated_gb', 0):.2f} GB | " - f"Reserved: {mem_stats.get('reserved_gb', 0):.2f} GB" - ) - - def _make_forward_hook(self, layer_name: str): - """Create a forward hook for a layer.""" - - def hook(module, input, output): - input_shapes = [] - for inp in input: - if isinstance(inp, torch.Tensor): - input_shapes.append(list(inp.shape)) - - output_shapes = [] - if isinstance(output, torch.Tensor): - output_shapes.append(list(output.shape)) - elif isinstance(output, (tuple, list)): - for out in output: - if isinstance(out, torch.Tensor): - output_shapes.append(list(out.shape)) - - self.log_memory( - f"forward:{layer_name}", - { - "input_shapes": input_shapes, - "output_shapes": output_shapes, - }, - ) - - return hook - - def _make_backward_hook(self, layer_name: str): - """Create a backward hook for a layer.""" - - def hook(module, grad_input, grad_output): - self.log_memory(f"backward:{layer_name}") - - return hook - - def attach_hooks(self, model: torch.nn.Module, layers_to_track: List[str] = None): - """Attach memory tracking hooks to model layers.""" - if layers_to_track is None: - # Default: track key layers in DeepSeek/MoE model - layers_to_track = [ - "embed_tokens", - "layers.0", # First transformer layer - "layers.30", # Middle layer - "layers.60", # Last layer (if exists) - "moe", # MoE layers - "experts", # Expert modules - "norm", - "lm_head", - ] - - for name, module in model.named_modules(): - should_track = any(track_name in name for track_name in layers_to_track) - if should_track: - # Forward hook - handle = module.register_forward_hook(self._make_forward_hook(name)) - self.hooks.append(handle) - # Backward hook - handle = module.register_full_backward_hook( - self._make_backward_hook(name) - ) - self.hooks.append(handle) - print(f"[MemProf] Attached hooks to: {name}") - - def remove_hooks(self): - """Remove all hooks.""" - for hook in self.hooks: - hook.remove() - self.hooks = [] - - def set_step(self, step: int): - self.current_step = step - - def set_phase(self, phase: str): - self.current_phase = phase - - def save_profile(self): - """Save memory profile to JSON file.""" - with open(self.output_file, "w") as f: - json.dump(self.memory_events, f, indent=2) - print(f"[MemProf] Saved profile to {self.output_file}") - - def print_summary(self): - """Print memory profile summary.""" - print("\n" + "=" * 80) - print("MEMORY PROFILE SUMMARY") - print("=" * 80) - - # Group by event name and find max memory - event_max_mem = defaultdict(float) - event_counts = defaultdict(int) - - for event in self.memory_events: - name = event["event"] - mem = event["memory"].get("allocated_gb", 0) - event_max_mem[name] = max(event_max_mem[name], mem) - event_counts[name] += 1 - - # Sort by max memory - sorted_events = sorted(event_max_mem.items(), key=lambda x: x[1], reverse=True) - - print(f"\n{'Event':<60} {'Max Alloc (GB)':<15} {'Count':<10}") - print("-" * 85) - for event_name, max_mem in sorted_events[:30]: - print(f"{event_name:<60} {max_mem:<15.2f} {event_counts[event_name]:<10}") - - # Find peak memory point - if self.memory_events: - peak_event = max( - self.memory_events, key=lambda x: x["memory"].get("reserved_gb", 0) - ) - print(f"\n{'='*80}") - print( - f"PEAK MEMORY: {peak_event['memory'].get('reserved_gb', 0):.2f} GB reserved" - ) - print( - f" At: Step {peak_event['step']} | Phase: {peak_event['phase']} | Event: {peak_event['event']}" - ) - if "extra" in peak_event: - print(f" Extra: {peak_event['extra']}") - - -def analyze_memory_difference(profile_2k: str, profile_4k: str): - """Compare memory profiles between 2k and 4k to find differences.""" - with open(profile_2k) as f: - events_2k = json.load(f) - with open(profile_4k) as f: - events_4k = json.load(f) - - print("\n" + "=" * 80) - print("MEMORY COMPARISON: 2k vs 4k context") - print("=" * 80) - - # Build event maps - def build_event_map(events): - event_map = {} - for e in events: - key = (e["step"], e["phase"], e["event"]) - event_map[key] = e["memory"] - return event_map - - map_2k = build_event_map(events_2k) - map_4k = build_event_map(events_4k) - - # Find common events and compare - common_keys = set(map_2k.keys()) & set(map_4k.keys()) - - differences = [] - for key in common_keys: - mem_2k = map_2k[key].get("allocated_gb", 0) - mem_4k = map_4k[key].get("allocated_gb", 0) - diff = mem_4k - mem_2k - if abs(diff) > 0.1: # Only show significant differences - differences.append((key, mem_2k, mem_4k, diff)) - - # Sort by difference - differences.sort(key=lambda x: x[3], reverse=True) - - print(f"\n{'Event':<50} {'2k (GB)':<10} {'4k (GB)':<10} {'Diff (GB)':<10}") - print("-" * 80) - for key, mem_2k, mem_4k, diff in differences[:20]: - step, phase, event = key - event_short = event[:45] if len(event) > 45 else event - print(f"{event_short:<50} {mem_2k:<10.2f} {mem_4k:<10.2f} {diff:<+10.2f}") - - # Summary - total_2k = ( - max(e["memory"].get("reserved_gb", 0) for e in events_2k) if events_2k else 0 - ) - total_4k = ( - max(e["memory"].get("reserved_gb", 0) for e in events_4k) if events_4k else 0 - ) - - print(f"\n{'='*80}") - print("Peak Reserved Memory:") - print(f" 2k context: {total_2k:.2f} GB") - print(f" 4k context: {total_4k:.2f} GB") - print(f" Difference: {total_4k - total_2k:+.2f} GB") - - -if __name__ == "__main__": - if len(sys.argv) > 2: - # Compare mode - analyze_memory_difference(sys.argv[1], sys.argv[2]) - else: - print( - "Usage: python deep_memory_profiler.py " - ) diff --git a/launch_kimi_1t_emozilla.slurm b/launch_kimi_1t_emozilla.slurm deleted file mode 100755 index 0876b3f364..0000000000 --- a/launch_kimi_1t_emozilla.slurm +++ /dev/null @@ -1,84 +0,0 @@ -#!/bin/bash -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. - -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -# --- This script is optimized for AWS with EFA -# --- adjust NCCL_BUFFSIZE if you encounter memory -# --- constraint issues or to tune for improved performance. -# --- - -#SBATCH --job-name=kimi_1t_emozilla - -#SBATCH --ntasks=40 - -#SBATCH --nodes=40 - -#SBATCH --gpus-per-task=8 - -#SBATCH --cpus-per-task=64 - -#SBATCH --partition=batch - - -nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) -nodes_array=($nodes) -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - -echo Node IP: $head_node_ip -export LOGLEVEL=INFO -# Enable for A100 -export FI_PROVIDER="efa" -# Ensure that P2P is available -# export NCCL_P2P_DISABLE=1 -# export NCCL_IB_DISABLE=1 - -# debugging flags (optional) -export NCCL_DEBUG=WARN -export PYTHONFAULTHANDLER=1 -# optional debug settings -# export NCCL_DEBUG=INFO -# NCCL_DEBUG_SUBSYS=INIT,GRAPH,ENV - -export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH -export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH -export CUDA_LAUNCH_BLOCKING=0 - -# on your cluster you might need these: -# set the network interface -export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" -export NCCL_BUFFSIZE=2097152 -#export TORCH_DIST_INIT_BARRIER=1 -export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 - -export TRITON_HOME=/tmp/emotritoncache_$$_$SLURM_PROCID -export NUMBA_CACHE_DIR=/tmp/numbacache_$$_$SLURM_PROCID -mkdir -p $TRITON_HOME $NUMBA_CACHE_DIR - -export WANDB_ENTITY="nous_research" -#export WANDB_PROJECT="torchtune" -export WANDB_PROJECT="moe" - -export HF_HOME="/home/phuc/.cache/huggingface" -mkdir -p $HF_HOME -#export NVSHMEM_DISABLE_NIC_LOCKING=1 -#export NVSHMEM_VERBOSE=3 -#export PYTORCH_ALLOC_CONF="expandable_segments:True,max_split_size_mb:128,garbage_collection_threshold:0.95" - -# Activate conda environment on all nodes -export PATH="/home/phuc/kimi_1t/env/bin:$PATH" -export CONDA_PREFIX="/home/phuc/kimi_1t/env" - -CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml"} -#CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_8b.toml"} -#CONFIG_FILE=${CONFIG_FILE:-"./torchtitan/models/llama3/train_configs/llama3_405b.toml"} - - -#dcgmi profile --pause -# adjust sbatch --ntasks and sbatch --nodes above and --nnodes below -# to your specific node count, and update target launch file. -srun --export=ALL torchrun --nnodes 40 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" -m torchtitan.train --job.config_file ${CONFIG_FILE} "$@" -#dcgmi profile --resume diff --git a/test_single_node.slurm b/test_single_node.slurm deleted file mode 100755 index c90b0dc71f..0000000000 --- a/test_single_node.slurm +++ /dev/null @@ -1,40 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=kimi_1t_test_1node -#SBATCH --ntasks=1 -#SBATCH --nodes=1 -#SBATCH --gpus-per-task=8 -#SBATCH --cpus-per-task=64 -#SBATCH --partition=batch -#SBATCH --time=00:30:00 - -nodes=( $( scontrol show hostnames $SLURM_JOB_NODELIST ) ) -nodes_array=($nodes) -head_node=${nodes_array[0]} -head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) - -echo "Node IP: $head_node_ip" -export LOGLEVEL=INFO -export FI_PROVIDER="efa" -export NCCL_DEBUG=WARN -export PYTHONFAULTHANDLER=1 -export LD_LIBRARY_PATH=/opt/amazon/efa/lib:$LD_LIBRARY_PATH -export LD_LIBRARY_PATH=/usr/local/lib/:$LD_LIBRARY_PATH -export CUDA_LAUNCH_BLOCKING=0 -export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond" -export NCCL_BUFFSIZE=2097152 -export FI_EFA_SET_CUDA_SYNC_MEMOPS=0 -export TRITON_HOME=/tmp/emotritoncache -export NUMBA_CACHE_DIR=/tmp/numbacache -export WANDB_ENTITY="nous_research" -export WANDB_PROJECT="moe" -export HF_HOME="/home/phuc/.cache/huggingface" -mkdir -p $HF_HOME - -# Activate conda environment -export PATH="/home/phuc/kimi_1t/env/bin:$PATH" -export CONDA_PREFIX="/home/phuc/kimi_1t/env" - -CONFIG_FILE="./torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml" - -echo "Testing single node with 8 GPUs" -srun torchrun --nnodes 1 --nproc_per_node 8 --rdzv_id 101 --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:29500" -m torchtitan.train --job.config_file ${CONFIG_FILE} "$@" diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml deleted file mode 100644 index 50c110d127..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_1b_baseline.toml +++ /dev/null @@ -1,84 +0,0 @@ -# DeepSeek V3 ~1B debug model - BASELINE (no activation offload) - -[job] -dump_folder = "./outputs/debug_1b_baseline" -description = "DeepSeek-V3 ~1B debug - BASELINE (no activation offload)" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_1b" # ~1B parameters -# test tokenizer, for debug purpose only -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 -decay_type = "linear" -min_lr_factor = 0.0 - -[training] -local_batch_size = 2 -seq_len = 4096 # Longer sequence to see activation memory -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" -context_parallel_degree = 1 -expert_parallel_degree = 1 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 10 -last_save_model_only = false -export_dtype = "float32" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' -# NO CPU OFFLOAD - BASELINE -cpu_offload = false - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml deleted file mode 100644 index 932fbb7084..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_1b_no_ac_baseline.toml +++ /dev/null @@ -1,54 +0,0 @@ -# DeepSeek V3 ~1B - NO AC, NO OFFLOAD (will show true activation memory) - -[job] -dump_folder = "./outputs/debug_1b_no_ac_baseline" -description = "DeepSeek-V3 ~1B - NO AC, NO OFFLOAD" -print_config = false - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_1b" -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 - -[training] -local_batch_size = 2 -seq_len = 4096 -max_norm = 1.0 -steps = 5 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -expert_parallel_degree = 1 - -[checkpoint] -enable = false - -[activation_checkpoint] -mode = "none" # NO ACTIVATION CHECKPOINTING - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml deleted file mode 100644 index 98f0ea4598..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload.toml +++ /dev/null @@ -1,84 +0,0 @@ -# DeepSeek V3 ~1B debug model - WITH activation offload - -[job] -dump_folder = "./outputs/debug_1b_offload" -description = "DeepSeek-V3 ~1B debug - WITH activation offload" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_1b" # ~1B parameters -# test tokenizer, for debug purpose only -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 -decay_type = "linear" -min_lr_factor = 0.0 - -[training] -local_batch_size = 2 -seq_len = 4096 # Longer sequence to see activation memory -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" -context_parallel_degree = 1 -expert_parallel_degree = 1 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 10 -last_save_model_only = false -export_dtype = "float32" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' -# ENABLE CPU OFFLOAD FOR ACTIVATIONS -cpu_offload = true - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml b/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml deleted file mode 100644 index a823e7e856..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_1b_offload_only.toml +++ /dev/null @@ -1,55 +0,0 @@ -# DeepSeek V3 ~1B - Offload-only (NO AC, just offload to CPU) - -[job] -dump_folder = "./outputs/debug_1b_offload_only" -description = "DeepSeek-V3 ~1B - Offload-only (NO AC)" -print_config = false - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_1b" -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 - -[training] -local_batch_size = 2 -seq_len = 4096 -max_norm = 1.0 -steps = 5 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -expert_parallel_degree = 1 - -[checkpoint] -enable = false - -[activation_checkpoint] -mode = "none" # NO AC - just offload -cpu_offload = true # But enable offload - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml b/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml deleted file mode 100644 index e04aa9c7f7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_7b_baseline.toml +++ /dev/null @@ -1,55 +0,0 @@ -# DeepSeek V3 ~7B debug model - BASELINE (no activation offload) - -[job] -dump_folder = "./outputs/debug_7b_baseline" -description = "DeepSeek-V3 ~7B debug - BASELINE (no activation offload)" -print_config = false - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_7b" -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 - -[training] -local_batch_size = 1 -seq_len = 4096 -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -expert_parallel_degree = 1 - -[checkpoint] -enable = false - -[activation_checkpoint] -mode = "full" -cpu_offload = false - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml deleted file mode 100644 index 19a99d0843..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_7b_offload.toml +++ /dev/null @@ -1,55 +0,0 @@ -# DeepSeek V3 ~7B debug model - WITH activation offload - -[job] -dump_folder = "./outputs/debug_7b_offload" -description = "DeepSeek-V3 ~7B debug - WITH activation offload" -print_config = false - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel_7b" -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 - -[training] -local_batch_size = 1 -seq_len = 4096 -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -expert_parallel_degree = 1 - -[checkpoint] -enable = false - -[activation_checkpoint] -mode = "full" -cpu_offload = true - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml b/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml deleted file mode 100644 index 7c4f45cbe0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/debug_activation_offload.toml +++ /dev/null @@ -1,82 +0,0 @@ -[job] -dump_folder = "./outputs" -description = "DeepSeek-V3 debug training with Activation Offloading" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "debugmodel" -# test tokenizer, for debug purpose only -hf_assets_path = "./tests/assets/tokenizer" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 -decay_type = "linear" -min_lr_factor = 0.0 - -[training] -local_batch_size = 2 -seq_len = 512 -max_norm = 1.0 -steps = 5 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "1F1B" -context_parallel_degree = 1 -expert_parallel_degree = 1 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 10 -last_save_model_only = false -export_dtype = "float32" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' -# Enable CPU offloading for activations -cpu_offload = true - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml deleted file mode 100644 index ad3249c845..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa0_8n_EP64_CP2_LBS1_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa0: 8 nodes EP=64 CP=2 LBS=1 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1aa0_8n_EP64_CP2_LBS1_ctx16k" -description = "exp1aa0_8n_EP64_CP2_LBS1_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml deleted file mode 100644 index 1614e22495..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa10_8n_EP64_CP2_LBS1_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa10: 8 nodes EP=64 CP=2 LBS=1 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1aa10_8n_EP64_CP2_LBS1_ctx32k" -description = "exp1aa10_8n_EP64_CP2_LBS1_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml deleted file mode 100644 index ff5eae2822..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa11_8n_EP64_CP2_LBS2_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa11: 8 nodes EP=64 CP=2 LBS=2 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1aa11_8n_EP64_CP2_LBS2_ctx32k" -description = "exp1aa11_8n_EP64_CP2_LBS2_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml deleted file mode 100644 index 227b601d57..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa12_8n_EP64_CP2_LBS4_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa12: 8 nodes EP=64 CP=2 LBS=4 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1aa12_8n_EP64_CP2_LBS4_ctx32k" -description = "exp1aa12_8n_EP64_CP2_LBS4_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml deleted file mode 100644 index 31467d95f7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa13_8n_EP64_CP2_LBS6_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa13: 8 nodes EP=64 CP=2 LBS=6 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1aa13_8n_EP64_CP2_LBS6_ctx32k" -description = "exp1aa13_8n_EP64_CP2_LBS6_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml deleted file mode 100644 index e269a0f35a..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa14_8n_EP64_CP2_LBS8_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa14: 8 nodes EP=64 CP=2 LBS=8 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1aa14_8n_EP64_CP2_LBS8_ctx32k" -description = "exp1aa14_8n_EP64_CP2_LBS8_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml deleted file mode 100644 index e0a8c4725d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa1_8n_EP64_CP2_LBS2_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa1: 8 nodes EP=64 CP=2 LBS=2 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1aa1_8n_EP64_CP2_LBS2_ctx16k" -description = "exp1aa1_8n_EP64_CP2_LBS2_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml deleted file mode 100644 index a3b1845a19..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa2_8n_EP64_CP2_LBS4_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa2: 8 nodes EP=64 CP=2 LBS=4 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1aa2_8n_EP64_CP2_LBS4_ctx16k" -description = "exp1aa2_8n_EP64_CP2_LBS4_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml deleted file mode 100644 index 92abbed703..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa3_8n_EP64_CP2_LBS6_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa3: 8 nodes EP=64 CP=2 LBS=6 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1aa3_8n_EP64_CP2_LBS6_ctx16k" -description = "exp1aa3_8n_EP64_CP2_LBS6_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml deleted file mode 100644 index dfb310182c..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa4_8n_EP64_CP2_LBS8_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa4: 8 nodes EP=64 CP=2 LBS=8 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1aa4_8n_EP64_CP2_LBS8_ctx16k" -description = "exp1aa4_8n_EP64_CP2_LBS8_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml deleted file mode 100644 index 33ffb89ef3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa5_8n_EP64_CP2_LBS1_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa5: 8 nodes EP=64 CP=2 LBS=1 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1aa5_8n_EP64_CP2_LBS1_ctx24k" -description = "exp1aa5_8n_EP64_CP2_LBS1_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml deleted file mode 100644 index 7201988625..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa6_8n_EP64_CP2_LBS2_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa6: 8 nodes EP=64 CP=2 LBS=2 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1aa6_8n_EP64_CP2_LBS2_ctx24k" -description = "exp1aa6_8n_EP64_CP2_LBS2_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml deleted file mode 100644 index 5127dafa97..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa7_8n_EP64_CP2_LBS4_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa7: 8 nodes EP=64 CP=2 LBS=4 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1aa7_8n_EP64_CP2_LBS4_ctx24k" -description = "exp1aa7_8n_EP64_CP2_LBS4_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml deleted file mode 100644 index f6a143b448..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa8_8n_EP64_CP2_LBS6_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa8: 8 nodes EP=64 CP=2 LBS=6 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1aa8_8n_EP64_CP2_LBS6_ctx24k" -description = "exp1aa8_8n_EP64_CP2_LBS6_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml deleted file mode 100644 index 0de9b9730c..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1aa/exp1aa9_8n_EP64_CP2_LBS8_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1aa9: 8 nodes EP=64 CP=2 LBS=8 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1aa9_8n_EP64_CP2_LBS8_ctx24k" -description = "exp1aa9_8n_EP64_CP2_LBS8_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml deleted file mode 100644 index 3a60dfc42a..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab0_12n_EP96_CP2_LBS1_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab0: 12 nodes EP=96 CP=2 LBS=1 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1ab0_12n_EP96_CP2_LBS1_ctx16k" -description = "exp1ab0_12n_EP96_CP2_LBS1_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml deleted file mode 100644 index 2ae12c912d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab10_12n_EP96_CP2_LBS1_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab10: 12 nodes EP=96 CP=2 LBS=1 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1ab10_12n_EP96_CP2_LBS1_ctx32k" -description = "exp1ab10_12n_EP96_CP2_LBS1_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml deleted file mode 100644 index 95083976af..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab11_12n_EP96_CP2_LBS2_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab11: 12 nodes EP=96 CP=2 LBS=2 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1ab11_12n_EP96_CP2_LBS2_ctx32k" -description = "exp1ab11_12n_EP96_CP2_LBS2_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml deleted file mode 100644 index 9d96209b09..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab12_12n_EP96_CP2_LBS4_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab12: 12 nodes EP=96 CP=2 LBS=4 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1ab12_12n_EP96_CP2_LBS4_ctx32k" -description = "exp1ab12_12n_EP96_CP2_LBS4_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml deleted file mode 100644 index 3f53673681..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab13_12n_EP96_CP2_LBS6_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab13: 12 nodes EP=96 CP=2 LBS=6 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1ab13_12n_EP96_CP2_LBS6_ctx32k" -description = "exp1ab13_12n_EP96_CP2_LBS6_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml deleted file mode 100644 index b40dd6b684..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab14_12n_EP96_CP2_LBS8_ctx32k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab14: 12 nodes EP=96 CP=2 LBS=8 ctx=32k -[job] -dump_folder = "./outputs/exp1a/exp1ab14_12n_EP96_CP2_LBS8_ctx32k" -description = "exp1ab14_12n_EP96_CP2_LBS8_ctx32k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml deleted file mode 100644 index 4d7a0bbf2d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab1_12n_EP96_CP2_LBS2_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab1: 12 nodes EP=96 CP=2 LBS=2 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1ab1_12n_EP96_CP2_LBS2_ctx16k" -description = "exp1ab1_12n_EP96_CP2_LBS2_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml deleted file mode 100644 index 817ea3793a..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab2_12n_EP96_CP2_LBS4_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab2: 12 nodes EP=96 CP=2 LBS=4 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1ab2_12n_EP96_CP2_LBS4_ctx16k" -description = "exp1ab2_12n_EP96_CP2_LBS4_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml deleted file mode 100644 index 7592a5d554..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab3_12n_EP96_CP2_LBS6_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab3: 12 nodes EP=96 CP=2 LBS=6 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1ab3_12n_EP96_CP2_LBS6_ctx16k" -description = "exp1ab3_12n_EP96_CP2_LBS6_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml deleted file mode 100644 index 5a6dbfa6c0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab4_12n_EP96_CP2_LBS8_ctx16k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab4: 12 nodes EP=96 CP=2 LBS=8 ctx=16k -[job] -dump_folder = "./outputs/exp1a/exp1ab4_12n_EP96_CP2_LBS8_ctx16k" -description = "exp1ab4_12n_EP96_CP2_LBS8_ctx16k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml deleted file mode 100644 index 4780f17a45..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab5_12n_EP96_CP2_LBS1_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab5: 12 nodes EP=96 CP=2 LBS=1 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1ab5_12n_EP96_CP2_LBS1_ctx24k" -description = "exp1ab5_12n_EP96_CP2_LBS1_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml deleted file mode 100644 index 4d4a50f8d4..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab6_12n_EP96_CP2_LBS2_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab6: 12 nodes EP=96 CP=2 LBS=2 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1ab6_12n_EP96_CP2_LBS2_ctx24k" -description = "exp1ab6_12n_EP96_CP2_LBS2_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 2 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml deleted file mode 100644 index 353ae40ead..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab7_12n_EP96_CP2_LBS4_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab7: 12 nodes EP=96 CP=2 LBS=4 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1ab7_12n_EP96_CP2_LBS4_ctx24k" -description = "exp1ab7_12n_EP96_CP2_LBS4_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 4 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml deleted file mode 100644 index 24df0535b3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab8_12n_EP96_CP2_LBS6_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab8: 12 nodes EP=96 CP=2 LBS=6 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1ab8_12n_EP96_CP2_LBS6_ctx24k" -description = "exp1ab8_12n_EP96_CP2_LBS6_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 6 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml b/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml deleted file mode 100644 index 50de7dac1e..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/exp1a/exp1ab/exp1ab9_12n_EP96_CP2_LBS8_ctx24k.toml +++ /dev/null @@ -1,46 +0,0 @@ -# Exp1ab9: 12 nodes EP=96 CP=2 LBS=8 ctx=24k -[job] -dump_folder = "./outputs/exp1a/exp1ab9_12n_EP96_CP2_LBS8_ctx24k" -description = "exp1ab9_12n_EP96_CP2_LBS8_ctx24k" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 8 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml deleted file mode 100644 index e5eecac4d0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_28k_flb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 28k context - 10 nodes EP=16 with FORCE LOAD BALANCE - -[job] -dump_folder = "./outputs/10n_28k_flb" -description = "28k context 10 nodes EP=16 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 16 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml deleted file mode 100644 index cb66bc4a3d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_2k.toml +++ /dev/null @@ -1,45 +0,0 @@ -# MEMORY TEST: 10 nodes, EP=16, 2k context -# 10 nodes × 8 GPUs = 80 GPUs, EP=16, 384/16 = 24 experts per GPU - -[job] -dump_folder = "./outputs/kimi_1t_10n_ep16_2k" -description = "Memory test - 10 nodes EP=16 - 2k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 2048 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 16 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml deleted file mode 100644 index d72e81d9ea..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_4k.toml +++ /dev/null @@ -1,43 +0,0 @@ -# MEMORY TEST: 10 nodes, EP=16, 4k context -[job] -dump_folder = "./outputs/kimi_1t_10n_ep16_4k" -description = "Memory test - 10 nodes EP=16 - 4k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 16 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml deleted file mode 100644 index 79d285cc96..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_10n_ep16_8k.toml +++ /dev/null @@ -1,43 +0,0 @@ -# MEMORY TEST: 10 nodes, EP=16, 8k context -[job] -dump_folder = "./outputs/kimi_1t_10n_ep16_8k" -description = "Memory test - 10 nodes EP=16 - 8k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 8192 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 16 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml deleted file mode 100644 index 166f5594af..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_ac_offload.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context with FORCE LOAD BALANCE + AC CPU OFFLOAD - 12 nodes EP=96 -# Testing if activation checkpoint CPU offload helps avoid OOM - -[job] -dump_folder = "./outputs/12n_28k_ac_offload" -description = "28k context 12n EP=96 with AC CPU offload" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 - -[activation_checkpoint] -mode = "full" -cpu_offload = true - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml deleted file mode 100644 index 0dd5b8e9bf..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_28k_selective_ac.toml +++ /dev/null @@ -1,49 +0,0 @@ -# 28k context with FORCE LOAD BALANCE + SELECTIVE AC - 12 nodes EP=96 -# Testing if selective op-level AC helps avoid OOM - -[job] -dump_folder = "./outputs/12n_28k_selective_ac" -description = "28k context 12n EP=96 with selective AC" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 - -[activation_checkpoint] -mode = "selective" -selective_ac_option = "op" -cpu_offload = true - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml deleted file mode 100644 index 642fa35563..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_28k_flex_fix.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 12 nodes EP=96 with CP=2 + FlexAttention (with fix) -# Testing create_cp_block_mask fix with more nodes - -[job] -dump_folder = "./outputs/12n_cp2_28k_flex_fix" -description = "28k context 12n EP=96 CP=2 FlexAttention with CP block mask fix" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml deleted file mode 100644 index 06678106e6..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_30720.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 30720 context - 12 nodes EP=96 with CP=2 + FlexAttention -[job] -dump_folder = "./outputs/12n_cp2_30720" -description = "30720 context 12n EP=96 CP=2" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 30720 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml deleted file mode 100644 index fbd8c45b73..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_cp2_32768.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 32768 context - 12 nodes EP=96 with CP=2 + FlexAttention -[job] -dump_folder = "./outputs/12n_cp2_32768" -description = "32768 context 12n EP=96 CP=2" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml deleted file mode 100644 index d431b45824..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_28k_flb.toml +++ /dev/null @@ -1,47 +0,0 @@ -# 28k context - 12 nodes EP=12 with FORCE LOAD BALANCE -# 32 experts per GPU - -[job] -dump_folder = "./outputs/12n_ep12_28k_flb" -description = "28k context 12 nodes EP=12 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 12 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml deleted file mode 100644 index 1d45a540f3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep12_32k_flb.toml +++ /dev/null @@ -1,47 +0,0 @@ -# 32k context - 12 nodes EP=12 with FORCE LOAD BALANCE -# 32 experts per GPU - -[job] -dump_folder = "./outputs/12n_ep12_32k_flb" -description = "32k context 12 nodes EP=12 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 12 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml deleted file mode 100644 index bdb5b51596..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_28k_flb.toml +++ /dev/null @@ -1,47 +0,0 @@ -# 28k context - 12 nodes EP=96 with FORCE LOAD BALANCE -# 4 experts per GPU (better than 8 nodes with 6 experts/GPU) - -[job] -dump_folder = "./outputs/12n_ep96_28k_flb" -description = "28k context 12 nodes EP=96 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml deleted file mode 100644 index ba0bfd78c7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_12n_ep96_32k_flb.toml +++ /dev/null @@ -1,47 +0,0 @@ -# 32k context - 12 nodes EP=96 with FORCE LOAD BALANCE -# 4 experts per GPU (better than 8 nodes with 6 experts/GPU) - -[job] -dump_folder = "./outputs/12n_ep96_32k_flb" -description = "32k context 12 nodes EP=96 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 96 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml deleted file mode 100644 index c2c2584d39..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 16k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/16k_force_lb" -description = "16k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 16384 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml deleted file mode 100644 index c782b42092..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_2k.toml +++ /dev/null @@ -1,45 +0,0 @@ -# MEMORY TEST: 16 nodes, EP=128, 2k context -# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU - -[job] -dump_folder = "./outputs/kimi_1t_16n_ep128_2k" -description = "Memory test - 16 nodes EP=128 - 2k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 2048 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 128 # 3 experts per GPU - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml deleted file mode 100644 index 483f36cdc0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_4k.toml +++ /dev/null @@ -1,45 +0,0 @@ -# MEMORY TEST: 16 nodes, EP=128, 4k context -# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU - -[job] -dump_folder = "./outputs/kimi_1t_16n_ep128_4k" -description = "Memory test - 16 nodes EP=128 - 4k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 128 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml deleted file mode 100644 index cf587b35df..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_16n_ep128_8k.toml +++ /dev/null @@ -1,45 +0,0 @@ -# MEMORY TEST: 16 nodes, EP=128, 8k context -# 16 nodes × 8 GPUs = 128 GPUs, 384 experts / 128 = 3 experts per GPU - -[job] -dump_folder = "./outputs/kimi_1t_16n_ep128_8k" -description = "Memory test - 16 nodes EP=128 - 8k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 8192 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 128 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml deleted file mode 100644 index 49d382f70e..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_20k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 20k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/20k_force_lb" -description = "20k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 20480 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml deleted file mode 100644 index 36c664c81b..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_24k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 24k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/24k_force_lb" -description = "24k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 24576 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml deleted file mode 100644 index 2aa63e98d7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_ac_offload.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context with FORCE LOAD BALANCE + AC CPU OFFLOAD -# Testing if activation checkpoint CPU offload helps avoid OOM - -[job] -dump_folder = "./outputs/28k_ac_offload" -description = "28k context with AC CPU offload" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" -cpu_offload = true - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml deleted file mode 100644 index b4022f85df..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 28k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/28k_force_lb" -description = "28k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml deleted file mode 100644 index a0b5a40665..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_28k_selective_ac.toml +++ /dev/null @@ -1,49 +0,0 @@ -# 28k context with FORCE LOAD BALANCE + SELECTIVE AC (op level) -# Testing if selective op-level AC helps avoid OOM - -[job] -dump_folder = "./outputs/28k_selective_ac" -description = "28k context with selective op-level AC" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "selective" -selective_ac_option = "op" -cpu_offload = true - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml deleted file mode 100644 index ffa6efdc54..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_32k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 32k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/32k_force_lb" -description = "32k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml deleted file mode 100644 index beea82e717..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_4k_force_lb.toml +++ /dev/null @@ -1,47 +0,0 @@ -# 4k context with FORCE LOAD BALANCE - 8 nodes EP=64 -# Test if forced uniform expert distribution prevents OOM - -[job] -dump_folder = "./outputs/4k_force_lb" -description = "4k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml deleted file mode 100644 index ab7214f598..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_6k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 6k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/6k_force_lb" -description = "6k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 6144 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml deleted file mode 100644 index ebab4dda5d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8k_force_lb.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 8k context with FORCE LOAD BALANCE - 8 nodes EP=64 - -[job] -dump_folder = "./outputs/8k_force_lb" -description = "8k context with force load balance" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 8192 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml deleted file mode 100644 index 44f9f7b80b..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flb.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 8 nodes EP=64 with CP=2 (Context Parallel) -# Testing FlexAttention + Context Parallel - -[job] -dump_folder = "./outputs/8n_cp2_28k_flb" -description = "28k context 8 nodes EP=64 CP=2 force LB" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml deleted file mode 100644 index 8e93ca2dc5..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_flex_fix.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 8 nodes EP=64 with CP=2 + FlexAttention (with fix) -# Testing create_cp_block_mask fix - -[job] -dump_folder = "./outputs/8n_cp2_28k_flex_fix" -description = "28k context CP=2 FlexAttention with CP block mask fix" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml deleted file mode 100644 index 9f9f93eba3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_28k_sdpa.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 8 nodes EP=64 with CP=2 (Context Parallel) -# Using SDPA instead of FlexAttention for CP compatibility - -[job] -dump_folder = "./outputs/8n_cp2_28k_sdpa" -description = "28k context 8 nodes EP=64 CP=2 SDPA" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2_sdpa" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml deleted file mode 100644 index c1b2172a88..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_30720.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 30720 context - 8 nodes EP=64 with CP=2 + FlexAttention -[job] -dump_folder = "./outputs/8n_cp2_30720" -description = "30720 context 8n EP=64 CP=2" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 30720 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml deleted file mode 100644 index 59a6394f09..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_32768.toml +++ /dev/null @@ -1,46 +0,0 @@ -# 32768 context - 8 nodes EP=64 with CP=2 + FlexAttention -[job] -dump_folder = "./outputs/8n_cp2_32768" -description = "32768 context 8n EP=64 CP=2" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 32768 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml deleted file mode 100644 index 1d5f3d5277..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_cp2_ep1_28k.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 8 nodes with CP=2 but NO EP -# Testing if CP works without Expert Parallelism - -[job] -dump_folder = "./outputs/8n_cp2_ep1_28k" -description = "28k context CP=2 NO EP (isolate CP issue)" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2_sdpa" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 3 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 1 -context_parallel_degree = 2 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml deleted file mode 100644 index 991b6dceb3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_8n_tp2_28k.toml +++ /dev/null @@ -1,48 +0,0 @@ -# 28k context - 8 nodes with TP=2 (Tensor Parallel) -# Testing if TP helps fit 28k by sharding attention - -[job] -dump_folder = "./outputs/8n_tp2_28k" -description = "28k context TP=2 EP=32" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -tensor_parallel_degree = 2 -expert_parallel_degree = 32 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml deleted file mode 100644 index 439ba2dcb9..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_activation_offload.toml +++ /dev/null @@ -1,81 +0,0 @@ -[job] -dump_folder = "./outputs" -description = "Kimi K2 1T model training with Activation Offloading" -print_config = false - -[profiling] -enable_profiling = true -save_traces_folder = "profile_trace" -profile_freq = 5 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2_000 -decay_ratio = 0.8 -decay_type = "cosine" -min_lr_factor = 0.1 - -[training] -local_batch_size = 1 -seq_len = 256 -max_norm = 1.0 -steps = 20 -dataset = "c4_test" -dtype = "bfloat16" -# Disable parameter/gradient CPU offload to test activation offload specifically -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 64 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 500 -last_save_model_only = true -export_dtype = "bfloat16" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' -# Enable CPU offloading for activations -cpu_offload = true - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml deleted file mode 100644 index 4644ba6f40..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep32_40nodes_no_offload.toml +++ /dev/null @@ -1,50 +0,0 @@ -# BASELINE: Production-like config (NO optimizations) -# Test: Measure memory usage WITHOUT any optimizations -# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 -# Purpose: Establish baseline memory consumption (like original config) - -[job] -dump_folder = "./outputs/kimi_1t_baseline_ep32_40nodes_no_offload" -description = "Baseline - No Optimizations - EP32 - 40 nodes" - -[profiling] -enable_profiling = false -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" - -# BASELINE: No optimizations (like original) -# enable_cpu_offload = false (default) -# enable_detailed_memory_tracking = false (default) -# clear_cache_between_steps = false (default) - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 32 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml deleted file mode 100644 index 0c3ec0b6e1..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_baseline_ep64_8nodes_no_offload.toml +++ /dev/null @@ -1,51 +0,0 @@ -# BASELINE: EP=64, 8 nodes, NO optimizations -# Test: Measure memory WITHOUT CPU offload or cache clearing -# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 -# Purpose: Establish baseline at EP=64 (should work, ~23-25 GB expected) - -[job] -dump_folder = "./outputs/kimi_1t_baseline_ep64_8n_no_offload" -description = "Baseline EP64 - No Optimizations - 8 nodes" - -[profiling] -enable_profiling = false -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" -skip_optimizer_step = true -enable_detailed_memory_tracking = true - -# BASELINE: No optimizations -# enable_cpu_offload = false (default) -# clear_cache_between_steps = false (default) - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 # 384 experts / 64 = 6 experts per GPU - -[activation_checkpoint] -mode = "full" - -[compile] -enable = true # Production-like diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml deleted file mode 100644 index 019ae9fa0b..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_cpuoffload_ep32_40nodes_with_offload.toml +++ /dev/null @@ -1,52 +0,0 @@ -# TEST: WITH CPU Offloading -# Test: Measure memory usage WITH CPU offloading -# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 -# Purpose: Measure CPU offloading impact vs baseline - -[job] -dump_folder = "./outputs/kimi_1t_cpuoffload_ep32_40nodes_with_offload" -description = "CPU Offload Test - EP32 - 40 nodes" - -[profiling] -enable_profiling = true -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" - -# TEST: CPU offload ENABLED -enable_cpu_offload = true - -# Memory tracking enabled for both tests -enable_detailed_memory_tracking = true -clear_cache_between_steps = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 32 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml deleted file mode 100644 index 8e3e0af01d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_40nodes_with_offload.toml +++ /dev/null @@ -1,53 +0,0 @@ -# DEBUG: Deep profiling to find memory bottleneck -# Test: EP=32 with CPU offload - investigate why memory spikes to 73GB -# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 -# Purpose: Find where memory is being allocated during forward pass - -[job] -dump_folder = "./outputs/kimi_1t_debug_ep32_40nodes_with_offload" -description = "DEBUG - Memory Investigation - EP32 - 40 nodes" - -[profiling] -enable_profiling = true -profile_freq = 1 # Profile at step 1 -enable_memory_snapshot = true # Enable memory snapshot -save_traces_folder = "profile_traces" -with_stack = true # Capture stack traces -with_modules = true - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 2 # Just 2 steps for quick debug -dataset = "c4_test" - -# OPTIMIZATIONS ENABLED - Same as Job 1270 -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 32 # 12 experts per GPU - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false # No compile for cleaner profiling diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml deleted file mode 100644 index 3d3368cfc7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_debug_ep32_8nodes_seq512.toml +++ /dev/null @@ -1,52 +0,0 @@ -# DEBUG: 8 nodes, EP=32, seq=512 to reproduce OOM -# Expected: Should OOM like Job 1270, but on stable 8-node setup -# Configuration: 8 nodes, 64 GPUs, EP=32 (12 experts/GPU), seq=512 - -[job] -dump_folder = "./outputs/kimi_1t_debug_ep32_8n_seq512" -description = "DEBUG - EP32 seq512 - 8 nodes - Expected OOM" - -[profiling] -enable_profiling = true -profile_freq = 1 -enable_memory_snapshot = true -save_traces_folder = "profile_traces" -with_stack = true -with_modules = true - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 # Same as Job 1270 -steps = 2 -dataset = "c4_test" - -# OPTIMIZATIONS ENABLED -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 32 # 384 experts / 32 = 12 experts per GPU - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml deleted file mode 100644 index 2402915040..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_defrag_test.toml +++ /dev/null @@ -1,60 +0,0 @@ -# Kimi K2 1T - Memory Defragmentation Test -# Testing different allocator settings + selective AC - -[job] -dump_folder = "./outputs/kimi_1t_defrag_test" -description = "Kimi K2 1T - Test fragmentation fixes" - -[profiling] -enable_profiling = true -save_traces_folder = "profile_trace" -profile_freq = 3 -profiler_warmup = 1 -profiler_active = 1 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot" -with_stack = true -with_modules = true - -[metrics] -log_freq = 1 -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 256 -steps = 5 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = true - -[parallelism] -data_parallel_shard_degree = -1 -tensor_parallel_degree = 1 -pipeline_parallel_degree = 1 -expert_parallel_degree = 64 - -[checkpoint] -enable = false - -[activation_checkpoint] -# Try selective instead of full to reduce allocation churn -mode = "selective" -selective_ac_option = '2' # Checkpoint every 2 layers - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml deleted file mode 100644 index 5af25527d7..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_memory_profiling.toml +++ /dev/null @@ -1,78 +0,0 @@ -# Kimi K2 1T - DETAILED MEMORY PROFILING Configuration -# -# Purpose: Get exact breakdown of where the 55 GB "activation memory" goes -# Approach: Enhanced memory snapshots + instrumented checkpoints -# Expected output: Detailed memory allocation report showing: -# - NCCL communication buffers -# - FSDP temporary storage -# - MoE all-to-all staging -# - CPU offload overhead -# - Actual checkpointed activations -# - PyTorch allocator overhead - -[job] -dump_folder = "./outputs/kimi_1t_detailed_memory_profiling" -description = "Kimi K2 1T - Detailed Memory Breakdown Profiling" -print_config = true # Print config for verification - -[profiling] -enable_profiling = true -save_traces_folder = "profile_trace" -profile_freq = 3 # Profile step 3 (after warmup) -profiler_warmup = 1 -profiler_active = 1 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot" -with_stack = true # Enable stack traces for memory allocations -with_modules = true # Enable module tracking - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 -decay_type = "cosine" -min_lr_factor = 0.1 - -[training] -local_batch_size = 1 -seq_len = 256 -max_norm = 1.0 -steps = 5 # Just 5 steps for profiling -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = true - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -expert_parallel_degree = 64 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml deleted file mode 100644 index 485b60a104..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_detailed_tracking.toml +++ /dev/null @@ -1,47 +0,0 @@ -# Detailed Memory Tracking Test -# Track memory at every phase with cache clearing between steps - -[job] -dump_folder = "./outputs/kimi_1t_detailed_tracking" - -[profiling] -enable_profiling = true -profile_freq = 10 # Don't take snapshot (expensive) -enable_memory_snapshot = false # Disable snapshot for cleaner logs - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 256 -steps = 5 # Just 5 steps for detailed tracking -dataset = "c4_test" -enable_cpu_offload = true - -# Detailed memory tracking -enable_detailed_memory_tracking = true -clear_cache_between_steps = true # Clear cache after each step - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml deleted file mode 100644 index 17dc3a4538..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_emozilla.toml +++ /dev/null @@ -1,116 +0,0 @@ -[job] -dump_folder = "./outputs" -description = "Kimi K2 1T model training" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 10 -enable_memory_snapshot = false -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" -#converters = ["quantize.linear.float8"] - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -#[optimizer] -#name = "Muon" -##lr = 2.2e-4 -#lr = 1e-5 -#eps = 1e-8 -#weight_decay = 0.1 -# -## Muon-specific parameters -#mu = 0.95 # Momentum factor for Muon -#algorithm = "muon" # Main algorithm to use for 2D matrices -#nesterov = false # Whether to use Nesterov momentum -#adjust_lr = "rms_norm" # How to adjust LR: "spectral_norm", "rms_norm", or null -#flatten = false # Whether to flatten 3D+ tensors to 2D -#use_triton = true # Whether to use Triton kernel for Newton-Schulz -# -## Parameter-specific optimizer selection -#scalar_optimizer = "adamw" # For 1D parameters (biases, layer norms) -#embedding_optimizer = "adamw" # For embedding layers -#head_optimizer = "adamw" # For model head/output layers -#head_lr_scaling = false # Apply 1/sqrt(dim) scaling to head layers -# -## Learning rate scaling factors -#scalar_lr_factor = 1 # LR multiplier for scalar parameters -#embedding_lr_factor = 1 # LR multiplier for embedding parameters -#head_lr_factor = 1.0 # LR multiplier for head parameters (after head_lr_scaling) -#routing_lr_factor = 1.0 # LR multiplier for routing parameters -#expert_lr_factor = 1.0 - -[lr_scheduler] -warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps -decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "cosine" -min_lr_factor = 0.1 - -[training] -local_batch_size = 1 -seq_len = 256 -max_norm = 1.0 # grad norm clipping -steps = 10_000 -dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) -dtype = "bfloat16" -enable_cpu_offload = true - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 64 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 500 -last_save_model_only = true -export_dtype = "bfloat16" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" - -[activation_checkpoint] -mode = "full" # ["none", "selective", "full"] -selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy - -[compile] -enable = false -components = ["model", "loss"] # ["model", "loss"] -# fullgraph = false - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] - -#[deepep] -#sync_comm_stream = false -#fused_weighted_scatter_add = false -#fused_silu_gate_prob = true -# -#[debug] -#moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml deleted file mode 100644 index f48ab381cc..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_combined.toml +++ /dev/null @@ -1,50 +0,0 @@ -# Test 3: Combined Approach -# Approach: Selective AC + Defragmentation + Better allocator - -[job] -dump_folder = "./outputs/kimi_1t_fix_combined" - -[profiling] -enable_profiling = true -profile_freq = 3 -enable_memory_snapshot = true -with_stack = true - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 256 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -# Moderate defragmentation (not every step) -enable_memory_defrag = true -defrag_freq = 2 # Every 2 steps -aggressive_defrag = false - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -# Selective AC -mode = "selective" -selective_ac_option = '2' - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml deleted file mode 100644 index acc42b24d6..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_defrag.toml +++ /dev/null @@ -1,48 +0,0 @@ -# Test 1: Aggressive Memory Defragmentation -# Approach: Clear cache every step to reduce fragmentation - -[job] -dump_folder = "./outputs/kimi_1t_fix_defrag" - -[profiling] -enable_profiling = true -profile_freq = 3 -enable_memory_snapshot = true -with_stack = true - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 256 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -# Enable aggressive defragmentation -enable_memory_defrag = true -defrag_freq = 1 # Every step -aggressive_defrag = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" # Keep full AC - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml deleted file mode 100644 index c41ba80ffc..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_fix_selective_ac.toml +++ /dev/null @@ -1,47 +0,0 @@ -# Test 2: Selective Activation Checkpointing -# Approach: Checkpoint fewer layers to reduce allocation churn - -[job] -dump_folder = "./outputs/kimi_1t_fix_selective_ac" - -[profiling] -enable_profiling = true -profile_freq = 3 -enable_memory_snapshot = true -with_stack = true - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 256 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_memory_defrag = false - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -# Selective AC: checkpoint every 2 layers instead of all -mode = "selective" -selective_ac_option = '2' - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml deleted file mode 100644 index 38a0761336..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_16k_ctx.toml +++ /dev/null @@ -1,46 +0,0 @@ -# MEMORY TEST: 16k context length -# Purpose: Measure memory at 16384 seq_len - -[job] -dump_folder = "./outputs/kimi_1t_memory_16k" -description = "Memory test - 16k context - CPU Offload" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 16384 # 16k context -steps = 5 -dataset = "c4_test" - -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml deleted file mode 100644 index f4f5358293..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_1k_ctx.toml +++ /dev/null @@ -1,43 +0,0 @@ -# MEMORY TEST: 1k context length -[job] -dump_folder = "./outputs/kimi_1t_memory_1k" -description = "Memory test - 1k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 1024 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml deleted file mode 100644 index 4895a24ab3..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_2k_ctx.toml +++ /dev/null @@ -1,43 +0,0 @@ -# MEMORY TEST: 2k context length -[job] -dump_folder = "./outputs/kimi_1t_memory_2k" -description = "Memory test - 2k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 2048 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml deleted file mode 100644 index 9b3611ddd8..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_4k_ctx.toml +++ /dev/null @@ -1,43 +0,0 @@ -# MEMORY TEST: 4k context length -[job] -dump_folder = "./outputs/kimi_1t_memory_4k" -description = "Memory test - 4k context" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 5 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml deleted file mode 100644 index 2cb6ac136d..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memory_8k_ctx.toml +++ /dev/null @@ -1,46 +0,0 @@ -# MEMORY TEST: 8k context length -# Purpose: Measure memory at 8192 seq_len - -[job] -dump_folder = "./outputs/kimi_1t_memory_8k" -description = "Memory test - 8k context - CPU Offload" - -[profiling] -enable_profiling = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 8192 # 8k context -steps = 5 -dataset = "c4_test" - -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml deleted file mode 100644 index 4fa07a6084..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_24k_flb.toml +++ /dev/null @@ -1,53 +0,0 @@ -# DEEP MEMORY PROFILING: 24k context with FORCE LOAD BALANCE (works) -# Compare with 28k to find exact OOM location - -[job] -dump_folder = "./outputs/memprof_24k_flb" -description = "Deep Memory Profiling - 24k context (force LB)" - -[profiling] -enable_profiling = true -profile_freq = 3 -profiler_warmup = 1 -profiler_active = 1 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_24k_flb" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 24576 -steps = 2 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml deleted file mode 100644 index c3ee3516fa..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_28k_flb.toml +++ /dev/null @@ -1,53 +0,0 @@ -# DEEP MEMORY PROFILING: 28k context with FORCE LOAD BALANCE (OOM) -# Compare with 24k to find exact OOM location - -[job] -dump_folder = "./outputs/memprof_28k_flb" -description = "Deep Memory Profiling - 28k context (force LB)" - -[profiling] -enable_profiling = true -profile_freq = 3 -profiler_warmup = 1 -profiler_active = 1 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_28k_flb" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 28672 -steps = 2 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml deleted file mode 100644 index b0d8678ca1..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k.toml +++ /dev/null @@ -1,50 +0,0 @@ -# DEEP MEMORY PROFILING: 2k context (baseline that works) -# Purpose: Capture detailed memory snapshots to compare with 4k - -[job] -dump_folder = "./outputs/memprof_2k" -description = "Deep Memory Profiling - 2k context" - -[profiling] -enable_profiling = true -profile_freq = 5 -profiler_warmup = 1 -profiler_active = 2 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_2k" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 2048 -steps = 3 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml deleted file mode 100644 index 3571e2a0e8..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_2k_force_lb.toml +++ /dev/null @@ -1,53 +0,0 @@ -# DEEP MEMORY PROFILING: 2k context with FORCE LOAD BALANCE -# Baseline for comparison with 4k - -[job] -dump_folder = "./outputs/memprof_2k_flb" -description = "Deep Memory Profiling - 2k context (force LB)" - -[profiling] -enable_profiling = true -profile_freq = 5 -profiler_warmup = 1 -profiler_active = 2 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_2k_flb" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 2048 -steps = 3 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml deleted file mode 100644 index cb829756b4..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k.toml +++ /dev/null @@ -1,50 +0,0 @@ -# DEEP MEMORY PROFILING: 4k context (OOMs - need to capture where) -# Purpose: Capture detailed memory snapshots to identify OOM point - -[job] -dump_folder = "./outputs/memprof_4k" -description = "Deep Memory Profiling - 4k context" - -[profiling] -enable_profiling = true -profile_freq = 5 -profiler_warmup = 1 -profiler_active = 2 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_4k" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 3 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml deleted file mode 100644 index 3f0b618255..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_memprof_4k_force_lb.toml +++ /dev/null @@ -1,53 +0,0 @@ -# DEEP MEMORY PROFILING: 4k context with FORCE LOAD BALANCE -# Compare with 2k to find where memory increases - -[job] -dump_folder = "./outputs/memprof_4k_flb" -description = "Deep Memory Profiling - 4k context (force LB)" - -[profiling] -enable_profiling = true -profile_freq = 5 -profiler_warmup = 1 -profiler_active = 2 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot_4k_flb" - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 4096 -steps = 3 -dataset = "c4_test" -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false - -[debug] -moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml deleted file mode 100644 index bb1c50aad0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_offload_no_cache_clear.toml +++ /dev/null @@ -1,51 +0,0 @@ -# CPU OFFLOAD ONLY: EP=64, 8 nodes, NO cache clearing -# Test: Measure memory WITH CPU offload but WITHOUT cache clearing -# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 -# Purpose: Apple-to-apple comparison with cache clearing - -[job] -dump_folder = "./outputs/kimi_1t_offload_no_cache_clear" -description = "CPU Offload Only - No Cache Clearing - 8 nodes" - -[profiling] -enable_profiling = false -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" -skip_optimizer_step = true -enable_detailed_memory_tracking = true - -# CPU OFFLOAD ONLY - NO CACHE CLEARING -enable_cpu_offload = true -clear_cache_between_steps = false - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml deleted file mode 100644 index 7ad05e8bae..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep32_40nodes_with_offload.toml +++ /dev/null @@ -1,50 +0,0 @@ -# OPTIMIZED: WITH all memory optimizations -# Test: Measure memory usage WITH CPU offloading + tracking + cache clearing -# Configuration: 40 nodes, 320 GPUs, EP=32, seq_len=512, batch_size=1 -# Purpose: Measure optimization impact vs baseline - -[job] -dump_folder = "./outputs/kimi_1t_optimized_ep32_40nodes_with_offload" -description = "Optimized - CPU Offload + Tracking - EP32 - 40 nodes" - -[profiling] -enable_profiling = true -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" - -# OPTIMIZATIONS ENABLED -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 32 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml deleted file mode 100644 index ff0d3bb4b0..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_optimized_ep64_8nodes_with_offload.toml +++ /dev/null @@ -1,51 +0,0 @@ -# OPTIMIZED: EP=64, 8 nodes, WITH all optimizations -# Test: Measure memory WITH CPU offload + cache clearing + tracking -# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 -# Purpose: Measure optimization impact at EP=64 - -[job] -dump_folder = "./outputs/kimi_1t_optimized_ep64_8n_with_offload" -description = "Optimized EP64 - CPU Offload + Cache Clearing - 8 nodes" - -[profiling] -enable_profiling = true -profile_freq = 10 -enable_memory_snapshot = false - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 5 -dataset = "c4_test" - -# OPTIMIZATIONS ENABLED -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 # Same as baseline: 6 experts per GPU - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false # For cleaner memory tracking diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml deleted file mode 100644 index f7cda07a23..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling.toml +++ /dev/null @@ -1,116 +0,0 @@ -[job] -dump_folder = "./outputs" -description = "Kimi K2 1T model training" -print_config = false - -[profiling] -enable_profiling = true -save_traces_folder = "profile_trace" -profile_freq = 5 -enable_memory_snapshot = true -save_memory_snapshot_folder = "memory_snapshot" - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" -#converters = ["quantize.linear.float8"] - -[optimizer] -name = "AdamW" -lr = 2.2e-4 -eps = 1e-8 - -#[optimizer] -#name = "Muon" -##lr = 2.2e-4 -#lr = 1e-5 -#eps = 1e-8 -#weight_decay = 0.1 -# -## Muon-specific parameters -#mu = 0.95 # Momentum factor for Muon -#algorithm = "muon" # Main algorithm to use for 2D matrices -#nesterov = false # Whether to use Nesterov momentum -#adjust_lr = "rms_norm" # How to adjust LR: "spectral_norm", "rms_norm", or null -#flatten = false # Whether to flatten 3D+ tensors to 2D -#use_triton = true # Whether to use Triton kernel for Newton-Schulz -# -## Parameter-specific optimizer selection -#scalar_optimizer = "adamw" # For 1D parameters (biases, layer norms) -#embedding_optimizer = "adamw" # For embedding layers -#head_optimizer = "adamw" # For model head/output layers -#head_lr_scaling = false # Apply 1/sqrt(dim) scaling to head layers -# -## Learning rate scaling factors -#scalar_lr_factor = 1 # LR multiplier for scalar parameters -#embedding_lr_factor = 1 # LR multiplier for embedding parameters -#head_lr_factor = 1.0 # LR multiplier for head parameters (after head_lr_scaling) -#routing_lr_factor = 1.0 # LR multiplier for routing parameters -#expert_lr_factor = 1.0 - -[lr_scheduler] -warmup_steps = 2_000 # lr scheduler warm up, normally 20% of the train steps -decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps -decay_type = "cosine" -min_lr_factor = 0.1 - -[training] -local_batch_size = 1 -seq_len = 256 -max_norm = 1.0 # grad norm clipping -steps = 20 -dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) -dtype = "bfloat16" -enable_cpu_offload = true - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" # default / never / always -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -pipeline_parallel_degree = 1 -pipeline_parallel_schedule = "Interleaved1F1B" -expert_parallel_degree = 64 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 500 -last_save_model_only = true -export_dtype = "bfloat16" -async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem]" - -[activation_checkpoint] -mode = "full" # ["none", "selective", "full"] -selective_ac_option = 'op' # 'int' = ac every positive int layer or 'op', ac based on ops policy - -[compile] -enable = false -components = ["model", "loss"] # ["model", "loss"] -# fullgraph = false - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output", "router.gate"] - -[quantize.grouped_mm.float8] -fqns = ["experts"] - -#[deepep] -#sync_comm_stream = false -#fused_weighted_scatter_add = false -#fused_silu_gate_prob = true -# -#[debug] -#moe_force_load_balance = true diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml deleted file mode 100644 index cdb16eac9c..0000000000 --- a/torchtitan/models/deepseek_v3/train_configs/kimi_1t_profiling_ep64_8nodes.toml +++ /dev/null @@ -1,51 +0,0 @@ -# PROFILING: EP=64, 8 nodes, CPU offload + cache clear -# Purpose: Deep profiling to identify bottleneck operations -# Configuration: 8 nodes, 64 GPUs, EP=64 (6 experts/GPU), seq=512, batch=1 - -[job] -dump_folder = "./outputs/kimi_1t_profiling_ep64" -description = "Deep Profiling - CPU Offload + Cache Clear - 8 nodes" - -[profiling] -enable_profiling = true -profile_freq = 5 # Must be >= warmup (1) + active (3) -profiler_warmup = 1 -profiler_active = 3 - -[metrics] -log_freq = 1 - -[model] -name = "deepseek_v3" -flavor = "kimi_k2" -hf_assets_path = "./assets/hf/DeepSeek-V3-Base" - -[optimizer] -name = "AdamW" -lr = 2.2e-4 - -[lr_scheduler] -warmup_steps = 2 -decay_ratio = 0.8 - -[training] -local_batch_size = 1 -seq_len = 512 -steps = 15 # More steps for better profiling data (need >= profile_freq for traces) -dataset = "c4_test" - -# OPTIMIZATIONS ENABLED -enable_cpu_offload = true -enable_detailed_memory_tracking = true -clear_cache_between_steps = true -skip_optimizer_step = true - -[parallelism] -data_parallel_shard_degree = -1 -expert_parallel_degree = 64 - -[activation_checkpoint] -mode = "full" - -[compile] -enable = false diff --git a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml deleted file mode 100644 index 140744642b..0000000000 --- a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_baseline.toml +++ /dev/null @@ -1,70 +0,0 @@ -# Qwen3 1.7B - Local test WITHOUT activation offloading (BASELINE) - -[job] -dump_folder = "./outputs/qwen3_1.7b_baseline" -description = "Qwen 3 1.7B local test - BASELINE (no activation offload)" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 100 - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "qwen3" -flavor = "1.7B" -hf_assets_path = "./assets/hf/Qwen3-1.7B" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 5 - -[training] -local_batch_size = 1 -seq_len = 2048 -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -context_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 50 -last_save_model_only = false -export_dtype = "bfloat16" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = "op" -# NO CPU OFFLOAD - BASELINE -cpu_offload = false - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output"] diff --git a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml b/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml deleted file mode 100644 index 188cc580c2..0000000000 --- a/torchtitan/models/qwen3/train_configs/qwen3_1.7b_local_test_offload.toml +++ /dev/null @@ -1,70 +0,0 @@ -# Qwen3 1.7B - Local test WITH activation offloading - -[job] -dump_folder = "./outputs/qwen3_1.7b_offload" -description = "Qwen 3 1.7B local test - WITH activation offload" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 100 - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "qwen3" -flavor = "1.7B" -hf_assets_path = "./assets/hf/Qwen3-1.7B" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 5 - -[training] -local_batch_size = 1 -seq_len = 2048 -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -context_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 50 -last_save_model_only = false -export_dtype = "bfloat16" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = "op" -# ENABLE CPU OFFLOAD FOR ACTIVATIONS -cpu_offload = true - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output"] diff --git a/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml b/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml deleted file mode 100644 index 3ad3d6321b..0000000000 --- a/torchtitan/models/qwen3/train_configs/qwen3_30b_a3b_activation_offload_test.toml +++ /dev/null @@ -1,75 +0,0 @@ -# Qwen3 30B A3B with Activation Offloading - Local Test Config - -[job] -dump_folder = "./outputs/qwen3_30b_act_offload_test" -description = "Qwen3 30B A3B with Activation Offloading Test" -print_config = false - -[profiling] -enable_profiling = false -save_traces_folder = "profile_trace" -profile_freq = 100 - -[metrics] -log_freq = 1 -disable_color_printing = false -enable_tensorboard = false -save_tb_folder = "tb" -enable_wandb = false - -[model] -name = "qwen3" -flavor = "30B-A3B" -hf_assets_path = "./assets/hf/Qwen3-30B-A3B-Instruct-2507" - -[optimizer] -name = "AdamW" -lr = 8e-4 -eps = 1e-8 - -[lr_scheduler] -warmup_steps = 10 -decay_ratio = 0.8 -decay_type = "cosine" -min_lr_factor = 0.1 - -[training] -local_batch_size = 1 -seq_len = 2048 -max_norm = 1.0 -steps = 10 -dataset = "c4_test" -dtype = "bfloat16" -enable_cpu_offload = false - -[parallelism] -data_parallel_replicate_degree = 1 -data_parallel_shard_degree = -1 -fsdp_reshard_after_forward = "default" -tensor_parallel_degree = 1 -enable_async_tensor_parallel = false -expert_parallel_degree = 2 -expert_tensor_parallel_degree = 1 - -[checkpoint] -enable = false -folder = "checkpoint" -interval = 500 -last_save_model_only = true -export_dtype = "bfloat16" -async_mode = "disabled" - -[activation_checkpoint] -mode = "full" -selective_ac_option = 'op' -# Enable CPU offloading for activations - THIS IS THE KEY TEST -cpu_offload = true - -[compile] -enable = false -components = ["model", "loss"] - -[quantize.linear.float8] -enable_fsdp_float8_all_gather = false -precompute_float8_dynamic_scale_for_fsdp = false -filter_fqns = ["output"] From e04c0f68aa204f0f93f27fe7ae9ea644f970605f Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 20 Jan 2026 11:50:55 -0800 Subject: [PATCH 05/18] remove assert for cp, and removed new activation checkpointing --- torchtitan/config/job_config.py | 8 - .../distributed/activation_checkpoint.py | 32 +- .../activation_checkpoint_offload.py | 313 ------------------ .../models/deepseek_v3/infra/parallelize.py | 3 - 4 files changed, 2 insertions(+), 354 deletions(-) delete mode 100644 torchtitan/distributed/activation_checkpoint_offload.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 77535e932c..7bc20b1d62 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -861,14 +861,6 @@ class ActivationCheckpoint: https://docs.pytorch.org/docs/stable/checkpoint.html for details. """ - cpu_offload: bool = False - """ - Enable CPU offloading for activation checkpoints. When enabled, saved activations - are moved to CPU RAM during forward pass and brought back to GPU during backward pass. - This trades memory for PCIe bandwidth, saving GPU memory at the cost of data transfer time. - Only applies when mode is 'full' or 'selective'. - """ - @dataclass class Compile: diff --git a/torchtitan/distributed/activation_checkpoint.py b/torchtitan/distributed/activation_checkpoint.py index 3cb5378637..8359f71730 100644 --- a/torchtitan/distributed/activation_checkpoint.py +++ b/torchtitan/distributed/activation_checkpoint.py @@ -67,14 +67,6 @@ def _apply_op_sac( Returns: nn.Module: The module with selective activation checkpointing applied. """ - # Use CPU offload if enabled - if ac_config.cpu_offload: - from torchtitan.distributed.activation_checkpoint_offload import ( - apply_selective_ac_with_cpu_offload, - ) - - return apply_selective_ac_with_cpu_offload(module, ac_config, base_fqn=base_fqn) - from torch.utils.checkpoint import ( CheckpointPolicy, create_selective_checkpoint_contexts, @@ -154,14 +146,6 @@ def _apply_full_ac(module: nn.Module, ac_config: ACConfig) -> nn.Module: Returns: nn.Module: The module with full activation checkpointing applied. """ - # Use CPU offload if enabled - if ac_config.cpu_offload: - from torchtitan.distributed.activation_checkpoint_offload import ( - apply_full_ac_with_cpu_offload, - ) - - return apply_full_ac_with_cpu_offload(module, ac_config) - return ptd_checkpoint_wrapper( module, preserve_rng_state=ac_config.preserve_rng_state, @@ -324,18 +308,6 @@ def apply_ac( Returns: None """ - # Special case: CPU offload without activation checkpointing - if ac_config.mode == "none" and ac_config.cpu_offload: - from torchtitan.distributed.activation_checkpoint_offload import ( - apply_offload_wrapper_only, - ) - - for layer_id, transformer_block in model.layers.named_children(): - transformer_block = apply_offload_wrapper_only(transformer_block) - model.layers.register_module(layer_id, transformer_block) - logger.info("Applied activation offloading WITHOUT checkpointing to the model") - return - if ac_config.mode == "memory_budget": assert model_compile_enabled, "Memory budget mode requires model to be compiled" if ac_config.visualize_memory_budget_pareto: @@ -347,7 +319,7 @@ def apply_ac( torch._functorch.config.activation_memory_budget = ac_config.memory_budget logger.info(f"Selected {ac_config.memory_budget} budget option") - elif ac_config.mode != "none": + else: for layer_id, transformer_block in model.layers.named_children(): transformer_block = _apply_ac_to_transformer_block( transformer_block, @@ -359,4 +331,4 @@ def apply_ac( ) model.layers.register_module(layer_id, transformer_block) - logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") + logger.info(f"Applied {ac_config.mode} activation checkpointing to the model") diff --git a/torchtitan/distributed/activation_checkpoint_offload.py b/torchtitan/distributed/activation_checkpoint_offload.py deleted file mode 100644 index dce14874cf..0000000000 --- a/torchtitan/distributed/activation_checkpoint_offload.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -""" -Activation Checkpointing with CPU Offloading Support - -This module extends torchtitan's activation checkpointing with CPU offloading capability, -inspired by DeepSpeed's CPU checkpointing implementation. - -CPU offloading moves activation tensors to CPU RAM during the forward pass and brings them -back to GPU during the backward pass, trading memory for PCIe bandwidth. -""" - -from collections import defaultdict - -import torch -import torch.nn as nn -from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( - checkpoint_wrapper as ptd_checkpoint_wrapper, -) -from torch.utils.checkpoint import ( - CheckpointPolicy, - create_selective_checkpoint_contexts, -) - -from torchtitan.config.job_config import ActivationCheckpoint as ACConfig -from torchtitan.tools.logging import logger - - -def _cpu_offload_context_fn(): - """ - Create a context function for CPU offloading of activation checkpoints. - - This function returns a tuple of contexts that uses saved_tensors_hooks to automatically - offload tensors to CPU when they're saved during forward pass and reload them - to GPU during backward pass. - - Returns: - A tuple of (forward_context, recompute_context) - """ - - def pack_to_cpu(tensor): - """Move tensor to CPU during forward pass""" - if not isinstance(tensor, torch.Tensor): - return tensor - # Only offload CUDA tensors that are floating point and large enough - if tensor.is_cuda and tensor.is_floating_point() and tensor.numel() > 0: - # Use non-blocking transfer for better performance - return tensor.to("cpu", non_blocking=True) - return tensor - - def unpack_from_cpu(tensor): - """Move tensor back to GPU during backward pass""" - if not isinstance(tensor, torch.Tensor): - return tensor - # If tensor is on CPU, move it back to the current CUDA device - if tensor.device.type == "cpu": - return tensor.to(torch.cuda.current_device(), non_blocking=True) - return tensor - - # Return the same context for both forward and recompute phases - ctx = torch.autograd.graph.saved_tensors_hooks(pack_to_cpu, unpack_from_cpu) - return (ctx, ctx) - - -def _cpu_offload_selective_context_fn(ac_config: ACConfig, mm_recompute_shapes: set): - """ - Create a selective checkpoint context with CPU offloading support. - - This combines selective activation checkpointing (choosing which ops to save vs recompute) - with CPU offloading (moving saved tensors to CPU). - - Args: - ac_config: Activation checkpoint configuration - mm_recompute_shapes: Set of matrix multiplication shapes to force recompute - - Returns: - A context function for selective checkpointing with CPU offloading - """ - # Get the default op save list for selective AC - op_sac_save_list = { - torch.ops.aten.mm.default, - torch.ops.aten.addmm.default, - torch.ops.aten.bmm.default, - torch.ops.aten.linear.default, - } - - def _get_custom_policy(meta): - def _custom_policy(ctx, func, *args, **kwargs): - # Always save CPU offload ops - if ( - func == torch.ops.aten._to_copy.default - and "cuda" in str(args[0].device) - and "device" in kwargs - and str(kwargs["device"]) == "cpu" - ): - return CheckpointPolicy.MUST_SAVE - - mode = "recompute" if ctx.is_recompute else "forward" - mm_count_key = f"{mode}_mm_count" - if func == torch.ops.aten.mm.default: - if args[1].shape in mm_recompute_shapes: - return CheckpointPolicy.PREFER_RECOMPUTE - meta[mm_count_key] += 1 - - # Saves output of all compute ops, except every second mm - to_save = func in op_sac_save_list and not ( - func == torch.ops.aten.mm.default and meta[mm_count_key] % 2 == 0 - ) - return ( - CheckpointPolicy.MUST_SAVE - if to_save - else CheckpointPolicy.PREFER_RECOMPUTE - ) - - return _custom_policy - - def selective_checkpointing_with_cpu_offload(): - """Combined context for selective AC + CPU offload""" - meta = defaultdict(int) - ( - selective_forward_ctx, - selective_recompute_ctx, - ) = create_selective_checkpoint_contexts(_get_custom_policy(meta)) - cpu_offload_forward_ctx, cpu_offload_recompute_ctx = _cpu_offload_context_fn() - - # Stack both contexts for forward phase - class CombinedForwardContext: - def __enter__(self): - self.selective = selective_forward_ctx.__enter__() - self.cpu_offload = cpu_offload_forward_ctx.__enter__() - return self - - def __exit__(self, *args): - self.cpu_offload.__exit__(*args) - self.selective.__exit__(*args) - - # Stack both contexts for recompute phase - class CombinedRecomputeContext: - def __enter__(self): - self.selective = selective_recompute_ctx.__enter__() - self.cpu_offload = cpu_offload_recompute_ctx.__enter__() - return self - - def __exit__(self, *args): - self.cpu_offload.__exit__(*args) - self.selective.__exit__(*args) - - return (CombinedForwardContext(), CombinedRecomputeContext()) - - return selective_checkpointing_with_cpu_offload - - -def apply_full_ac_with_cpu_offload(module: nn.Module, ac_config: ACConfig) -> nn.Module: - """ - Apply full activation checkpointing with CPU offloading to the module. - - This will checkpoint all activations and offload them to CPU RAM. - - Args: - module: The module to apply full AC with CPU offload to - ac_config: The activation checkpointing config - - Returns: - The wrapped module with full AC + CPU offload applied - """ - logger.info("Applying full activation checkpointing with CPU offload") - - return ptd_checkpoint_wrapper( - module, - context_fn=_cpu_offload_context_fn, - preserve_rng_state=ac_config.preserve_rng_state, - determinism_check=ac_config.determinism_check, - early_stop=ac_config.early_stop, - debug=ac_config.debug, - ) - - -def apply_selective_ac_with_cpu_offload( - module: nn.Module, - ac_config: ACConfig, - *, - base_fqn: str | None = None, -) -> nn.Module: - """ - Apply selective activation checkpointing with CPU offloading to the module. - - This selectively checkpoints certain operations while offloading saved tensors to CPU. - - Args: - module: The module to apply selective AC with CPU offload to - ac_config: The activation checkpointing config - base_fqn: The base fully qualified name of the module - - Returns: - The wrapped module with selective AC + CPU offload applied - """ - logger.info("Applying selective activation checkpointing with CPU offload") - - # Collect mm shapes to force recompute if configured - mm_recompute_shapes = set() - if len(ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns) > 0: - for module_fqn, submod in module.named_modules(): - fqn = module_fqn - if base_fqn is not None: - fqn = f"{base_fqn}.{module_fqn}" - if not any( - filter_fqn in fqn - for filter_fqn in ac_config.per_op_sac_force_recompute_mm_shapes_by_fqns - ): - continue - if not isinstance(submod, nn.Linear): - raise ValueError( - "per_op_sac_force_recompute_mm_shapes_by_fqns expected to match " - f"a nn.Linear, but got: {submod}" - ) - out_f, in_f = submod.weight.shape - mm_recompute_shapes.add((in_f, out_f)) - - def context_fn_wrapper(): - return _cpu_offload_selective_context_fn(ac_config, mm_recompute_shapes) - - return ptd_checkpoint_wrapper( - module, - context_fn=context_fn_wrapper, - preserve_rng_state=ac_config.preserve_rng_state, - determinism_check=ac_config.determinism_check, - early_stop=ac_config.early_stop, - debug=ac_config.debug, - ) - - -class ActivationOffloadWrapper(nn.Module): - """ - Wrapper that offloads layer activations to CPU without checkpointing/recomputation. - - This keeps all activations but moves them to CPU RAM to save GPU memory. - """ - - def __init__(self, module: nn.Module): - super().__init__() - self.module = module - self._cpu_activations = [] - - def forward(self, *args, **kwargs): - # Move inputs to GPU if they were offloaded - args = tuple( - self._to_gpu(arg) if isinstance(arg, torch.Tensor) else arg for arg in args - ) - kwargs = { - k: self._to_gpu(v) if isinstance(v, torch.Tensor) else v - for k, v in kwargs.items() - } - - # Run forward pass - output = self.module(*args, **kwargs) - - # Offload output to CPU during forward pass - if isinstance(output, torch.Tensor): - output_cpu = output.to("cpu", non_blocking=True) - # Register hook to bring it back for backward - output.register_hook(lambda grad: self._backward_hook(grad, output_cpu)) - return output_cpu - elif isinstance(output, tuple): - output_cpu = tuple( - o.to("cpu", non_blocking=True) if isinstance(o, torch.Tensor) else o - for o in output - ) - # Register hooks for tensor outputs - for i, (o, o_cpu) in enumerate(zip(output, output_cpu)): - if isinstance(o, torch.Tensor): - o.register_hook( - lambda grad, oc=o_cpu: self._backward_hook(grad, oc) - ) - return output_cpu - return output - - def _to_gpu(self, tensor): - """Move tensor from CPU to GPU""" - if tensor.device.type == "cpu": - return tensor.to(torch.cuda.current_device(), non_blocking=True) - return tensor - - def _backward_hook(self, grad, cpu_activation): - """Called during backward to move activation back to GPU""" - return cpu_activation.to(grad.device, non_blocking=True) - - def __getattr__(self, name): - """Forward attribute access to the wrapped module""" - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.module, name) - - -def apply_offload_wrapper_only(module: nn.Module) -> nn.Module: - """ - Apply activation offloading WITHOUT checkpointing. - - This wraps the module to offload all activations to CPU, keeping them in memory - but freeing GPU RAM. No recomputation happens - activations are transferred - back to GPU during backward pass. - - Args: - module: The module to wrap - - Returns: - The wrapped module with activation offloading - """ - return ActivationOffloadWrapper(module) diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index a54ac81be8..fc6064500b 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -62,9 +62,6 @@ def parallelize_deepseekv3( """ use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - # NOTE: CP + FlexAttention now supported in PyTorch 2.9+ (PRs #145896, #146397) - # if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - # raise NotImplementedError("CP support for FlexAttention is still in progress.") if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters From e42846c86bd186d6725bdcff8350532258578a5e Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Tue, 20 Jan 2026 12:00:36 -0800 Subject: [PATCH 06/18] remove MemoryDefragManager --- torchtitan/memory_defrag.py | 101 -------------------- torchtitan/models/deepseek_v3/model/args.py | 6 -- torchtitan/train.py | 11 --- 3 files changed, 118 deletions(-) delete mode 100644 torchtitan/memory_defrag.py diff --git a/torchtitan/memory_defrag.py b/torchtitan/memory_defrag.py deleted file mode 100644 index b5f3761bcf..0000000000 --- a/torchtitan/memory_defrag.py +++ /dev/null @@ -1,101 +0,0 @@ -"""Memory defragmentation utilities for training""" -import logging -from typing import Optional - -import torch -import torch.distributed as dist - -logger = logging.getLogger(__name__) - - -class MemoryDefragManager: - """Manages memory defragmentation during training""" - - def __init__( - self, - enabled: bool = True, - defrag_freq: int = 10, # Defrag every N steps - aggressive: bool = False, - ): - self.enabled = enabled - self.defrag_freq = defrag_freq - self.aggressive = aggressive - self.step_count = 0 - - if self.enabled: - logger.info( - f"MemoryDefragManager enabled: freq={defrag_freq}, aggressive={aggressive}" - ) - - def step(self, step_num: int): - """Called after each training step""" - if not self.enabled: - return - - self.step_count += 1 - - if self.step_count % self.defrag_freq == 0: - self._defragment() - - def _defragment(self): - """Perform memory defragmentation""" - if not self.enabled: - return - - device = torch.cuda.current_device() - - # Get memory stats before - before_reserved = torch.cuda.memory_reserved(device) - before_allocated = torch.cuda.memory_allocated(device) - - # Method 1: Empty cache (basic) - torch.cuda.empty_cache() - - if self.aggressive: - # Method 2: Synchronize and empty cache again - torch.cuda.synchronize() - if dist.is_initialized(): - dist.barrier() - torch.cuda.empty_cache() - - # Get memory stats after - after_reserved = torch.cuda.memory_reserved(device) - after_allocated = torch.cuda.memory_allocated(device) - - freed_mb = (before_reserved - after_reserved) / (1024**2) - - if freed_mb > 0: - logger.info( - f"[Defrag] Freed {freed_mb:.2f} MB " - f"(reserved: {before_reserved/(1024**3):.2f} GB → {after_reserved/(1024**3):.2f} GB, " - f"allocated: {after_allocated/(1024**2):.2f} MB)" - ) - - -def setup_allocator_config( - max_split_size_mb: Optional[int] = None, - garbage_collection_threshold: Optional[float] = None, - roundup_power2_divisions: Optional[int] = None, -): - """Configure PyTorch CUDA allocator for reduced fragmentation""" - import os - - config_parts = ["expandable_segments:True"] - - if max_split_size_mb is not None: - config_parts.append(f"max_split_size_mb:{max_split_size_mb}") - - if garbage_collection_threshold is not None: - config_parts.append( - f"garbage_collection_threshold:{garbage_collection_threshold}" - ) - - if roundup_power2_divisions is not None: - config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") - - config = ",".join(config_parts) - os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config - - logger.info(f"Allocator config: {config}") - - return config diff --git a/torchtitan/models/deepseek_v3/model/args.py b/torchtitan/models/deepseek_v3/model/args.py index cbf7548a72..dd9056ee91 100644 --- a/torchtitan/models/deepseek_v3/model/args.py +++ b/torchtitan/models/deepseek_v3/model/args.py @@ -104,12 +104,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.moe_args.use_grouped_mm = False - # NOTE: CP + FlexAttention now supported in PyTorch 2.9+ (PRs #145896, #146397) - # if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - # raise NotImplementedError( - # "CP support for FlexAttention is still in progress." - # ) - self.moe_args._debug_force_load_balance = ( job_config.debug.moe_force_load_balance ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 9323bf428b..e99b26aafe 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -25,7 +25,6 @@ ) from torchtitan.config import ConfigManager, JobConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.memory_defrag import MemoryDefragManager from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker @@ -108,13 +107,6 @@ def __init__(self, job_config: JobConfig): gc_freq=job_config.training.gc_freq, debug=job_config.training.gc_debug ) - # Initialize memory defragmentation manager - self.defrag_manager = MemoryDefragManager( - enabled=getattr(job_config.training, "enable_memory_defrag", False), - defrag_freq=getattr(job_config.training, "defrag_freq", 1), - aggressive=getattr(job_config.training, "aggressive_defrag", False), - ) - # Initialize detailed memory tracker self.detailed_memory_tracker = DetailedMemoryTracker( enabled=getattr( @@ -900,9 +892,6 @@ def train(self): if memory_profiler: memory_profiler.step() - # Run memory defragmentation if enabled - self.defrag_manager.step(self.step) - # Track memory at step end and optionally clear cache self.detailed_memory_tracker.step_complete(self.step) self.cuda_memory_tracker.measure_all("step_end", self.step) From 94e59dc36c72eeea0b2cd96b6cc443b4eb8c815e Mon Sep 17 00:00:00 2001 From: emozilla Date: Wed, 21 Jan 2026 17:54:47 +0000 Subject: [PATCH 07/18] fast path for initing bfloat16 params on cpu --- torchtitan/models/deepseek_v3/model/model.py | 10 +-- torchtitan/models/moe/__init__.py | 4 +- torchtitan/models/moe/moe.py | 65 ++++++++++++++++---- 3 files changed, 59 insertions(+), 20 deletions(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 0c8917edbc..ab18d11626 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -30,7 +30,7 @@ get_document_mask_mod, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import FeedForward, MoE, fast_init_trunc_normal_, fast_init_normal_ from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -333,8 +333,8 @@ def init_weights(self, init_std: float): linear_list.append(self.wq) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.wo.weight, mean=0.0, std=init_std) self.kv_norm.reset_parameters() if self.q_lora_rank > 0: @@ -464,7 +464,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: with torch.device(buffer_device): self.freqs_cis = precompute_freqs_cis(self.model_args) if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: layer.init_weights(buffer_device=buffer_device) @@ -473,7 +473,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 if self.output is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.output.weight, mean=0.0, std=final_out_std, diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index c932f6aa83..e0d4a45609 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import ExpertRoutingHistogram, FeedForward, MoE, MoEArgs +from .moe import ExpertRoutingHistogram, FeedForward, MoE, MoEArgs, fast_init_trunc_normal_, fast_init_normal_ -__all__ = ["FeedForward", "MoE", "MoEArgs", "ExpertRoutingHistogram"] +__all__ = ["FeedForward", "MoE", "MoEArgs", "ExpertRoutingHistogram", "fast_init_trunc_normal_", "fast_init_normal_"] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 46c2aa4484..92adcc8aa0 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -38,6 +38,45 @@ def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 +def fast_init_trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + """ + Fast truncated normal initialization that handles bfloat16 tensors on CPU. + + When tensors are bfloat16 on CPU, nn.init.trunc_normal_ is extremely slow + because CPUs don't have native bfloat16 support. This function temporarily + converts to float32 for the initialization, then converts back. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + # Initialize in float32 for CPU performance + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.trunc_normal_(temp, mean=mean, std=std, a=a, b=b) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) + + +def fast_init_normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0 +) -> None: + """ + Fast normal initialization that handles bfloat16 tensors on CPU. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.normal_(temp, mean=mean, std=std) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.normal_(tensor, mean=mean, std=std) + + @dataclass class MoEArgs: num_experts: int = 8 @@ -174,9 +213,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=init_std) # NOTE: keeping this for-loop implementation for comparison @@ -456,9 +495,9 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) def _groupmm(x, w, offs): @@ -602,12 +641,12 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) - nn.init.trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) nn.init.zeros_(self.w1_lora_b) nn.init.zeros_(self.w2_lora_b) nn.init.zeros_(self.w3_lora_b) @@ -732,7 +771,7 @@ def init_weights(self, init_std: float, n_layers: int): # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss # when CPU offload is enabled with 3+ GPUs. - nn.init.normal_(self.gate.weight, mean=0.0, std=1.0) + fast_init_normal_(self.gate.weight, mean=0.0, std=1.0) # Normalize rows in-place with torch.no_grad(): @@ -1053,7 +1092,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i if self.shared_experts is not None: self.shared_experts.init_weights(init_std) if self.shared_gate is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.shared_gate.weight, mean=0.0, std=moe_init_std(self.shared_gate.weight.shape[1], n_layers), From 6dd01ddb48347bb347c342834949e0fe9823518d Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Thu, 22 Jan 2026 08:41:44 -0800 Subject: [PATCH 08/18] add bfloat16 optim states, fix page cahce --- torchtitan/components/optimizer.py | 140 ++++++++++++++++++- torchtitan/config/job_config.py | 14 ++ torchtitan/models/deepseek_v3/model/model.py | 15 +- torchtitan/models/moe/__init__.py | 18 ++- torchtitan/models/moe/moe.py | 65 +++++++-- torchtitan/train.py | 28 ++++ 6 files changed, 255 insertions(+), 25 deletions(-) diff --git a/torchtitan/components/optimizer.py b/torchtitan/components/optimizer.py index 77d317c0b9..a17258f8ea 100644 --- a/torchtitan/components/optimizer.py +++ b/torchtitan/components/optimizer.py @@ -21,6 +21,7 @@ from torchtitan.components.ft import FTManager, has_torchft from torchtitan.config import Optimizer as OptimizerConfig, TORCH_DTYPE_MAP from torchtitan.distributed import ParallelDims +from torchtitan.tools.logging import logger # Dion optimizer availability will be checked lazily when needed DION_AVAILABLE = None @@ -32,11 +33,11 @@ def _check_dion_availability(): global DION_AVAILABLE if DION_AVAILABLE is None: try: - from torchtitan.experiments.dion_optimizer.dion import ( + from torchtitan.experiments.dion_optimizer.dion import ( # noqa: F401 Dion, DionMixedPrecisionConfig, ) - from torchtitan.experiments.dion_optimizer.titan_dion import ( + from torchtitan.experiments.dion_optimizer.titan_dion import ( # noqa: F401 DionOptimizersContainer, ) @@ -51,8 +52,8 @@ def _check_muon_availability(): global MUON_AVAILABLE if MUON_AVAILABLE is None: try: - from torchtitan.experiments.dion_optimizer.muon import Muon - from torchtitan.experiments.dion_optimizer.titan_muon import ( + from torchtitan.experiments.dion_optimizer.muon import Muon # noqa: F401 + from torchtitan.experiments.dion_optimizer.titan_muon import ( # noqa: F401 MuonOptimizersContainer, ) @@ -76,6 +77,127 @@ def _check_muon_availability(): T = TypeVar("T", bound=Optimizer) +def preinit_optimizer_states_bf16(optimizers_container: "OptimizersContainer") -> None: + """ + Pre-initialize optimizer states (exp_avg, exp_avg_sq) directly in bfloat16. + This MUST be called BEFORE the first optimizer.step() to avoid fp32 allocation spike. + + This reduces optimizer state memory by ~50% (from fp32 to bf16). + States are allocated in bf16 from the start, avoiding the memory spike from fp32 allocation. + """ + total_params = 0 + total_bytes = 0 + + # For detailed logging + dtype_device_samples = [] + + rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0 + + for opt_idx, optimizer in enumerate(optimizers_container.optimizers): + for pg_idx, param_group in enumerate(optimizer.param_groups): + for p_idx, p in enumerate(param_group["params"]): + if p.requires_grad: + # Log first few params for debugging + if total_params < 5: + dtype_device_samples.append( + f"param[{opt_idx}][{pg_idx}][{p_idx}]: dtype={p.dtype}, device={p.device}, shape={list(p.shape)}" + ) + + # Initialize state dict for this parameter + state = optimizer.state[p] + if len(state) == 0: # Only initialize if not already initialized + state["step"] = torch.tensor(0, dtype=torch.float32) + # Allocate exp_avg and exp_avg_sq in SAME dtype as param for fused Adam compatibility + state["exp_avg"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + state["exp_avg_sq"] = torch.zeros_like( + p, dtype=p.dtype, device=p.device + ) + total_params += 1 + # Calculate bytes based on actual param dtype + bytes_per_element = 2 if p.dtype == torch.bfloat16 else 4 + total_bytes += p.numel() * 2 * bytes_per_element # 2 states + + # Log first few state allocations + if total_params <= 3: + logger.info( + f"[Rank {rank}] State init sample: param dtype={p.dtype}, device={p.device}, " + f"exp_avg dtype={state['exp_avg'].dtype}, device={state['exp_avg'].device}" + ) + + # Log dtype/device samples + for sample in dtype_device_samples: + logger.info(f"[Rank {rank}] {sample}") + + logger.info( + f"[Rank {rank}] Pre-initialized {total_params} optimizer states matching param dtype, " + f"this rank: {total_bytes / 1e9:.2f} GB" + ) + + +class BF16StateOptimizersContainer(Generic[T]): + """ + Wrapper that pre-initializes optimizer states in bfloat16 BEFORE first step. + This prevents the memory spike from fp32 state allocation. + + IMPORTANT: Call init_bf16_states() BEFORE the first step() to avoid + rank skew during state allocation. This should be called after model + setup but before training starts, ideally with a barrier afterwards. + """ + + def __init__( + self, + base_container: "OptimizersContainer", + state_dtype: torch.dtype = torch.bfloat16, + ): + self._base = base_container + self._state_dtype = state_dtype + self._states_initialized = False + + def init_bf16_states(self): + """ + Pre-initialize optimizer states in bf16. + Call this BEFORE training starts, then call a distributed barrier. + This avoids rank skew during the first optimizer.step(). + """ + if not self._states_initialized: + logger.info("Pre-initializing optimizer states in bfloat16...") + preinit_optimizer_states_bf16(self._base) + self._states_initialized = True + logger.info("BF16 optimizer state pre-initialization complete.") + + def step(self, *args, **kwargs) -> None: + # If states weren't pre-initialized, do it now (fallback) + if not self._states_initialized: + logger.warning( + "BF16 optimizer states not pre-initialized! " + "Call init_bf16_states() before training to avoid rank skew." + ) + self.init_bf16_states() + # Call base step + self._base.step(*args, **kwargs) + + def zero_grad(self, *args, **kwargs) -> None: + self._base.zero_grad(*args, **kwargs) + + def state_dict(self) -> dict[str, Any]: + return self._base.state_dict() + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + self._base.load_state_dict(state_dict) + + def __iter__(self): + return iter(self._base) + + def __len__(self): + return len(self._base) + + def __getattr__(self, name): + # Delegate all other attributes to base container + return getattr(self._base, name) + + class OptimizersContainer(Optimizer, Stateful, Generic[T]): """A container for multiple optimizers. @@ -504,7 +626,15 @@ def build_optimizers( use_ft_optimizer=ft_manager.use_async_quorum, ) - return OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + container = OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs) + + # Wrap with BF16 state container if configured + state_dtype = getattr(optimizer_config, "state_dtype", "float32") + if state_dtype == "bfloat16": + logger.info("Using bfloat16 optimizer states (will convert after first step)") + return BF16StateOptimizersContainer(container, torch.bfloat16) + + return container def build_optimizers_with_moe_load_balancing( diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 7bc20b1d62..458fffd9c2 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -287,6 +287,13 @@ class Optimizer: use_triton: bool = False """Whether to use Triton kernel for Newton-Schulz in Muon optimizer.""" + state_dtype: Literal["float32", "bfloat16"] = "float32" + """ + Dtype for optimizer states (exp_avg, exp_avg_sq for Adam/AdamW). + Using bfloat16 reduces memory by ~50% but may affect training stability. + Only applies to Adam/AdamW optimizers. + """ + @dataclass class LRScheduler: @@ -380,6 +387,13 @@ class Training: Whether to clear CUDA cache between training steps to measure minimum memory requirements """ + drop_page_cache_before_training: bool = False + """ + Whether to drop Linux page cache before training starts (after model/optimizer init). + This helps when using FSDP2 CPU offload with mmap'd model weights that fill page cache. + Requires root/sudo or appropriate permissions to write to /proc/sys/vm/drop_caches. + """ + skip_optimizer_step: bool = False """ Whether to skip the optimizer step (for memory profiling purposes only) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 0c8917edbc..4b259da6a2 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -30,7 +30,12 @@ get_document_mask_mod, ScaledDotProductAttentionWrapper, ) -from torchtitan.models.moe import FeedForward, MoE +from torchtitan.models.moe import ( + fast_init_normal_, + fast_init_trunc_normal_, + FeedForward, + MoE, +) from torchtitan.protocols.model import AttentionMasksType from torchtitan.protocols.train_spec import ModelProtocol @@ -333,8 +338,8 @@ def init_weights(self, init_std: float): linear_list.append(self.wq) for linear in linear_list: - nn.init.trunc_normal_(linear.weight, mean=0.0, std=0.02) - nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.wo.weight, mean=0.0, std=init_std) self.kv_norm.reset_parameters() if self.q_lora_rank > 0: @@ -464,7 +469,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: with torch.device(buffer_device): self.freqs_cis = precompute_freqs_cis(self.model_args) if self.tok_embeddings is not None: - nn.init.normal_(self.tok_embeddings.weight) + fast_init_normal_(self.tok_embeddings.weight) for layer in self.layers.values(): if layer is not None: layer.init_weights(buffer_device=buffer_device) @@ -473,7 +478,7 @@ def init_weights(self, buffer_device: torch.device | None = None) -> None: final_out_std = self.model_args.dim**-0.5 cutoff_factor = 3 if self.output is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.output.weight, mean=0.0, std=final_out_std, diff --git a/torchtitan/models/moe/__init__.py b/torchtitan/models/moe/__init__.py index c932f6aa83..fffd491ae4 100644 --- a/torchtitan/models/moe/__init__.py +++ b/torchtitan/models/moe/__init__.py @@ -4,6 +4,20 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from .moe import ExpertRoutingHistogram, FeedForward, MoE, MoEArgs +from .moe import ( + ExpertRoutingHistogram, + fast_init_normal_, + fast_init_trunc_normal_, + FeedForward, + MoE, + MoEArgs, +) -__all__ = ["FeedForward", "MoE", "MoEArgs", "ExpertRoutingHistogram"] +__all__ = [ + "FeedForward", + "MoE", + "MoEArgs", + "ExpertRoutingHistogram", + "fast_init_trunc_normal_", + "fast_init_normal_", +] diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 46c2aa4484..92adcc8aa0 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -38,6 +38,45 @@ def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 +def fast_init_trunc_normal_( + tensor: torch.Tensor, + mean: float = 0.0, + std: float = 1.0, + a: float = -2.0, + b: float = 2.0, +) -> None: + """ + Fast truncated normal initialization that handles bfloat16 tensors on CPU. + + When tensors are bfloat16 on CPU, nn.init.trunc_normal_ is extremely slow + because CPUs don't have native bfloat16 support. This function temporarily + converts to float32 for the initialization, then converts back. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + # Initialize in float32 for CPU performance + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.trunc_normal_(temp, mean=mean, std=std, a=a, b=b) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.trunc_normal_(tensor, mean=mean, std=std, a=a, b=b) + + +def fast_init_normal_( + tensor: torch.Tensor, mean: float = 0.0, std: float = 1.0 +) -> None: + """ + Fast normal initialization that handles bfloat16 tensors on CPU. + """ + if tensor.device.type == "cpu" and tensor.dtype == torch.bfloat16: + with torch.no_grad(): + temp = torch.empty_like(tensor, dtype=torch.float32) + nn.init.normal_(temp, mean=mean, std=std) + tensor.copy_(temp.to(torch.bfloat16)) + else: + nn.init.normal_(tensor, mean=mean, std=std) + + @dataclass class MoEArgs: num_experts: int = 8 @@ -174,9 +213,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(F.silu(self.w1(x)) * self.w3(x)) def init_weights(self, init_std: float = 0.02): - nn.init.trunc_normal_(self.w1.weight, mean=0.0, std=0.02) + fast_init_trunc_normal_(self.w1.weight, mean=0.0, std=0.02) for linear in (self.w2, self.w3): - nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + fast_init_trunc_normal_(linear.weight, mean=0.0, std=init_std) # NOTE: keeping this for-loop implementation for comparison @@ -456,9 +495,9 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) def _groupmm(x, w, offs): @@ -602,12 +641,12 @@ def forward( def init_weights(self, init_std: float, n_layers: int): std_in = moe_init_std(self.w1.shape[-1], n_layers) std_out = moe_init_std(self.w2.shape[0], n_layers) - nn.init.trunc_normal_(self.w1, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3, mean=0.0, std=std_out) - nn.init.trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) - nn.init.trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w1, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3, mean=0.0, std=std_out) + fast_init_trunc_normal_(self.w1_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w2_lora_a, mean=0.0, std=std_in) + fast_init_trunc_normal_(self.w3_lora_a, mean=0.0, std=std_in) nn.init.zeros_(self.w1_lora_b) nn.init.zeros_(self.w2_lora_b) nn.init.zeros_(self.w3_lora_b) @@ -732,7 +771,7 @@ def init_weights(self, init_std: float, n_layers: int): # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss # when CPU offload is enabled with 3+ GPUs. - nn.init.normal_(self.gate.weight, mean=0.0, std=1.0) + fast_init_normal_(self.gate.weight, mean=0.0, std=1.0) # Normalize rows in-place with torch.no_grad(): @@ -1053,7 +1092,7 @@ def init_weights(self, init_std: float, buffer_device: torch.device, n_layers: i if self.shared_experts is not None: self.shared_experts.init_weights(init_std) if self.shared_gate is not None: - nn.init.trunc_normal_( + fast_init_trunc_normal_( self.shared_gate.weight, mean=0.0, std=moe_init_std(self.shared_gate.weight.shape[1], n_layers), diff --git a/torchtitan/train.py b/torchtitan/train.py index e99b26aafe..75c88f684f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -768,7 +768,25 @@ def train_step( # Skip optimizer step if configured (for memory profiling) if not self.job_config.training.skip_optimizer_step: + import datetime + import time as _time + + # Log step start with timestamp for correlation with vmstat + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info(f"[STEP {self.step}] optimizer.step() START @ {_ts}") + + _optim_start = _time.time() self.optimizers.step() + _optim_elapsed = _time.time() - _optim_start + + # Log step end with timing + if self.device.index == 0: + _ts = datetime.datetime.now().strftime("%H:%M:%S") + logger.info( + f"[STEP {self.step}] optimizer.step() END @ {_ts} | Duration: {_optim_elapsed:.2f}s" + ) + self.lr_schedulers.step() else: logger.info("Skipping optimizer step (skip_optimizer_step=True)") @@ -820,6 +838,16 @@ def train(self): job_config = self.job_config self.checkpointer.load(step=job_config.checkpoint.load_step) + + # Pre-initialize bf16 optimizer states if configured + # This must happen BEFORE training to avoid rank skew during first step + if hasattr(self.optimizers, "init_bf16_states"): + self.optimizers.init_bf16_states() + # Barrier to ensure all ranks finish before training starts + if torch.distributed.is_initialized(): + torch.distributed.barrier() + logger.info("All ranks synchronized after bf16 optimizer state init") + logger.info(f"Training starts at step {self.step + 1}") leaf_folder = ( From f18db98bcf50becc9a458e07fb591442c0095a3c Mon Sep 17 00:00:00 2001 From: emozilla Date: Thu, 22 Jan 2026 20:45:54 +0000 Subject: [PATCH 09/18] add reference for init scheme --- torchtitan/models/moe/moe.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 92adcc8aa0..af1f32a7bf 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -34,6 +34,7 @@ class ExpertRoutingHistogram: counts: list[float] +# see https://arxiv.org/pdf/2310.10837 def moe_init_std(dim_in: int, n_layers: int) -> float: return (2 / (dim_in * n_layers)) ** 0.5 @@ -767,6 +768,10 @@ def forward( return top_scores, selected_experts_indices, num_tokens_per_expert def init_weights(self, init_std: float, n_layers: int): + # Init gate with each row normalized + # From "Approximating Two-Layer Feedforward Networks for Efficient Transformers" + # https://arxiv.org/pdf/2310.10837 + # NOTE: Must use in-place operations here. When FSDP wraps parameters as # DTensor, direct .data assignment (e.g., self.gate.weight.data = x) is # silently ignored, leaving weights uninitialized. This causes NaN loss From 2d60a0142ad7a2c9c3175df6ea7c0cb4ff31b14d Mon Sep 17 00:00:00 2001 From: emozilla Date: Fri, 23 Jan 2026 07:24:19 +0000 Subject: [PATCH 10/18] error if cp set but can't import --- torchtitan/models/deepseek_v3/model/model.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 9a215c6fae..245f5d7be5 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -521,7 +521,9 @@ def get_attention_masks( H = self.model_args.n_heads # Number of attention heads # Use CP-aware block mask when Context Parallel is enabled - if cp_mesh is not None and create_cp_block_mask is not None: + if cp_mesh is not None: + if create_cp_block_mask is None: + raise RuntimeError("Cannot do context parallel without a PyTorch that supports `create_cp_block_mask`") return create_cp_block_mask( mask_mod=combined_mask_mod, B=B, From 53eea6b31a3fdf39bde4bebde447ba4d5a66c60c Mon Sep 17 00:00:00 2001 From: emozilla Date: Fri, 23 Jan 2026 07:24:53 +0000 Subject: [PATCH 11/18] overlapped cpu offload muon --- torchtitan/experiments/dion_optimizer/muon.py | 575 +++++++++++++----- 1 file changed, 434 insertions(+), 141 deletions(-) diff --git a/torchtitan/experiments/dion_optimizer/muon.py b/torchtitan/experiments/dion_optimizer/muon.py index 432ab1399f..0ac1602f71 100644 --- a/torchtitan/experiments/dion_optimizer/muon.py +++ b/torchtitan/experiments/dion_optimizer/muon.py @@ -609,27 +609,17 @@ def muon_update_batch_dim_sharded_async( - This is mathematically equivalent to orthogonalizing each expert's weights independently This function processes all params locally without all-to-all or all-gather. + + Optimized for CPU offloading with: + - Double-buffered CUDA streams to overlap transfer and compute + - Batched Newton-Schulz for fewer kernel launches + - Single sync point at end (no intermediate cuda.synchronize()) """ - U = muon_update_pre_orthogonalize( - G=G, - M=M, - momentum=momentum, - nesterov=nesterov, - ) - - # Orthogonalize each tensor locally - # Newton-Schulz treats dim 0 as batch, processing each slice independently - U = [ - muon_update_newton_schulz( - u, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) - for u in U - ] + # Check if we need CPU offloading (tensors are on CPU) + original_device = G[0].device + needs_gpu_transfer = original_device.type != "cuda" - # Compute scaled learning rate + # Compute scaled learning rate upfront # Use the first tensor's shape (they should all be the same shape within a batch) if adjust_lr is None: adjusted_lr = lr @@ -640,16 +630,132 @@ def muon_update_batch_dim_sharded_async( else: raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=X, - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if needs_gpu_transfer: + # PIPELINED MODE: Double-buffered streams for maximum overlap + # Timeline: transfer[i+1] overlaps with compute[i] overlaps with writeback[i-1] + cuda_device = torch.device("cuda") + dtype = M[0].dtype + n_tensors = len(X) + + # Mini-batch size for batched Newton-Schulz (fewer kernel launches) + BATCH_SIZE = 4 + + # Create streams: one for H2D transfers, one for compute, one for D2H transfers + h2d_stream = torch.cuda.Stream() + compute_stream = torch.cuda.Stream() + d2h_stream = torch.cuda.Stream() + + # Double buffer: prefetch next batch while computing current + prefetch_data = None # Will hold (g_batch, m_batch, x_batch, indices) for next iteration + + def prefetch_batch(start_idx): + """Prefetch a batch of tensors to GPU (non-blocking).""" + end_idx = min(start_idx + BATCH_SIZE, n_tensors) + indices = list(range(start_idx, end_idx)) + with torch.cuda.stream(h2d_stream): + g_batch = [G[i].to(dtype=dtype).to(cuda_device, non_blocking=True) for i in indices] + m_batch = [M[i].to(cuda_device, non_blocking=True) for i in indices] + x_batch = [X[i].to(cuda_device, non_blocking=True) for i in indices] + return (g_batch, m_batch, x_batch, indices) + + def compute_batch(g_batch, m_batch, x_batch, indices): + """Compute momentum update and Newton-Schulz on GPU.""" + with torch.cuda.stream(compute_stream): + # Wait for H2D transfer to complete (lightweight stream sync) + compute_stream.wait_stream(h2d_stream) + + u_batch = [] + for j in range(len(indices)): + g_gpu, m_gpu = g_batch[j], m_batch[j] + # Update momentum: M = mu * M + G + m_gpu.mul_(momentum) + m_gpu.add_(g_gpu) + # Compute U + if nesterov: + u_gpu = m_gpu * momentum + g_gpu + else: + u_gpu = m_gpu.clone() + u_batch.append(u_gpu.to(dtype=torch.bfloat16)) + + # Batched Newton-Schulz: stack same-shape tensors for single kernel + if len(u_batch) > 1 and all(u.shape == u_batch[0].shape for u in u_batch): + u_stacked = torch.stack(u_batch, dim=0) + u_stacked = muon_update_newton_schulz(u_stacked, newton_schulz_func, flatten, epsilon) + u_batch = list(u_stacked.unbind(0)) + else: + u_batch = [muon_update_newton_schulz(u, newton_schulz_func, flatten, epsilon) for u in u_batch] + + # Apply weight decay and update + for j in range(len(indices)): + x_batch[j].mul_(1 - lr * weight_decay) + x_batch[j].sub_(u_batch[j] * adjusted_lr) + + return m_batch, x_batch + + def writeback_batch(m_batch, x_batch, indices): + """Write results back to CPU (non-blocking).""" + with torch.cuda.stream(d2h_stream): + # Wait for compute to complete + d2h_stream.wait_stream(compute_stream) + for j, i in enumerate(indices): + M[i].copy_(m_batch[j], non_blocking=True) + X[i].copy_(x_batch[j], non_blocking=True) + + # Pipeline: prefetch first batch + if n_tensors > 0: + prefetch_data = prefetch_batch(0) + + # Main loop with double buffering + for batch_start in range(0, n_tensors, BATCH_SIZE): + # Get current batch (already prefetched) + g_batch, m_batch, x_batch, indices = prefetch_data + + # Start prefetching NEXT batch (overlaps with current compute) + next_start = batch_start + BATCH_SIZE + if next_start < n_tensors: + prefetch_data = prefetch_batch(next_start) + + # Compute current batch + m_batch, x_batch = compute_batch(g_batch, m_batch, x_batch, indices) + + # Writeback current batch (overlaps with next iteration's prefetch/compute) + writeback_batch(m_batch, x_batch, indices) + + # Single sync at end to ensure all D2H transfers complete + torch.cuda.synchronize() + + yield # Single yield to make this a generator + else: + # STANDARD GPU MODE: Process all tensors together (original behavior) + U = muon_update_pre_orthogonalize( + G=G, + M=M, + momentum=momentum, + nesterov=nesterov, + ) + + # Orthogonalize each tensor locally + # Newton-Schulz treats dim 0 as batch, processing each slice independently + U = [ + muon_update_newton_schulz( + u, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + for u in U + ] - yield # Single yield to make this a generator + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) + + yield # Single yield to make this a generator def muon_update_batch_async( @@ -673,146 +779,333 @@ def muon_update_batch_async( Batched version of Muon update. Batch size should be equal to number of GPUs. All tensors in a batch should have identical shape, sharding, and dtype. Identical hyperparameters are used for all tensors in the batch. + + Memory-optimized for CPU offloading: when tensors are on CPU, moves ALL computation + to GPU (momentum update, all_to_all, Newton-Schulz, weight update) then copies back. """ assert len(X) == len(G) assert len(X) == len(M) assert len(X) == world_size - # Update momentum and compute the inputs for orthogonalization - U = muon_update_pre_orthogonalize( - G=to_local(G), - M=to_local(M), - momentum=momentum, - nesterov=nesterov, - ) - - # Get one whole matrix for each device to orthogonalize - if shard_dim is not None: - # Use all-to-all to transform from a batch of shards to a single whole matrix - # https://www.essential.ai/blog/infra - assert ( - process_group is not None - ), "process_group must be provided for sharded DTensors" - assert isinstance(X[0], DTensor), "X should contain DTensors" - assert not isinstance(U[0], DTensor), "U should contain local shards" - - # Debug: print full tensor info before the divisibility check - x0 = X[0] - x0_mesh = x0.device_mesh - x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} - - assert ( - X[0].size(shard_dim) % world_size == 0 - ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ - f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ - f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" - - # Allocate buffers to receive shards of one whole matrix from other devices - single_matrix_shards = [torch.empty_like(u) for u in U] - - # Redistribute the shards to form one unique full tensor on each device - # Sync CUDA before collective to ensure all prior GPU ops are complete - # This can prevent NCCL hangs due to async GPU operations + # Check early if we're in CPU offloading mode + G_local = to_local(G) + M_local = to_local(M) + X_local = to_local(X) + original_device = M_local[0].device + needs_gpu_transfer = original_device.type != "cuda" + + if needs_gpu_transfer: + # ====== CPU OFFLOADING PATH: Do ALL computation on GPU ====== + # This avoids slow CPU foreach operations for momentum and weight updates + cuda_device = torch.device("cuda") + dtype = M_local[0].dtype + + # Transfer G, M to GPU for momentum update + G_gpu = [g.to(dtype=dtype).to(cuda_device, non_blocking=True) for g in G_local] + M_gpu = [m.to(cuda_device, non_blocking=True) for m in M_local] torch.cuda.synchronize() - # N sequential all_gathers - only keep result for our assigned param - single_matrix_shards = None - for param_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + # Momentum update on GPU (equivalent to muon_update_pre_orthogonalize) + torch._foreach_mul_(M_gpu, momentum) + torch._foreach_add_(M_gpu, G_gpu) - # All ranks send their shard of param_idx - dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + if nesterov: + U_gpu = torch._foreach_mul(M_gpu, momentum) + torch._foreach_add_(U_gpu, G_gpu) + else: + # U shares memory with M when not using nesterov + U_gpu = M_gpu + + # Free G_gpu - no longer needed + del G_gpu + + # Convert to bfloat16 for communication + U_gpu = [u.to(dtype=torch.bfloat16) for u in U_gpu] + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + # Use all-to-all to transform from a batch of shards to a single whole matrix + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + + # Validation + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Make contiguous for all_to_all + U_gpu = [u.contiguous() for u in U_gpu] + + # First all_to_all: batch of shards -> single whole matrix + single_matrix_shards = [torch.empty_like(U_gpu[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, U_gpu, group=process_group) + del U_gpu - # Only keep if this is our assigned parameter - if param_idx == device_rank: - single_matrix_shards = gathered - # Otherwise 'gathered' goes out of scope and memory can be freed + yield - yield + # Concatenate shards to form whole matrix + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards - # Concatentate shards to form a whole matrix to orthogonalize - single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, - ) + # Newton-Schulz orthogonalization (on GPU) + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Split result back into shards - # Contiguous is needed for communication to work correctly - orth_shards = [ - x.contiguous() - for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) - ] + # Split result back into shards + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix - # N sequential all_gathers - collect results as we go - for shard_idx in range(world_size): - # Allocate output buffer for this all_gather - gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + # Second all_to_all to redistribute orthogonalized shards + U_orth_gpu = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(U_orth_gpu, orth_shards, group=process_group) + del orth_shards - # All ranks send their shard at index shard_idx - dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + yield - # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} - # We need U[r] = O^r_{device_rank} - # So when shard_idx == device_rank: U[r] = gathered[r] for all r - if shard_idx == device_rank: - for r in range(world_size): - U[r].copy_(gathered[r]) + else: + # Matrices are not sharded, orthogonalize directly + single_matrix = U_gpu[device_rank] + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) + + if process_group is not None and process_group.size() > 1: + U_orth_gpu = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_orth_gpu, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + else: + assert world_size == 1 + U_orth_gpu = [single_matrix] + + # Compute scaled learning rate (use full tensor shape from X[0]) + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Transfer X to GPU for weight update + X_gpu = [x.to(cuda_device, non_blocking=True) for x in X_local] + torch.cuda.synchronize() + + # Weight update on GPU (equivalent to muon_update_post_orthogonalize) + torch._foreach_mul_(X_gpu, 1 - lr * weight_decay) + U_scaled = torch._foreach_mul(U_orth_gpu, adjusted_lr) + torch._foreach_sub_(X_gpu, U_scaled) + del U_scaled, U_orth_gpu + + # Copy M and X back to CPU + for i in range(world_size): + M_local[i].copy_(M_gpu[i], non_blocking=True) + X_local[i].copy_(X_gpu[i], non_blocking=True) - yield + torch.cuda.synchronize() + del M_gpu, X_gpu else: - # Matrices are not sharded, so we can directly orthogonalize - # Get a single matrix corresponding to this device - single_matrix = U[device_rank] - assert not isinstance(single_matrix, DTensor) - - single_matrix = muon_update_newton_schulz( - single_matrix, - newton_schulz_func=newton_schulz_func, - flatten=flatten, - epsilon=epsilon, + # ====== STANDARD GPU PATH ====== + # Update momentum and compute the inputs for orthogonalization + U = muon_update_pre_orthogonalize( + G=G_local, + M=M_local, + momentum=momentum, + nesterov=nesterov, ) - if process_group is not None and process_group.size() > 1: - # Allocate empty tensors to receive updates from other devices - U = [torch.empty_like(u) for u in U] + # Get one whole matrix for each device to orthogonalize + # JQ: This is the N sequential gather version + # if shard_dim is not None: + # # Use all-to-all to transform from a batch of shards to a single whole matrix + # # https://www.essential.ai/blog/infra + # assert ( + # process_group is not None + # ), "process_group must be provided for sharded DTensors" + # assert isinstance(X[0], DTensor), "X should contain DTensors" + # assert not isinstance(U[0], DTensor), "U should contain local shards" + + # # Debug: print full tensor info before the divisibility check + # x0 = X[0] + # x0_mesh = x0.device_mesh + # x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + + # assert ( + # X[0].size(shard_dim) % world_size == 0 + # ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + # f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + # f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # # Allocate buffers to receive shards of one whole matrix from other devices + # single_matrix_shards = [torch.empty_like(u) for u in U] + + # # Redistribute the shards to form one unique full tensor on each device + # # Sync CUDA before collective to ensure all prior GPU ops are complete + # # This can prevent NCCL hangs due to async GPU operations + # torch.cuda.synchronize() + + # # N sequential all_gathers - only keep result for our assigned param + # single_matrix_shards = None + # for param_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(U[param_idx]) for _ in range(world_size)] + + # # All ranks send their shard of param_idx + # dist.all_gather(gathered, U[param_idx].contiguous(), group=process_group) + + # # Only keep if this is our assigned parameter + # if param_idx == device_rank: + # single_matrix_shards = gathered + # # Otherwise 'gathered' goes out of scope and memory can be freed + + # yield + + # # Concatentate shards to form a whole matrix to orthogonalize + # single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + # single_matrix = muon_update_newton_schulz( + # single_matrix, + # newton_schulz_func=newton_schulz_func, + # flatten=flatten, + # epsilon=epsilon, + # ) + + # # Split result back into shards + # # Contiguous is needed for communication to work correctly + # orth_shards = [ + # x.contiguous() + # for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + # ] + + # # N sequential all_gathers - collect results as we go + # for shard_idx in range(world_size): + # # Allocate output buffer for this all_gather + # gathered = [torch.empty_like(orth_shards[shard_idx]) for _ in range(world_size)] + + # # All ranks send their shard at index shard_idx + # dist.all_gather(gathered, orth_shards[shard_idx].contiguous(), group=process_group) + + # # gathered[r] = rank r's orth_shards[shard_idx] = O^r_{shard_idx} + # # We need U[r] = O^r_{device_rank} + # # So when shard_idx == device_rank: U[r] = gathered[r] for all r + # if shard_idx == device_rank: + # for r in range(world_size): + # U[r].copy_(gathered[r]) + + # yield + + # Get one whole matrix for each device to orthogonalize + if shard_dim is not None: + assert process_group is not None, "process_group must be provided for sharded DTensors" + assert isinstance(X[0], DTensor), "X should contain DTensors" + assert not isinstance(U[0], DTensor), "U should contain local shards" + + x0 = X[0] + x0_mesh = x0.device_mesh + x0_mesh_sizes = {name: x0_mesh.size(i) for i, name in enumerate(x0_mesh.mesh_dim_names)} + assert ( + X[0].size(shard_dim) % world_size == 0 + ), f"Shard dimension {shard_dim} size {X[0].size(shard_dim)} is not divisible by world size {world_size}. " \ + f"Tensor info: global_shape={tuple(X[0].shape)}, local_shape={X[0].to_local().shape}, " \ + f"mesh={X[0].device_mesh.mesh_dim_names}, mesh_sizes={x0_mesh_sizes}, placements={X[0].placements}" + + # Sync CUDA before collective to prevent NCCL hangs from async GPU ops + torch.cuda.synchronize() + + single_matrix_shards = [torch.empty_like(U[0]) for _ in range(world_size)] + dist.all_to_all(single_matrix_shards, [u.contiguous() for u in U], group=process_group) - # All gather orthogonalized results from other devices into buffer - work = dist.all_gather( - U, single_matrix.contiguous(), group=process_group, async_op=True + yield + + single_matrix = torch.cat(single_matrix_shards, dim=shard_dim) + del single_matrix_shards + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, ) + + orth_shards = [ + x.contiguous() + for x in torch.tensor_split(single_matrix, world_size, dim=shard_dim) + ] + del single_matrix + + output_shards = [torch.empty_like(orth_shards[0]) for _ in range(world_size)] + dist.all_to_all(output_shards, orth_shards, group=process_group) + del orth_shards + + for i in range(world_size): + U[i].copy_(output_shards[i]) + del output_shards + yield - work.wait() else: - # Single GPU case, no need to gather - assert world_size == 1 - U = [single_matrix] - - # Compute scaled learning rate - # Do this before to_local(X) because we use the full tensor shape, not the shard shape - if adjust_lr is None: - adjusted_lr = lr - elif adjust_lr == "spectral_norm": - adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) - elif adjust_lr == "rms_norm": - adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) - else: - raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + single_matrix = U[device_rank] + assert not isinstance(single_matrix, DTensor) + + single_matrix = muon_update_newton_schulz( + single_matrix, + newton_schulz_func=newton_schulz_func, + flatten=flatten, + epsilon=epsilon, + ) - # Update model parameters with orthogonalized output - muon_update_post_orthogonalize( - X=to_local(X), - U=U, - base_lr=lr, - adjusted_lr=adjusted_lr, - weight_decay=weight_decay, - ) + if process_group is not None and process_group.size() > 1: + U_gathered = [torch.empty_like(single_matrix) for _ in range(world_size)] + work = dist.all_gather( + U_gathered, single_matrix.contiguous(), group=process_group, async_op=True + ) + yield + work.wait() + del single_matrix + U = U_gathered + else: + assert world_size == 1 + U = [single_matrix] + + # Compute scaled learning rate + if adjust_lr is None: + adjusted_lr = lr + elif adjust_lr == "spectral_norm": + adjusted_lr = adjust_lr_spectral_norm(lr, X[0].shape) + elif adjust_lr == "rms_norm": + adjusted_lr = adjust_lr_rms_norm(lr, X[0].shape) + else: + raise ValueError(f"Unknown adjust_lr value: {adjust_lr}") + + # Update model parameters with orthogonalized output + muon_update_post_orthogonalize( + X=X_local, + U=U, + base_lr=lr, + adjusted_lr=adjusted_lr, + weight_decay=weight_decay, + ) def adamw_update_foreach_async( From e8e2cf919cdae17628737274e01e4fbb6f03d256 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 05:47:50 -0800 Subject: [PATCH 12/18] Add FSDP enhancements: partial resharding, bucket size, and prefetch control - fsdp_reshard_after_forward now accepts integer N for partial resharding to N-GPU groups (e.g., N=8 for intra-node NVLink communication) - Add fsdp_bucket_cap_mb config to control gradient reduction bucket size - Add fsdp_disable_prefetch config to disable forward/backward prefetching - Pass new options through to apply_fsdp() in deepseek_v3 and llama4 --- torchtitan/config/job_config.py | 19 +++++- .../models/deepseek_v3/infra/parallelize.py | 2 + torchtitan/models/llama4/infra/parallelize.py | 60 ++++++++++++++----- 3 files changed, 63 insertions(+), 18 deletions(-) diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index 458fffd9c2..c00fcce6f6 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -457,19 +457,22 @@ class Parallelism: only `data_parallel_shard_degree` can be negative. 1 means disabled. """ - fsdp_reshard_after_forward: Literal["default", "always", "never"] = "default" + fsdp_reshard_after_forward: Literal["default", "always", "never"] | int = "default" """ `reshard_after_forward` specifies the policy for applying `reshard_after_forward` within an FSDP setup. `reshard_after_forward` controls parameter behavior after forward, trading off memory and communication. See torch's `fully_shard` API for more documentation on `reshard_after_forward`. - The supported policies include "default", "always" and "never": + The supported policies include "default", "always", "never", or an integer: - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + - integer N: Partially reshard to groups of N GPUs after forward. Must be a factor of + the FSDP shard world size. Use N=8 for intra-node resharding (reduces memory while + keeping communication fast via NVLink). This trades memory for communication. """ tensor_parallel_degree: int = 1 @@ -584,6 +587,18 @@ class Parallelism: Note that this is still an experimental feature. """ + fsdp_bucket_cap_mb: int | None = None + """ + FSDP bucket size in MB for gradient reduction. None means use PyTorch default (25MB). + Smaller values (e.g., 20) can reduce peak memory at the cost of more communication overhead. + """ + + fsdp_disable_prefetch: bool = False + """ + Whether to disable FSDP forward/backward prefetching. Disabling prefetch can reduce memory + at the cost of performance (less overlap of communication and computation). + """ + @dataclass class DeepEP: diff --git a/torchtitan/models/deepseek_v3/infra/parallelize.py b/torchtitan/models/deepseek_v3/infra/parallelize.py index fc6064500b..5009609281 100644 --- a/torchtitan/models/deepseek_v3/infra/parallelize.py +++ b/torchtitan/models/deepseek_v3/infra/parallelize.py @@ -154,6 +154,8 @@ def parallelize_deepseekv3( else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + bucket_cap_mb=job_config.parallelism.fsdp_bucket_cap_mb, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: diff --git a/torchtitan/models/llama4/infra/parallelize.py b/torchtitan/models/llama4/infra/parallelize.py index efc73ccc0e..f5899e2de1 100644 --- a/torchtitan/models/llama4/infra/parallelize.py +++ b/torchtitan/models/llama4/infra/parallelize.py @@ -177,6 +177,8 @@ def parallelize_llama( else None ), gradient_divide_factor=parallel_dims.fsdp_gradient_divide_factor, + bucket_cap_mb=job_config.parallelism.fsdp_bucket_cap_mb, + disable_prefetch=job_config.parallelism.fsdp_disable_prefetch, ) if parallel_dims.dp_replicate_enabled: @@ -298,10 +300,12 @@ def apply_fsdp( reduce_dtype: torch.dtype, pp_enabled: bool, cpu_offload: bool = False, - reshard_after_forward_policy: str = "default", + reshard_after_forward_policy: str | int = "default", ep_degree: int = 1, dp_mod_ep_mesh: DeviceMesh | None = None, gradient_divide_factor: int | None = None, + bucket_cap_mb: int | None = None, + disable_prefetch: bool = False, ): """ Apply data parallelism (via FSDP2) to the model. @@ -313,31 +317,50 @@ def apply_fsdp( reduce_dtype (torch.dtype): The data type to use for reduction operations. pp_enabled (bool): Whether pipeline parallelism is enabled. cpu_offload (bool, optional): Whether to offload model parameters to CPU. Defaults to False. - reshard_after_forward_policy (str, optional): The policy to use for resharding after forward pass. Defaults to "default". - Other options: "never", "always". + reshard_after_forward_policy (str | int, optional): The policy to use for + resharding after forward pass. Defaults to "default". + String options: "never", "always", "default". - "default" applies default resharding behavior, implementing "smart defaults" for known optimal scenarios. - "always" will enable `reshard_after_forward` for all forward passes. - "never" will disable `reshard_after_forward` for all forward passes. + Integer option: N (e.g., 8) for partial resharding to N-GPU groups. + - Reduces peak memory by limiting all-gather buffer size to N GPUs instead of full DP world. + - Use N=8 for intra-node resharding (fast NVLink communication). + - N must be a factor of the FSDP shard world size. + bucket_cap_mb (int | None, optional): FSDP bucket size in MB for gradient + reduction. None means use PyTorch default (25MB). Defaults to None. + disable_prefetch (bool, optional): Whether to disable FSDP forward/backward prefetching. Defaults to False. """ mp_policy = MixedPrecisionPolicy(param_dtype=param_dtype, reduce_dtype=reduce_dtype) fsdp_config = {"mesh": dp_mesh, "mp_policy": mp_policy} if cpu_offload: fsdp_config["offload_policy"] = CPUOffloadPolicy() + if bucket_cap_mb is not None: + logger.warning( + f"bucket_cap_mb={bucket_cap_mb} requested but not supported by FSDP2 fully_shard() API - ignoring" + ) - match reshard_after_forward_policy: - case "always": - reshard_after_forward = True - case "never": - reshard_after_forward = False - case "default": - # For PP, by default do not reshard after forward to avoid per-microbatch - # all-gathers, which can be expensive and non-overlapped - reshard_after_forward = not pp_enabled - case _: - raise ValueError( - f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." - ) + # Handle integer reshard_after_forward (partial resharding to N-GPU groups) + if isinstance(reshard_after_forward_policy, int): + reshard_after_forward = reshard_after_forward_policy + logger.info( + f"Using partial reshard_after_forward={reshard_after_forward} (resharding to {reshard_after_forward}-GPU groups)" + ) + else: + match reshard_after_forward_policy: + case "always": + reshard_after_forward = True + case "never": + reshard_after_forward = False + case "default": + # For PP, by default do not reshard after forward to avoid per-microbatch + # all-gathers, which can be expensive and non-overlapped + reshard_after_forward = not pp_enabled + case _: + raise ValueError( + f"Invalid reshard_after_forward_policy: {reshard_after_forward_policy}." + ) if model.tok_embeddings is not None: fully_shard( @@ -461,6 +484,11 @@ def apply_fsdp( if ep_degree == 1: return + # Skip prefetch setup if disabled + if disable_prefetch: + logger.info("FSDP prefetching is disabled") + return + # forward transformer_blocks = list(model.layers.values()) next_transformer_blocks = transformer_blocks[1:] + [None] From 413377fc949c61d3cabb89bd8755245ddbe896e7 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 05:48:58 -0800 Subject: [PATCH 13/18] Add enhanced metrics and memory monitoring - Add nvidia-smi memory reporting for verification against PyTorch stats - Display active memory (actual tensor usage) as primary metric instead of reserved - Log detailed memory breakdown (active/reserved/nvidia-smi) on rank 0 - Enable profile_memory=True in profiler to track allocations per operation --- torchtitan/components/metrics.py | 64 +++++++++++++++++++++++++++++++- torchtitan/tools/profiling.py | 1 + 2 files changed, 63 insertions(+), 2 deletions(-) diff --git a/torchtitan/components/metrics.py b/torchtitan/components/metrics.py index 62ec4331ef..1ec434bfdf 100644 --- a/torchtitan/components/metrics.py +++ b/torchtitan/components/metrics.py @@ -36,6 +36,8 @@ "max_reserved_pct", "num_alloc_retries", "num_ooms", + "nvidia_smi_used_gib", # nvidia-smi reported memory for verification + "nvidia_smi_used_pct", ], ) @@ -62,6 +64,49 @@ def _to_gib(self, memory_in_bytes): def _to_pct(self, memory): return 100 * memory / self.device_capacity + def _get_nvidia_smi_memory(self): + """Get GPU memory usage from nvidia-smi for verification.""" + try: + import os + import subprocess + + # In SLURM with CUDA_VISIBLE_DEVICES, PyTorch device index 0-7 maps to + # physical GPUs listed in CUDA_VISIBLE_DEVICES. We need the physical GPU index. + cuda_visible = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if cuda_visible: + # CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" means device 0 is physical GPU 0 + # But it could also be "4,5,6,7,0,1,2,3" meaning device 0 is physical GPU 4 + visible_gpus = [ + int(x.strip()) for x in cuda_visible.split(",") if x.strip() + ] + if self.device_index < len(visible_gpus): + physical_gpu_index = visible_gpus[self.device_index] + else: + physical_gpu_index = self.device_index + else: + physical_gpu_index = self.device_index + + result = subprocess.run( + [ + "nvidia-smi", + "--query-gpu=memory.used", + "--format=csv,noheader,nounits", + f"--id={physical_gpu_index}", + ], + capture_output=True, + text=True, + timeout=5, + ) + if result.returncode == 0: + # nvidia-smi reports in MiB + used_mib = float(result.stdout.strip()) + used_gib = used_mib / 1024 + used_pct = (used_mib * 1024 * 1024) / self.device_capacity * 100 + return used_gib, used_pct + except Exception: + pass + return -1.0, -1.0 + def get_peak_stats(self): device_info = device_module.memory_stats(self.device) @@ -83,6 +128,9 @@ def get_peak_stats(self): if num_ooms > 0: logger.warning(f"{num_ooms} {device_type.upper()} OOM errors thrown.") + # Get nvidia-smi memory for verification + nvidia_smi_gib, nvidia_smi_pct = self._get_nvidia_smi_memory() + return DeviceMemStats( max_active_gib, max_active_pct, @@ -90,6 +138,8 @@ def get_peak_stats(self): max_reserved_pct, num_retries, num_ooms, + nvidia_smi_gib, + nvidia_smi_pct, ) def reset_peak_stats(self): @@ -480,16 +530,26 @@ def log( self.logger.log(metrics, step) color = self.color + # Show ACTIVE memory (actual tensor usage) as primary, with reserved and nvidia-smi for verification + # active = actual tensors, reserved = active + cached, nvidia-smi = OS-level GPU memory logger.info( f"{color.red}step: {step:2} " f"{color.green}loss: {global_avg_loss:7.4f} " f"{color.orange}grad_norm: {grad_norm:7.4f} " - f"{color.turquoise}memory: {device_mem_stats.max_reserved_gib:5.2f}GiB" - f"({device_mem_stats.max_reserved_pct:.2f}%) " + f"{color.turquoise}memory: {device_mem_stats.max_active_gib:5.2f}GiB" + f"({device_mem_stats.max_active_pct:.2f}%) " f"{color.blue}tps: {round(tps):,} " f"{color.cyan}tflops: {tflops:,.2f} " f"{color.magenta}mfu: {mfu:.2f}%{color.reset}" ) + # Log detailed memory breakdown for verification (only on rank 0, step 2+) + if step >= 2 and torch.distributed.get_rank() == 0: + logger.info( + f" [Memory Detail] active={device_mem_stats.max_active_gib:.2f}GiB " + f"reserved={device_mem_stats.max_reserved_gib:.2f}GiB " + f"nvidia-smi={device_mem_stats.nvidia_smi_used_gib:.2f}GiB " + f"retries={device_mem_stats.num_alloc_retries}" + ) self.ntokens_since_last_log = 0 self.data_loading_times.clear() diff --git a/torchtitan/tools/profiling.py b/torchtitan/tools/profiling.py index c37345f8d3..4e4d03606c 100644 --- a/torchtitan/tools/profiling.py +++ b/torchtitan/tools/profiling.py @@ -76,6 +76,7 @@ def trace_handler(prof): schedule=torch.profiler.schedule(wait=wait, warmup=warmup, active=active), on_trace_ready=trace_handler, record_shapes=True, + profile_memory=True, # Track memory allocations per operation with_stack=profiling_config.with_stack, with_modules=profiling_config.with_modules, ) as torch_profiler: From 936510c22335342fd5f74c1abda23a5d0571f3f1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 05:54:07 -0800 Subject: [PATCH 14/18] Add aggressive memory manager to reduce CUDA fragmentation - New AggressiveMemoryManager with 4 modes: minimal, balanced, aggressive, maximum - Clears CUDA cache at strategic points (post-backward, post-optimizer) - Add aggressive_memory_mode and aggressive_memory_verbose config options - Integrate into training loop with post_backward(), post_optimizer(), step_complete() hooks --- torchtitan/config/job_config.py | 20 + torchtitan/tools/aggressive_memory_manager.py | 414 ++++++++++++++++++ torchtitan/train.py | 53 +++ 3 files changed, 487 insertions(+) create mode 100644 torchtitan/tools/aggressive_memory_manager.py diff --git a/torchtitan/config/job_config.py b/torchtitan/config/job_config.py index c00fcce6f6..643e6e27e8 100644 --- a/torchtitan/config/job_config.py +++ b/torchtitan/config/job_config.py @@ -431,6 +431,26 @@ class Training: many temporary files. """ + aggressive_memory_mode: Literal[ + "minimal", "balanced", "aggressive", "maximum" + ] | None = None + """ + Enable aggressive memory management to reduce CUDA memory fragmentation. + This clears CUDA cache and Python GC at strategic points (post-backward, post-optimizer). + Modes: + - None: Disabled (default) + - "minimal": Only clear on high fragmentation (<1% overhead) + - "balanced": Clear after backward and optimizer (2-3% overhead) + - "aggressive": Clear frequently with sync (5-8% overhead) + - "maximum": Clear after every operation (10-15% overhead, for debugging) + """ + + aggressive_memory_verbose: bool = False + """ + Enable verbose logging for aggressive memory manager. + Logs detailed memory stats after each clear operation. + """ + @dataclass class Parallelism: diff --git a/torchtitan/tools/aggressive_memory_manager.py b/torchtitan/tools/aggressive_memory_manager.py new file mode 100644 index 0000000000..1c4861cb74 --- /dev/null +++ b/torchtitan/tools/aggressive_memory_manager.py @@ -0,0 +1,414 @@ +""" +Aggressive Memory Manager for reducing CUDA memory fragmentation. + +This module provides aggressive memory clearing strategies to minimize +fragmentation and allocation retries during distributed training. + +Usage: + from torchtitan.tools.aggressive_memory_manager import AggressiveMemoryManager + + # Initialize at start of training + mem_manager = AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + sync_before_clear=True, + defrag_threshold_mb=1000, # Defrag if fragmentation > 1GB + ) + + # In training loop: + loss.backward() + mem_manager.post_backward() + + optimizer.step() + mem_manager.post_optimizer() + + mem_manager.step_complete() +""" + +import gc +import os +import time +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.distributed as dist + +from torchtitan.tools.logging import logger + + +@dataclass +class MemoryStats: + """Current memory statistics""" + + allocated: int + reserved: int + active: int + fragmentation: int + fragmentation_pct: float + num_alloc_retries: int + + +class AggressiveMemoryManager: + """ + Aggressive memory management to minimize CUDA memory fragmentation. + + Key strategies: + 1. Clear cache at strategic points (post-backward, post-optimizer) + 2. Synchronize before clearing to ensure all async ops complete + 3. Force garbage collection to release Python references + 4. Monitor fragmentation and trigger defrag when threshold exceeded + 5. Set optimal allocator configuration + + Args: + clear_after_backward: Clear cache after backward pass + clear_after_optimizer: Clear cache after optimizer step + clear_every_n_steps: Only clear every N steps (1 = every step) + sync_before_clear: Synchronize CUDA before clearing cache + defrag_threshold_mb: Trigger defrag if fragmentation exceeds this (MB) + gc_generation: Python GC generation to collect (0-2, higher = more thorough) + verbose: Log detailed memory stats + rank: Distributed rank (auto-detected if None) + """ + + def __init__( + self, + clear_after_backward: bool = True, + clear_after_optimizer: bool = True, + clear_every_n_steps: int = 1, + sync_before_clear: bool = True, + defrag_threshold_mb: float = 500.0, + gc_generation: int = 1, + verbose: bool = False, + rank: Optional[int] = None, + ): + self.clear_after_backward = clear_after_backward + self.clear_after_optimizer = clear_after_optimizer + self.clear_every_n_steps = clear_every_n_steps + self.sync_before_clear = sync_before_clear + self.defrag_threshold_mb = defrag_threshold_mb + self.gc_generation = gc_generation + self.verbose = verbose + + self.rank = ( + rank + if rank is not None + else (dist.get_rank() if dist.is_initialized() else 0) + ) + + self.step_count = 0 + self.total_clears = 0 + self.total_defrag_time_ms = 0.0 + + # Disable automatic GC - we'll control it manually + gc.disable() + + # Initial cleanup + self._aggressive_clear("initialization") + + if self.rank == 0: + logger.info( + f"[AggressiveMemoryManager] Initialized: " + f"clear_backward={clear_after_backward}, " + f"clear_optimizer={clear_after_optimizer}, " + f"every_n_steps={clear_every_n_steps}, " + f"sync={sync_before_clear}, " + f"defrag_threshold={defrag_threshold_mb}MB" + ) + + @staticmethod + def configure_allocator( + expandable_segments: bool = True, + max_split_size_mb: int = 128, + garbage_collection_threshold: float = 0.8, + roundup_power2_divisions: int = 4, + ) -> str: + """ + Configure PyTorch CUDA allocator for minimal fragmentation. + + Call this BEFORE any CUDA operations (before model creation). + + Args: + expandable_segments: Enable expandable memory segments + max_split_size_mb: Max size of memory splits (smaller = less fragmentation) + garbage_collection_threshold: Trigger GC when this fraction of memory is fragmented + roundup_power2_divisions: Memory rounding granularity + + Returns: + The PYTORCH_CUDA_ALLOC_CONF string that was set + """ + config_parts = [] + + if expandable_segments: + config_parts.append("expandable_segments:True") + + config_parts.append(f"max_split_size_mb:{max_split_size_mb}") + config_parts.append( + f"garbage_collection_threshold:{garbage_collection_threshold}" + ) + config_parts.append(f"roundup_power2_divisions:{roundup_power2_divisions}") + + config_str = ",".join(config_parts) + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = config_str + + return config_str + + def get_memory_stats(self) -> MemoryStats: + """Get current memory statistics""" + if not torch.cuda.is_available(): + return MemoryStats(0, 0, 0, 0, 0.0, 0) + + stats = torch.cuda.memory_stats() + allocated = torch.cuda.memory_allocated() + reserved = torch.cuda.memory_reserved() + active = stats.get("active_bytes.all.current", 0) + fragmentation = reserved - allocated + fragmentation_pct = (fragmentation / reserved * 100) if reserved > 0 else 0.0 + num_retries = stats.get("num_alloc_retries", 0) + + return MemoryStats( + allocated=allocated, + reserved=reserved, + active=active, + fragmentation=fragmentation, + fragmentation_pct=fragmentation_pct, + num_alloc_retries=num_retries, + ) + + def _should_clear(self) -> bool: + """Check if we should clear cache this step""" + return self.step_count % self.clear_every_n_steps == 0 + + def _aggressive_clear(self, reason: str) -> float: + """ + Perform aggressive memory clearing. + + Returns: + Time taken in milliseconds + """ + if not torch.cuda.is_available(): + return 0.0 + + start = time.perf_counter() + + # 1. Synchronize all CUDA streams to ensure ops complete + if self.sync_before_clear: + torch.cuda.synchronize() + + # 2. Python garbage collection (releases tensor references) + gc.collect(self.gc_generation) + + # 3. Clear CUDA cache (releases unused cached memory) + torch.cuda.empty_cache() + + # 4. Optional: Force synchronization after clear + if self.sync_before_clear: + torch.cuda.synchronize() + + elapsed_ms = (time.perf_counter() - start) * 1000 + self.total_clears += 1 + self.total_defrag_time_ms += elapsed_ms + + if self.verbose and self.rank == 0: + stats = self.get_memory_stats() + logger.info( + f"[AggressiveMemoryManager] {reason}: " + f"cleared in {elapsed_ms:.1f}ms, " + f"frag={stats.fragmentation_pct:.1f}%, " + f"reserved={stats.reserved/1e9:.2f}GB" + ) + + return elapsed_ms + + def _check_and_defrag(self, phase: str) -> bool: + """ + Check fragmentation and defrag if needed. + + Returns: + True if defrag was triggered + """ + stats = self.get_memory_stats() + fragmentation_mb = stats.fragmentation / (1024 * 1024) + + if fragmentation_mb > self.defrag_threshold_mb: + self._aggressive_clear(f"defrag_{phase}_frag={fragmentation_mb:.0f}MB") + return True + + return False + + def post_backward(self): + """Call after backward pass completes""" + if self.clear_after_backward and self._should_clear(): + self._check_and_defrag("post_backward") + self._aggressive_clear("post_backward") + + def post_optimizer(self): + """Call after optimizer step completes""" + if self.clear_after_optimizer and self._should_clear(): + self._check_and_defrag("post_optimizer") + self._aggressive_clear("post_optimizer") + + def step_complete(self): + """Call at the end of each training step""" + self.step_count += 1 + + # Always check for high fragmentation + self._check_and_defrag("step_end") + + def get_summary(self) -> str: + """Get summary of memory management activity""" + avg_time = self.total_defrag_time_ms / max(1, self.total_clears) + return ( + f"AggressiveMemoryManager Summary:\n" + f" Total clears: {self.total_clears}\n" + f" Total defrag time: {self.total_defrag_time_ms:.1f}ms\n" + f" Avg time per clear: {avg_time:.2f}ms\n" + f" Steps processed: {self.step_count}" + ) + + +class BackwardMemoryHook: + """ + Register hooks on model parameters to clear memory during backward pass. + + This clears memory incrementally as gradients are computed, rather than + waiting until the end of backward. + + Args: + clear_every_n_params: Clear cache after every N parameter gradients + sync_on_clear: Synchronize before clearing (slower but more thorough) + """ + + def __init__( + self, + clear_every_n_params: int = 10, + sync_on_clear: bool = False, + ): + self.clear_every_n_params = clear_every_n_params + self.sync_on_clear = sync_on_clear + self.param_count = 0 + self.handles = [] + + def _backward_hook(self, grad): + """Hook called when gradient is computed for a parameter""" + self.param_count += 1 + + if self.param_count % self.clear_every_n_params == 0: + if self.sync_on_clear: + torch.cuda.synchronize() + gc.collect(0) # Fast GC (generation 0 only) + torch.cuda.empty_cache() + + return grad + + def register(self, model: torch.nn.Module): + """Register hooks on all model parameters""" + for name, param in model.named_parameters(): + if param.requires_grad: + handle = param.register_post_accumulate_grad_hook( + lambda p, name=name: self._backward_hook(p.grad) + ) + self.handles.append(handle) + + logger.info( + f"[BackwardMemoryHook] Registered on {len(self.handles)} parameters, " + f"clearing every {self.clear_every_n_params} params" + ) + + def remove(self): + """Remove all registered hooks""" + for handle in self.handles: + handle.remove() + self.handles.clear() + + def reset_count(self): + """Reset parameter count (call at start of each backward)""" + self.param_count = 0 + + +def setup_aggressive_memory_environment(): + """ + Set up environment variables for aggressive memory management. + + Call this BEFORE importing torch or creating any CUDA tensors. + """ + # Optimal allocator settings for minimal fragmentation + os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ( + "expandable_segments:True," + "max_split_size_mb:128," + "garbage_collection_threshold:0.8," + "roundup_power2_divisions:4" + ) + + # Disable NCCL async error handling (can cause memory issues) + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "0" + + # Force synchronous CUDA operations for debugging + # os.environ["CUDA_LAUNCH_BLOCKING"] = "1" # Uncomment for debugging + + return os.environ.get("PYTORCH_CUDA_ALLOC_CONF") + + +# Convenience function for quick setup +def create_aggressive_memory_manager( + mode: str = "balanced", + verbose: bool = False, +) -> AggressiveMemoryManager: + """ + Create an AggressiveMemoryManager with preset configurations. + + Args: + mode: One of: + - "minimal": Only clear on high fragmentation + - "balanced": Clear after backward and optimizer + - "aggressive": Clear frequently with sync + - "maximum": Clear after every operation + verbose: Enable verbose logging + + Returns: + Configured AggressiveMemoryManager + """ + if mode == "minimal": + return AggressiveMemoryManager( + clear_after_backward=False, + clear_after_optimizer=False, + clear_every_n_steps=10, + sync_before_clear=False, + defrag_threshold_mb=2000, + gc_generation=0, + verbose=verbose, + ) + elif mode == "balanced": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=False, + defrag_threshold_mb=500, + gc_generation=1, + verbose=verbose, + ) + elif mode == "aggressive": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=200, + gc_generation=2, + verbose=verbose, + ) + elif mode == "maximum": + return AggressiveMemoryManager( + clear_after_backward=True, + clear_after_optimizer=True, + clear_every_n_steps=1, + sync_before_clear=True, + defrag_threshold_mb=100, + gc_generation=2, + verbose=verbose, + ) + else: + raise ValueError( + f"Unknown mode: {mode}. Use minimal/balanced/aggressive/maximum" + ) diff --git a/torchtitan/train.py b/torchtitan/train.py index 75c88f684f..ca82996c2f 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -27,6 +27,7 @@ from torchtitan.distributed import ParallelDims, utils as dist_utils from torchtitan.protocols.model_converter import build_model_converters from torchtitan.tools import utils +from torchtitan.tools.aggressive_memory_manager import create_aggressive_memory_manager from torchtitan.tools.cuda_memory_tracker import CUDAMemoryTracker from torchtitan.tools.detailed_memory_tracker import DetailedMemoryTracker from torchtitan.tools.logging import init_logger, logger @@ -124,6 +125,24 @@ def __init__(self, job_config: JobConfig): ), ) + # Initialize aggressive memory manager to reduce CUDA fragmentation + # This clears cache after backward/optimizer to prevent allocation retries + aggressive_mem_mode = getattr( + job_config.training, "aggressive_memory_mode", None + ) + if aggressive_mem_mode: + self.aggressive_mem_manager = create_aggressive_memory_manager( + mode=aggressive_mem_mode, + verbose=getattr( + job_config.training, "aggressive_memory_verbose", False + ), + ) + logger.info( + f"Aggressive memory manager enabled (mode={aggressive_mem_mode})" + ) + else: + self.aggressive_mem_manager = None + # Set random seed, and maybe enable deterministic mode # (mainly for debugging, expect perf loss). dist_utils.set_determinism( @@ -719,11 +738,37 @@ def forward_backward_step( del pred loss.backward() + # Aggressive memory clearing after backward to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_backward() + return loss def train_step( self, data_iterator: Iterable[tuple[dict[str, torch.Tensor], torch.Tensor]] ): + # AGGRESSIVE cache clearing before step for accurate memory measurements + # Without this, cached memory from previous steps inflates readings + if self.job_config.training.aggressive_memory_mode: + import gc + + # 1. Synchronize all CUDA streams + torch.cuda.synchronize() + # 2. Python garbage collection (all generations for thorough cleanup) + gc.collect(0) + gc.collect(1) + gc.collect(2) + # 3. Clear CUDA cache (releases cached memory back to GPU) + torch.cuda.empty_cache() + # 4. Synchronize again to ensure cache clear completed + torch.cuda.synchronize() + # 5. Second round of clearing (catches any stragglers) + gc.collect(2) + torch.cuda.empty_cache() + # 6. Reset peak stats so we measure THIS step's peak only + torch.cuda.reset_peak_memory_stats() + self.metrics_processor.device_memory_monitor.reset_peak_stats() + self.optimizers.zero_grad() # Save the current step learning rate for logging lr = self.lr_schedulers.schedulers[0].get_last_lr()[0] @@ -780,6 +825,10 @@ def train_step( self.optimizers.step() _optim_elapsed = _time.time() - _optim_start + # Aggressive memory clearing after optimizer to reduce fragmentation + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.post_optimizer() + # Log step end with timing if self.device.index == 0: _ts = datetime.datetime.now().strftime("%H:%M:%S") @@ -833,6 +882,10 @@ def train_step( extra_metrics=extra_metrics, ) + # Signal step complete to aggressive memory manager (triggers defrag check) + if self.aggressive_mem_manager is not None: + self.aggressive_mem_manager.step_complete() + @record def train(self): job_config = self.job_config From 29d89cd48b40d7d0b41c7ea1cf314a2a69025da1 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 05:57:18 -0800 Subject: [PATCH 15/18] Add DeepEP tuning enhancements with model presets and CLI args - Add model presets: qwen3 (2048 dim, 128 experts) and kimi_k2 (7168 dim, 384 experts) - Add init_dist_torchrun() for torchrun environment compatibility - Add CLI arguments: --model, --hidden, --num-experts, --num-topk - Change from group-based to uniform token distribution for routing - Fix MASTER_PORT to be consistent across ranks --- .../torchtitan_deepep_tune/tune_internode.py | 153 ++++++++++++++---- .../tune_intranode_v2.py | 51 +++++- .../torchtitan_deepep_tune/tune_singlenode.py | 5 +- 3 files changed, 168 insertions(+), 41 deletions(-) diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_internode.py b/scripts/deepep/torchtitan_deepep_tune/tune_internode.py index cc7f5fa7c3..7c8b7d9767 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_internode.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_internode.py @@ -29,8 +29,50 @@ print("ERROR: deep_ep not found") sys.exit(1) -sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests") -from utils import bench_kineto, init_dist +DEEPEP_TESTS_PATH = os.environ.get( + "DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests" +) +sys.path.insert(0, DEEPEP_TESTS_PATH) +from utils import bench_kineto + + +def init_dist_torchrun(local_rank: int, num_local_ranks: int): + """ + Initialize distributed for torchrun environment. + torchrun sets: WORLD_SIZE=total_procs, RANK=global_rank, LOCAL_RANK, LOCAL_WORLD_SIZE + But init_dist expects: WORLD_SIZE=num_nodes, RANK=node_rank + """ + import inspect + + world_size = int(os.environ.get("WORLD_SIZE", 1)) + global_rank = int(os.environ.get("RANK", 0)) + + # Calculate node info from torchrun env vars + num_nodes = world_size // num_local_ranks + node_rank = global_rank // num_local_ranks + + ip = os.getenv("MASTER_ADDR", "127.0.0.1") + port = int(os.getenv("MASTER_PORT", "29500")) + + sig = inspect.signature(dist.init_process_group) + params = { + "backend": "nccl", + "init_method": f"tcp://{ip}:{port}", + "world_size": num_nodes * num_local_ranks, + "rank": node_rank * num_local_ranks + local_rank, + } + if "device_id" in sig.parameters: + params["device_id"] = torch.device(f"cuda:{local_rank}") + dist.init_process_group(**params) + torch.set_default_dtype(torch.bfloat16) + torch.set_default_device("cuda") + torch.cuda.set_device(local_rank) + + return ( + dist.get_rank(), + dist.get_world_size(), + dist.new_group(list(range(num_local_ranks * num_nodes))), + ) @dataclass @@ -71,15 +113,19 @@ def __init__( self.hidden = hidden self.num_experts = num_experts self.num_topk = num_topk - self.num_topk_groups = num_topk_groups - # Init distributed + # Init distributed (using torchrun-compatible init) self.local_rank = int(os.environ.get("LOCAL_RANK", 0)) num_local_ranks = int(os.environ.get("LOCAL_WORLD_SIZE", 8)) - self.rank, self.num_ranks, self.group = init_dist( + self.rank, self.num_ranks, self.group = init_dist_torchrun( self.local_rank, num_local_ranks ) - self.num_nodes = int(os.environ.get("WORLD_SIZE", 1)) + # Calculate num_nodes from torchrun env vars + world_size = int(os.environ.get("WORLD_SIZE", 1)) + self.num_nodes = world_size // num_local_ranks + + # num_topk_groups must be <= num_nodes + self.num_topk_groups = min(num_topk_groups, self.num_nodes) self.num_sms = 24 # Buffer sizes (from benchmark_internode.py) @@ -124,39 +170,31 @@ def setup_data(self): * self.rank ) - # Random scores with group-based routing (like Qwen3) + # UNIFORM distribution across all ranks + # Use same seed for reproducibility across all ranks + torch.manual_seed(42) scores = ( torch.randn( (self.num_tokens, self.num_experts), dtype=torch.float32, device="cuda" ).abs() + 1 ) - group_scores = scores.view(self.num_tokens, self.num_nodes, -1).amax(dim=-1) - group_idx = torch.topk( - group_scores, k=self.num_topk_groups, dim=-1, sorted=False - ).indices - - # Create grouped scores (group-limited routing) - masked_scores = scores.clone() - for i in range(self.num_nodes): - mask = (group_idx == i).any(dim=-1, keepdim=True) - node_mask = torch.zeros( - self.num_tokens, self.num_experts, dtype=torch.bool, device="cuda" - ) - start_expert = i * (self.num_experts // self.num_nodes) - end_expert = (i + 1) * (self.num_experts // self.num_nodes) - node_mask[:, start_expert:end_expert] = True - masked_scores = torch.where( - mask & node_mask, - masked_scores, - torch.tensor(-float("inf"), device="cuda"), - ) - self.topk_idx = torch.topk( - masked_scores, self.num_topk, dim=-1, largest=True, sorted=False + scores, self.num_topk, dim=-1, largest=True, sorted=False )[1] self.topk_idx = self.topk_idx.to(deep_ep.topk_idx_t) + # Verify distribution (only on rank 0) + if self.is_rank0(): + rank_idx_check = self.topk_idx // (self.num_experts // self.num_ranks) + tokens_per_rank = [ + (rank_idx_check == r).sum().item() for r in range(self.num_ranks) + ] + print( + f"[uniform] Tokens per rank: min={min(tokens_per_rank)}, max={max(tokens_per_rank)}, " + f"ratio={max(tokens_per_rank)/max(min(tokens_per_rank), 1):.2f}x" + ) + # Get layout ( self.num_tokens_per_rank, @@ -347,6 +385,13 @@ def cleanup(self): dist.destroy_process_group() +# Model presets +MODEL_CONFIGS = { + "qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8}, + "kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8}, +} + + def main(): parser = argparse.ArgumentParser( description="Tune DeepEP for actual torchtitan setup" @@ -354,15 +399,55 @@ def main(): parser.add_argument("--ep-size", type=int, required=True) parser.add_argument("--mode", choices=["quick", "medium", "full"], default="medium") parser.add_argument("--output-dir", default="results") + parser.add_argument( + "--model", + type=str, + choices=["qwen3", "kimi_k2", "custom"], + default="qwen3", + help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)", + ) + parser.add_argument("--num-tokens", type=int, default=4096, help="Number of tokens") + parser.add_argument( + "--hidden", + type=int, + default=None, + help="Hidden dimension (overrides model preset)", + ) + parser.add_argument( + "--num-experts", + type=int, + default=None, + help="Number of experts (overrides model preset)", + ) + parser.add_argument( + "--num-topk", + type=int, + default=None, + help="Top-k experts (overrides model preset)", + ) args = parser.parse_args() - # Qwen3-30B-A3B parameters + # Apply model preset, allow overrides + if args.model in MODEL_CONFIGS: + preset = MODEL_CONFIGS[args.model] + hidden = args.hidden if args.hidden else preset["hidden"] + num_experts = args.num_experts if args.num_experts else preset["num_experts"] + num_topk = args.num_topk if args.num_topk else preset["num_topk"] + else: + hidden = args.hidden if args.hidden else 2048 + num_experts = args.num_experts if args.num_experts else 128 + num_topk = args.num_topk if args.num_topk else 8 + + print( + f"Model: {args.model}, hidden={hidden}, experts={num_experts}, topk={num_topk}, tokens={args.num_tokens}" + ) + tuner = DeepEPTuner( - num_tokens=4096, - hidden=2048, # Qwen3-30B dim - num_experts=128, # Qwen3-30B-A3B - num_topk=8, + num_tokens=args.num_tokens, + hidden=hidden, + num_experts=num_experts, + num_topk=num_topk, num_topk_groups=4, ) diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py b/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py index 441a5f1012..9405fc386e 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py @@ -40,7 +40,7 @@ def init_dist(local_rank: int, num_local_ranks: int): os.environ["RANK"] = str(local_rank) os.environ["WORLD_SIZE"] = str(num_local_ranks) os.environ["MASTER_ADDR"] = "localhost" - os.environ["MASTER_PORT"] = str(29500 + local_rank) + os.environ["MASTER_PORT"] = "29500" # Same port for all ranks dist.init_process_group(backend="nccl", rank=local_rank, world_size=num_local_ranks) torch.cuda.set_device(local_rank) @@ -426,6 +426,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): dist.destroy_process_group() +# Model presets +MODEL_CONFIGS = { + "qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8}, + "kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8}, +} + + if __name__ == "__main__": parser = argparse.ArgumentParser( description="Tune DeepEP intranode configs for TorchTitan" @@ -436,26 +443,58 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace): parser.add_argument( "--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)" ) + parser.add_argument( + "--model", + type=str, + choices=["qwen3", "kimi_k2", "custom"], + default="qwen3", + help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)", + ) parser.add_argument( "--hidden", type=int, - default=2048, - help="Hidden dimension - Qwen3-30B (default: 2048)", + default=None, + help="Hidden dimension (overrides model preset)", ) parser.add_argument( - "--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)" + "--num-topk", + type=int, + default=None, + help="Number of top-k experts (overrides model preset)", ) parser.add_argument( "--num-experts", type=int, - default=128, - help="Number of experts - Qwen3-30B-A3B (default: 128)", + default=None, + help="Number of experts (overrides model preset)", ) parser.add_argument( "--output-dir", type=str, default="results", help="Output directory for results" ) args = parser.parse_args() + # Apply model preset, allow overrides + if args.model in MODEL_CONFIGS: + preset = MODEL_CONFIGS[args.model] + if args.hidden is None: + args.hidden = preset["hidden"] + if args.num_experts is None: + args.num_experts = preset["num_experts"] + if args.num_topk is None: + args.num_topk = preset["num_topk"] + else: + # Custom mode - require explicit values + if args.hidden is None: + args.hidden = 2048 + if args.num_experts is None: + args.num_experts = 128 + if args.num_topk is None: + args.num_topk = 8 + + print( + f"Model: {args.model}, hidden={args.hidden}, experts={args.num_experts}, topk={args.num_topk}" + ) + num_processes = args.num_processes torch.multiprocessing.spawn( test_loop, args=(num_processes, args), nprocs=num_processes diff --git a/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py b/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py index d8164967f2..c73f61f691 100755 --- a/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py +++ b/scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py @@ -31,7 +31,10 @@ sys.exit(1) # Import from DeepEP tests -sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests") +DEEPEP_TESTS_PATH = os.environ.get( + "DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests" +) +sys.path.insert(0, DEEPEP_TESTS_PATH) from utils import bench, init_dist, inplace_unique From 5a091a269c4ee1481ba02dada82b004cf4f3de37 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 06:04:34 -0800 Subject: [PATCH 16/18] Add device mesh visualizer for distributed training - Visualize GPU allocation across DP, PP, TP, CP, EP dimensions - Show mesh structure, submeshes, and coordinate mappings - Visualize expert parallel and context parallel group allocation - Display FSDP sharding details for expert vs non-expert parameters --- torchtitan/tools/mesh_visualizer.py | 415 ++++++++++++++++++++++++++++ 1 file changed, 415 insertions(+) create mode 100644 torchtitan/tools/mesh_visualizer.py diff --git a/torchtitan/tools/mesh_visualizer.py b/torchtitan/tools/mesh_visualizer.py new file mode 100644 index 0000000000..0ba8fecb03 --- /dev/null +++ b/torchtitan/tools/mesh_visualizer.py @@ -0,0 +1,415 @@ +""" +Device Mesh Visualizer for Distributed Training + +Creates comprehensive visualization of how GPUs are allocated across +all parallelism dimensions: DP, PP, TP, CP, EP. +""" + +import os +from typing import Dict + +import torch.distributed as dist +from torch.distributed.device_mesh import DeviceMesh + +from torchtitan.tools.logging import logger + + +def get_rank_info() -> Dict: + """Get current rank's information across all process groups.""" + info = { + "global_rank": dist.get_rank() if dist.is_initialized() else 0, + "world_size": dist.get_world_size() if dist.is_initialized() else 1, + "local_rank": int(os.environ.get("LOCAL_RANK", 0)), + "node_rank": int(os.environ.get("GROUP_RANK", os.environ.get("NODE_RANK", 0))), + } + return info + + +def visualize_mesh_structure( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a detailed text visualization of the device mesh structure. + + Args: + mesh: The DeviceMesh object + parallel_dims: ParallelDims object with all parallelism settings + rank: Current rank (only rank 0 prints full visualization) + + Returns: + String visualization of the mesh + """ + lines = [] + lines.append("=" * 100) + lines.append("DEVICE MESH VISUALIZATION") + lines.append("=" * 100) + + # Basic info + lines.append("\n[CLUSTER INFO]") + lines.append(f" Total GPUs: {parallel_dims.world_size}") + lines.append(f" Nodes: {parallel_dims.world_size // 8} (assuming 8 GPUs/node)") + + # Parallelism dimensions + lines.append("\n[PARALLELISM DIMENSIONS]") + lines.append(f" DP Replicate (HSDP): {parallel_dims.dp_replicate}") + lines.append(f" DP Shard (FSDP): {parallel_dims.dp_shard}") + lines.append(f" Context Parallel: {parallel_dims.cp}") + lines.append(f" Tensor Parallel: {parallel_dims.tp}") + lines.append(f" Pipeline Parallel: {parallel_dims.pp}") + lines.append(f" Expert Parallel: {parallel_dims.ep}") + lines.append(f" Expert TP: {parallel_dims.etp}") + + # Mesh structure + lines.append("\n[MESH STRUCTURE]") + lines.append(f" Mesh dim names: {mesh.mesh_dim_names}") + lines.append(f" Mesh shape: {mesh.mesh.shape}") + + # Log each dimension + for i, (name, size) in enumerate(zip(mesh.mesh_dim_names, mesh.mesh.shape)): + lines.append(f" Dim {i}: {name:20s} = {size}") + + # EP-specific derived dimensions + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append("\n[EXPERT PARALLEL DERIVED DIMENSIONS]") + lines.append(f" dp_shard_mod_ep (DP for non-experts): {dp_shard_mod_ep}") + lines.append(f" dp_shard_in_ep (DP within EP group): {dp_shard_in_ep}") + lines.append(f" ep_group_size (EP degree): {parallel_dims.ep}") + lines.append("") + lines.append(" Formula: dp_shard = dp_shard_mod_ep * dp_shard_in_ep") + lines.append( + f" {parallel_dims.dp_shard} = {dp_shard_mod_ep} * {dp_shard_in_ep}" + ) + lines.append("") + lines.append(" Formula: ep = dp_shard_in_ep * cp") + lines.append( + f" {parallel_dims.ep} = {dp_shard_in_ep} * {parallel_dims.cp}" + ) + + # Submesh info + lines.append("\n[SUBMESHES]") + + # Try to get submesh info + submesh_names = ["dp", "dp_shard_cp", "dp_cp", "ep", "cp", "tp", "pp"] + for name in submesh_names: + try: + submesh = mesh[name] + lines.append( + f" {name:15s}: size={submesh.size():4d}, dim_names={submesh.mesh_dim_names}" + ) + except (KeyError, RuntimeError): + pass + + return "\n".join(lines) + + +def visualize_gpu_allocation( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Create a grid visualization showing GPU allocation. + + For 16 nodes (128 GPUs) with EP=64, CP=8: + - Shows how each GPU maps to (dp_shard_mod_ep, dp_shard_in_ep, cp) coordinates + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("GPU ALLOCATION GRID") + lines.append("=" * 100) + + world_size = parallel_dims.world_size + num_nodes = world_size // 8 + + # For EP-enabled config + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append( + f"\nMesh: [{dp_shard_mod_ep}] x [{dp_shard_in_ep}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard_mod_ep] x [dp_shard_in_ep] x [cp]") + lines.append("") + + # Create mapping from global rank to mesh coordinates + lines.append("GPU -> Mesh Coordinate Mapping:") + lines.append("-" * 80) + lines.append( + f"{'Node':>6} | {'GPU':>4} | {'Rank':>5} | {'dp_mod_ep':>10} | {'dp_in_ep':>10} | {'cp':>4} | {'EP Group':>10}" + ) + lines.append("-" * 80) + + # The mesh is laid out as: dp_shard_mod_ep (slowest) x dp_shard_in_ep x cp (fastest) + for node in range(num_nodes): + for local_gpu in range(8): + global_rank = node * 8 + local_gpu + + # Compute mesh coordinates (assuming row-major ordering) + # Total size = dp_shard_mod_ep * dp_shard_in_ep * cp + cp_coord = global_rank % parallel_dims.cp + dp_in_ep_coord = (global_rank // parallel_dims.cp) % dp_shard_in_ep + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + + # EP group = dp_in_ep_coord * cp + cp_coord (within each dp_shard_mod_ep group) + ep_group = dp_in_ep_coord * parallel_dims.cp + cp_coord + + row = ( + f"{node:>6} | {local_gpu:>4} | {global_rank:>5} | " + f"{dp_mod_ep_coord:>10} | {dp_in_ep_coord:>10} | " + f"{cp_coord:>4} | {ep_group:>10}" + ) + lines.append(row) + + if node < num_nodes - 1: + lines.append("-" * 80) + else: + lines.append( + f"\nMesh: [{parallel_dims.dp_shard}] x [{parallel_dims.cp}] = {world_size} GPUs" + ) + lines.append(" [dp_shard] x [cp]") + + return "\n".join(lines) + + +def visualize_expert_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize which GPUs belong to which Expert Parallel group. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("EXPERT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.ep <= 1: + lines.append("Expert Parallel is disabled (EP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + lines.append(f"\nEP={parallel_dims.ep} experts distributed across GPUs") + lines.append( + f"Each EP group has {parallel_dims.ep} GPUs working on different experts" + ) + lines.append( + f"There are {dp_shard_mod_ep} such EP groups (for FSDP replication of experts)" + ) + lines.append("") + + # Group GPUs by their dp_shard_mod_ep coordinate + lines.append("EP Groups (GPUs that share the same set of experts):") + lines.append("-" * 80) + + for dp_mod_ep_idx in range(dp_shard_mod_ep): + # Find all ranks in this dp_shard_mod_ep group + ranks_in_group = [] + for global_rank in range(world_size): + dp_mod_ep_coord = global_rank // (parallel_dims.cp * dp_shard_in_ep) + if dp_mod_ep_coord == dp_mod_ep_idx: + ranks_in_group.append(global_rank) + + lines.append(f"\nDP_SHARD_MOD_EP group {dp_mod_ep_idx}:") + lines.append( + f" GPUs: {ranks_in_group[:16]}{'...' if len(ranks_in_group) > 16 else ''}" + ) + lines.append(f" Total: {len(ranks_in_group)} GPUs") + lines.append(" These GPUs have IDENTICAL expert parameters (FSDP sharded)") + + return "\n".join(lines) + + +def visualize_context_parallel_groups( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize Context Parallel groups - GPUs that work on different parts of the sequence. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("CONTEXT PARALLEL GROUP ALLOCATION") + lines.append("=" * 100) + + if parallel_dims.cp <= 1: + lines.append("Context Parallel is disabled (CP=1)") + return "\n".join(lines) + + world_size = parallel_dims.world_size + cp = parallel_dims.cp + + lines.append(f"\nCP={cp} - Each sequence is split into {cp} chunks") + lines.append( + "GPUs with the same (dp_shard, ep) coordinates but different cp coordinates" + ) + lines.append("work on different parts of the same sequence.") + lines.append("") + + # Show a few example CP groups + lines.append("Example CP groups (first few):") + lines.append("-" * 80) + + num_cp_groups = world_size // cp + for cp_group_idx in range(min(4, num_cp_groups)): + ranks_in_group = [cp_group_idx * cp + i for i in range(cp)] + lines.append(f"\nCP group {cp_group_idx}:") + lines.append(f" GPUs: {ranks_in_group}") + lines.append(f" These {cp} GPUs process different chunks of the same sequence") + + if num_cp_groups > 4: + lines.append(f"\n... and {num_cp_groups - 4} more CP groups") + + return "\n".join(lines) + + +def visualize_fsdp_sharding( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """ + Visualize FSDP sharding - which GPUs share which parameters. + """ + lines = [] + lines.append("\n" + "=" * 100) + lines.append("FSDP SHARDING VISUALIZATION") + lines.append("=" * 100) + + if parallel_dims.ep > 1: + if parallel_dims.etp == parallel_dims.tp: + dp_shard_mod_ep = ( + parallel_dims.dp_shard * parallel_dims.cp // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // parallel_dims.cp + else: + dp_shard_mod_ep = ( + parallel_dims.dp_shard + * parallel_dims.cp + * parallel_dims.tp + // parallel_dims.ep + ) + dp_shard_in_ep = parallel_dims.ep // (parallel_dims.cp * parallel_dims.tp) + + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + + lines.append("\n[NON-EXPERT PARAMETERS (Attention, Embeddings, etc.)]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + lines.append( + f" All-gather buffer size per param: original_size / {dp_shard_cp_size}" + ) + + lines.append("\n[EXPERT PARAMETERS (MoE experts)]") + lines.append(" FSDP mesh: dp_shard_mod_ep") + lines.append(f" FSDP group size: {dp_shard_mod_ep} GPUs") + lines.append( + f" Each expert's parameters are sharded across {dp_shard_mod_ep} GPUs" + ) + lines.append( + f" All-gather buffer size per expert param: original_size / {dp_shard_mod_ep}" + ) + + lines.append("\n[MEMORY IMPLICATIONS]") + lines.append( + f" Non-expert params: sharded {dp_shard_cp_size}x -> small per-GPU footprint" + ) + lines.append( + f" Expert params: sharded only {dp_shard_mod_ep}x -> larger per-GPU footprint" + ) + lines.append(" ") + lines.append(" As DP increases:") + lines.append( + " - dp_shard_cp increases -> non-expert params get more sharded" + ) + lines.append( + " - dp_shard_mod_ep increases -> expert params get more sharded" + ) + lines.append( + " - BUT: all-gather/reduce-scatter buffers scale with group size!" + ) + + else: + dp_shard_cp_size = parallel_dims.dp_shard * parallel_dims.cp + lines.append("\n[ALL PARAMETERS]") + lines.append(" FSDP mesh: dp_shard_cp") + lines.append(f" FSDP group size: {dp_shard_cp_size} GPUs") + lines.append(f" Each parameter is sharded across {dp_shard_cp_size} GPUs") + + return "\n".join(lines) + + +def create_full_visualization( + mesh: DeviceMesh, + parallel_dims, + rank: int = 0, +) -> str: + """Create a comprehensive visualization of the entire mesh structure.""" + parts = [ + visualize_mesh_structure(mesh, parallel_dims, rank), + visualize_gpu_allocation(mesh, parallel_dims, rank), + visualize_expert_parallel_groups(mesh, parallel_dims, rank), + visualize_context_parallel_groups(mesh, parallel_dims, rank), + visualize_fsdp_sharding(mesh, parallel_dims, rank), + ] + + full_viz = "\n".join(parts) + full_viz += "\n" + "=" * 100 + full_viz += "\nEND OF DEVICE MESH VISUALIZATION" + full_viz += "\n" + "=" * 100 + + return full_viz + + +def log_mesh_visualization(mesh: DeviceMesh, parallel_dims): + """Log the full mesh visualization (only on rank 0).""" + rank = dist.get_rank() if dist.is_initialized() else 0 + + if rank == 0: + viz = create_full_visualization(mesh, parallel_dims, rank) + # Log each line separately for better formatting + for line in viz.split("\n"): + logger.info(f"[MESH-VIZ] {line}") From f6bf1ec61e7b7944e75d147058b66ee6295f7b24 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 06:05:01 -0800 Subject: [PATCH 17/18] Add pipeline parallelism support to DeepSeek V3 model - Add return_outputs parameter for PP compatibility - Accept **kwargs to handle additional PP arguments --- torchtitan/models/deepseek_v3/model/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchtitan/models/deepseek_v3/model/model.py b/torchtitan/models/deepseek_v3/model/model.py index 9a215c6fae..582260169a 100644 --- a/torchtitan/models/deepseek_v3/model/model.py +++ b/torchtitan/models/deepseek_v3/model/model.py @@ -538,6 +538,8 @@ def forward( tokens: torch.Tensor, attention_masks: AttentionMasksType | None = None, position_ids: torch.Tensor | None = None, + return_outputs: bool = False, # For pipeline parallelism compatibility + **kwargs, # Accept additional kwargs for PP compatibility ): """ Forward pass for the Transformer model. From 86cf636d87e77740038a6093d15b9f5dc4573dbc Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Sat, 31 Jan 2026 06:39:26 -0800 Subject: [PATCH 18/18] Add Kimi K2 training configs for 12n baseline and 36n HSDP - kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml: 12-node baseline config - EP=96, CP=16, DP=1, LBS=11, 32k context - Expected: 402 TPS, 67.55 GiB (85.2%), 17.72% MFU - kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml: 36-node HSDP config - EP=96, CP=16, dp_replicate=3, dp_shard=6, LBS=10, 32k context - Expected: 378 TPS, 69.45 GiB (87.6%), 16.64% MFU Both configs include aggressive memory management (mode=maximum). Co-Authored-By: Claude Opus 4.5 --- .../kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml | 72 ++++++++++++++++++ ..._32k_ctx_hsdp_replicate3_shard6_lbs10.toml | 73 +++++++++++++++++++ 2 files changed, 145 insertions(+) create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml create mode 100644 torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml new file mode 100644 index 0000000000..e3e1c9d497 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_12n_ep96_cp16_32k_ctx_lbs11.toml @@ -0,0 +1,72 @@ +# Kimi K2 - 12 nodes - EP=96, CP=16, LBS=11 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep/configs/exp1acd_12n_ep96_cp16_lbs11.toml +# Job ID: 2307 +# +# Expected Performance: +# - TPS: 402 +# - Memory: 67.55 GiB (85.2%) +# - MFU: 17.72% +# - TFLOPS: ~175 +# +# Parallelism: EP=96, CP=16, DP=1 (dp_replicate=1, dp_shard=1) +# Nodes: 12 (96 GPUs) +# Seq Length: 32768 +# Local Batch Size: 11 +# + +[job] +dump_folder = "./outputs/kimi_k2/12n_ep96_cp16_lbs11" +description = "Kimi K2 - 12n EP=96 CP=16 LBS=11" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 11 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_shard_degree = -1 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800 diff --git a/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml new file mode 100644 index 0000000000..7ee95edec6 --- /dev/null +++ b/torchtitan/models/deepseek_v3/train_configs/kimi_k2_36n_ep96_cp16_32k_ctx_hsdp_replicate3_shard6_lbs10.toml @@ -0,0 +1,73 @@ +# Kimi K2 - 36 nodes - EP=96, CP=16, HSDP (dp_replicate=3, dp_shard=6), LBS=10 +# +# Original config: /home/phuc/worklogs/2026-01-30/cp16_sweep_dp/configs/exp1aj_HSDP_r3_s6_lbs10.toml +# Job ID: 2485 +# +# Expected Performance: +# - TPS: 378 +# - Memory: 69.45 GiB (87.6%) +# - MFU: 16.64% +# +# Parallelism: EP=96, CP=16, dp_replicate=3, dp_shard=6 +# HSDP: Shard within 12 nodes, all-reduce between 3 replica groups +# Nodes: 36 (288 GPUs) +# Seq Length: 32768 +# Local Batch Size: 10 +# + +[job] +dump_folder = "./outputs/kimi_k2/36n_ep96_cp16_hsdp_replicate3_shard6_lbs10" +description = "Kimi K2 - 36n HSDP dp_replicate=3 dp_shard=6 LBS=10" + +[profiling] +enable_profiling = false + +[metrics] +log_freq = 1 + +[model] +name = "deepseek_v3" +flavor = "kimi_k2" +hf_assets_path = "/home/phuc/kimi_1t/torchtitan/assets/hf/DeepSeek-V3-Base" + +[optimizer] +name = "AdamW" +lr = 2.2e-4 +state_dtype = "bfloat16" + +[lr_scheduler] +warmup_steps = 2 +decay_ratio = 0.8 + +[training] +dtype = "bfloat16" +local_batch_size = 10 +seq_len = 32768 +steps = 5 +dataset = "c4_test" +enable_cpu_offload = true +clear_cache_between_steps = true +skip_optimizer_step = false +# Aggressive memory management to reduce CUDA fragmentation +aggressive_memory_mode = "maximum" +aggressive_memory_verbose = true + +[parallelism] +data_parallel_replicate_degree = 3 +data_parallel_shard_degree = 6 +expert_parallel_degree = 96 +context_parallel_degree = 16 + +[activation_checkpoint] +mode = "full" + +[compile] +enable = true +components = ["loss"] + +[debug] +moe_force_load_balance = true + +[comm] +init_timeout_seconds = 1800 +train_timeout_seconds = 1800