Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
114b8af
hacky
Jan 19, 2026
1fda9bb
Add Kimi 1T training configs, activation offload, and memory defrag
Jan 20, 2026
b4f226d
backup
xrsrke Jan 20, 2026
82f2afe
Remove experimental configs and debug scripts from tracking
xrsrke Jan 20, 2026
e04c0f6
remove assert for cp, and removed new activation checkpointing
xrsrke Jan 20, 2026
e42846c
remove MemoryDefragManager
xrsrke Jan 20, 2026
94e59dc
fast path for initing bfloat16 params on cpu
jquesnelle Jan 21, 2026
6dd01dd
add bfloat16 optim states, fix page cahce
xrsrke Jan 22, 2026
d36c5d3
Merge remote-tracking branch 'origin/phuc/kimi1t_training'
xrsrke Jan 22, 2026
f18db98
add reference for init scheme
jquesnelle Jan 22, 2026
2d60a01
error if cp set but can't import
jquesnelle Jan 23, 2026
53eea6b
overlapped cpu offload muon
jquesnelle Jan 23, 2026
e8e2cf9
Add FSDP enhancements: partial resharding, bucket size, and prefetch …
xrsrke Jan 31, 2026
413377f
Add enhanced metrics and memory monitoring
xrsrke Jan 31, 2026
936510c
Add aggressive memory manager to reduce CUDA fragmentation
xrsrke Jan 31, 2026
29d89cd
Add DeepEP tuning enhancements with model presets and CLI args
xrsrke Jan 31, 2026
5a091a2
Add device mesh visualizer for distributed training
xrsrke Jan 31, 2026
f6bf1ec
Add pipeline parallelism support to DeepSeek V3 model
xrsrke Jan 31, 2026
6037d27
Merge remote-tracking branch 'temp/remote-branch' into phuc/kimi1t_tr…
xrsrke Jan 31, 2026
86cf636
Add Kimi K2 training configs for 12n baseline and 36n HSDP
xrsrke Jan 31, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 119 additions & 34 deletions scripts/deepep/torchtitan_deepep_tune/tune_internode.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,50 @@
print("ERROR: deep_ep not found")
sys.exit(1)

sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests")
from utils import bench_kineto, init_dist
DEEPEP_TESTS_PATH = os.environ.get(
"DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests"
)
sys.path.insert(0, DEEPEP_TESTS_PATH)
from utils import bench_kineto


def init_dist_torchrun(local_rank: int, num_local_ranks: int):
"""
Initialize distributed for torchrun environment.
torchrun sets: WORLD_SIZE=total_procs, RANK=global_rank, LOCAL_RANK, LOCAL_WORLD_SIZE
But init_dist expects: WORLD_SIZE=num_nodes, RANK=node_rank
"""
import inspect

world_size = int(os.environ.get("WORLD_SIZE", 1))
global_rank = int(os.environ.get("RANK", 0))

# Calculate node info from torchrun env vars
num_nodes = world_size // num_local_ranks
node_rank = global_rank // num_local_ranks

ip = os.getenv("MASTER_ADDR", "127.0.0.1")
port = int(os.getenv("MASTER_PORT", "29500"))

sig = inspect.signature(dist.init_process_group)
params = {
"backend": "nccl",
"init_method": f"tcp://{ip}:{port}",
"world_size": num_nodes * num_local_ranks,
"rank": node_rank * num_local_ranks + local_rank,
}
if "device_id" in sig.parameters:
params["device_id"] = torch.device(f"cuda:{local_rank}")
dist.init_process_group(**params)
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.cuda.set_device(local_rank)

return (
dist.get_rank(),
dist.get_world_size(),
dist.new_group(list(range(num_local_ranks * num_nodes))),
)


@dataclass
Expand Down Expand Up @@ -71,15 +113,19 @@ def __init__(
self.hidden = hidden
self.num_experts = num_experts
self.num_topk = num_topk
self.num_topk_groups = num_topk_groups

# Init distributed
# Init distributed (using torchrun-compatible init)
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
num_local_ranks = int(os.environ.get("LOCAL_WORLD_SIZE", 8))
self.rank, self.num_ranks, self.group = init_dist(
self.rank, self.num_ranks, self.group = init_dist_torchrun(
self.local_rank, num_local_ranks
)
self.num_nodes = int(os.environ.get("WORLD_SIZE", 1))
# Calculate num_nodes from torchrun env vars
world_size = int(os.environ.get("WORLD_SIZE", 1))
self.num_nodes = world_size // num_local_ranks

# num_topk_groups must be <= num_nodes
self.num_topk_groups = min(num_topk_groups, self.num_nodes)
self.num_sms = 24

# Buffer sizes (from benchmark_internode.py)
Expand Down Expand Up @@ -124,39 +170,31 @@ def setup_data(self):
* self.rank
)

# Random scores with group-based routing (like Qwen3)
# UNIFORM distribution across all ranks
# Use same seed for reproducibility across all ranks
torch.manual_seed(42)
scores = (
torch.randn(
(self.num_tokens, self.num_experts), dtype=torch.float32, device="cuda"
).abs()
+ 1
)
group_scores = scores.view(self.num_tokens, self.num_nodes, -1).amax(dim=-1)
group_idx = torch.topk(
group_scores, k=self.num_topk_groups, dim=-1, sorted=False
).indices

# Create grouped scores (group-limited routing)
masked_scores = scores.clone()
for i in range(self.num_nodes):
mask = (group_idx == i).any(dim=-1, keepdim=True)
node_mask = torch.zeros(
self.num_tokens, self.num_experts, dtype=torch.bool, device="cuda"
)
start_expert = i * (self.num_experts // self.num_nodes)
end_expert = (i + 1) * (self.num_experts // self.num_nodes)
node_mask[:, start_expert:end_expert] = True
masked_scores = torch.where(
mask & node_mask,
masked_scores,
torch.tensor(-float("inf"), device="cuda"),
)

self.topk_idx = torch.topk(
masked_scores, self.num_topk, dim=-1, largest=True, sorted=False
scores, self.num_topk, dim=-1, largest=True, sorted=False
)[1]
self.topk_idx = self.topk_idx.to(deep_ep.topk_idx_t)

# Verify distribution (only on rank 0)
if self.is_rank0():
rank_idx_check = self.topk_idx // (self.num_experts // self.num_ranks)
tokens_per_rank = [
(rank_idx_check == r).sum().item() for r in range(self.num_ranks)
]
print(
f"[uniform] Tokens per rank: min={min(tokens_per_rank)}, max={max(tokens_per_rank)}, "
f"ratio={max(tokens_per_rank)/max(min(tokens_per_rank), 1):.2f}x"
)

# Get layout
(
self.num_tokens_per_rank,
Expand Down Expand Up @@ -347,22 +385,69 @@ def cleanup(self):
dist.destroy_process_group()


# Model presets
MODEL_CONFIGS = {
"qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8},
"kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8},
}


def main():
parser = argparse.ArgumentParser(
description="Tune DeepEP for actual torchtitan setup"
)
parser.add_argument("--ep-size", type=int, required=True)
parser.add_argument("--mode", choices=["quick", "medium", "full"], default="medium")
parser.add_argument("--output-dir", default="results")
parser.add_argument(
"--model",
type=str,
choices=["qwen3", "kimi_k2", "custom"],
default="qwen3",
help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)",
)
parser.add_argument("--num-tokens", type=int, default=4096, help="Number of tokens")
parser.add_argument(
"--hidden",
type=int,
default=None,
help="Hidden dimension (overrides model preset)",
)
parser.add_argument(
"--num-experts",
type=int,
default=None,
help="Number of experts (overrides model preset)",
)
parser.add_argument(
"--num-topk",
type=int,
default=None,
help="Top-k experts (overrides model preset)",
)

args = parser.parse_args()

# Qwen3-30B-A3B parameters
# Apply model preset, allow overrides
if args.model in MODEL_CONFIGS:
preset = MODEL_CONFIGS[args.model]
hidden = args.hidden if args.hidden else preset["hidden"]
num_experts = args.num_experts if args.num_experts else preset["num_experts"]
num_topk = args.num_topk if args.num_topk else preset["num_topk"]
else:
hidden = args.hidden if args.hidden else 2048
num_experts = args.num_experts if args.num_experts else 128
num_topk = args.num_topk if args.num_topk else 8

print(
f"Model: {args.model}, hidden={hidden}, experts={num_experts}, topk={num_topk}, tokens={args.num_tokens}"
)

tuner = DeepEPTuner(
num_tokens=4096,
hidden=2048, # Qwen3-30B dim
num_experts=128, # Qwen3-30B-A3B
num_topk=8,
num_tokens=args.num_tokens,
hidden=hidden,
num_experts=num_experts,
num_topk=num_topk,
num_topk_groups=4,
)

Expand Down
51 changes: 45 additions & 6 deletions scripts/deepep/torchtitan_deepep_tune/tune_intranode_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def init_dist(local_rank: int, num_local_ranks: int):
os.environ["RANK"] = str(local_rank)
os.environ["WORLD_SIZE"] = str(num_local_ranks)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(29500 + local_rank)
os.environ["MASTER_PORT"] = "29500" # Same port for all ranks

dist.init_process_group(backend="nccl", rank=local_rank, world_size=num_local_ranks)
torch.cuda.set_device(local_rank)
Expand Down Expand Up @@ -426,6 +426,13 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
dist.destroy_process_group()


# Model presets
MODEL_CONFIGS = {
"qwen3": {"hidden": 2048, "num_experts": 128, "num_topk": 8},
"kimi_k2": {"hidden": 7168, "num_experts": 384, "num_topk": 8},
}


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Tune DeepEP intranode configs for TorchTitan"
Expand All @@ -436,26 +443,58 @@ def test_loop(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
parser.add_argument(
"--num-tokens", type=int, default=4096, help="Number of tokens (default: 4096)"
)
parser.add_argument(
"--model",
type=str,
choices=["qwen3", "kimi_k2", "custom"],
default="qwen3",
help="Model preset: qwen3 (dim=2048, 128 experts), kimi_k2 (dim=7168, 384 experts)",
)
parser.add_argument(
"--hidden",
type=int,
default=2048,
help="Hidden dimension - Qwen3-30B (default: 2048)",
default=None,
help="Hidden dimension (overrides model preset)",
)
parser.add_argument(
"--num-topk", type=int, default=8, help="Number of top-k experts (default: 8)"
"--num-topk",
type=int,
default=None,
help="Number of top-k experts (overrides model preset)",
)
parser.add_argument(
"--num-experts",
type=int,
default=128,
help="Number of experts - Qwen3-30B-A3B (default: 128)",
default=None,
help="Number of experts (overrides model preset)",
)
parser.add_argument(
"--output-dir", type=str, default="results", help="Output directory for results"
)
args = parser.parse_args()

# Apply model preset, allow overrides
if args.model in MODEL_CONFIGS:
preset = MODEL_CONFIGS[args.model]
if args.hidden is None:
args.hidden = preset["hidden"]
if args.num_experts is None:
args.num_experts = preset["num_experts"]
if args.num_topk is None:
args.num_topk = preset["num_topk"]
else:
# Custom mode - require explicit values
if args.hidden is None:
args.hidden = 2048
if args.num_experts is None:
args.num_experts = 128
if args.num_topk is None:
args.num_topk = 8

print(
f"Model: {args.model}, hidden={args.hidden}, experts={args.num_experts}, topk={args.num_topk}"
)

num_processes = args.num_processes
torch.multiprocessing.spawn(
test_loop, args=(num_processes, args), nprocs=num_processes
Expand Down
5 changes: 4 additions & 1 deletion scripts/deepep/torchtitan_deepep_tune/tune_singlenode.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,10 @@
sys.exit(1)

# Import from DeepEP tests
sys.path.insert(0, "/home/phuc/workspace/moe/DeepEP/tests")
DEEPEP_TESTS_PATH = os.environ.get(
"DEEPEP_TESTS_PATH", "/home/phuc/kimi_1t/deepep/tests"
)
sys.path.insert(0, DEEPEP_TESTS_PATH)
from utils import bench, init_dist, inplace_unique


Expand Down
Loading