Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flagscale/runner/backend/backend_megatron.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def generate_run_script(
f.write(f"mkdir -p {system_config.logging.details_dir}\n")
f.write(f"mkdir -p {system_config.logging.tensorboard_dir}\n")
f.write(f"mkdir -p {system_config.logging.wandb_save_dir}\n")
f.write(f"mkdir -p {system_config.logging.straggler_dir}\n")
if system_config.get("straggler_log_dir", None):
f.write(f"mkdir -p {system_config.straggler_log_dir}\n")
f.write("\n")
f.write(f"cd {pkg_dir}\n")
f.write("\n")
Expand Down
16 changes: 15 additions & 1 deletion flagscale/runner/runner_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _get_args_megatron(config: DictConfig):
new_config_dict.update(config_dict["model"])
new_config_dict.update(config_dict["data"])

ignore_keys = ["log_dir", "details_dir", "scripts_dir", "pids_dir"]
ignore_keys = ["log_dir", "details_dir", "scripts_dir", "pids_dir", "straggler_dir"]
# Flatten the dictionary to a list of arguments
args = flatten_dict_to_args(new_config_dict, ignore_keys)

Expand Down Expand Up @@ -118,6 +118,17 @@ def _update_config_train(config: DictConfig):
else os.path.join(exp_dir, "wandb")
)

system.logging.straggler_dir = (
resolve_path(system.logging.straggler_dir, "logging.straggler_dir")
if system.logging.get("straggler_dir", None)
else os.path.join(log_dir, "straggler")
)
system.straggler_log_dir = (
resolve_path(system.straggler_log_dir, "system.straggler_log_dir")
if system.get("straggler_log_dir", None)
else system.logging.straggler_dir
)

# Tokenizer file paths — resolve before passing to the training subprocess,
# which may run with a different cwd (e.g. site-packages when pip-installed).
data = config.train.get("data", None)
Expand Down Expand Up @@ -248,6 +259,9 @@ def _generate_run_script_train(
f.write(f"mkdir -p {system_config.logging.details_dir}\n")
f.write(f"mkdir -p {system_config.logging.tensorboard_dir}\n")
f.write(f"mkdir -p {system_config.logging.wandb_save_dir}\n")
f.write(f"mkdir -p {system_config.logging.straggler_dir}\n")
if system_config.get("straggler_log_dir", None):
f.write(f"mkdir -p {system_config.straggler_log_dir}\n")
f.write("\n")
f.write(f"cd {pkg_dir}\n")
f.write("\n")
Expand Down
105 changes: 105 additions & 0 deletions flagscale/runner/straggler/README_STRAGGLER.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Straggler on Metax C550

## Scope

This note documents the current `straggler` integration for the Metax C550 training path on `main-legacy`.

Important:

- The active branch in use is `main-legacy`.
- The current runnable path is still the legacy training stack.
- The new launcher path `flagscale/runner/launcher/launcher_ssh.py` is not used here.

## Current Code Path

- FlagScale detector logic:
- `flagscale/runner/straggler/`
- Train-side integration:
- `flagscale/train/train.py`

This is a FlagScale-side detector layered on top of the existing Megatron training loop. It does not replace Megatron's native `log_straggler`.

## Metax-Specific Notes

- The implementation avoids new CUDA-only assumptions.
- GPU event profiling is not required by default.
- The practical validation path on Metax should use the mini Aquila config that already passed smoke training.

## Known Pitfalls

- Do not validate this first on the original full 7B config. The practical path we already stabilized on Metax is the mini Aquila config.
- Do not reuse old checkpoints while testing. Use a fresh `exp_dir` and force a missing `checkpoint.load`.
- If the straggler keys are not present in YAML, pass them with `++`.

## Smoke Test

Run from the generated Metax build tree:

```bash
cd /workspace/muxi-flagscale-legacy/build/Metax_C550/muxi-flagscale-legacy

TS=$(date +%Y%m%d_%H%M%S)

python run.py \
--config-path ./examples/aquila/conf \
--config-name train \
action=test \
experiment.exp_dir=/workspace/exp/aquila_straggler_smoke_${TS} \
train.system.checkpoint.load=/workspace/exp/__no_ckpt__/does_not_exist \
train.system.checkpoint.save=/workspace/exp/aquila_straggler_smoke_${TS}/checkpoints \
train.system.use_flash_attn=false \
train.model.attention_backend=unfused \
train.model.num_layers=8 \
train.model.hidden_size=1024 \
train.model.num_attention_heads=16 \
train.model.seq_length=512 \
train.model.max_position_embeddings=512 \
train.model.multiple_of=128 \
train.model.micro_batch_size=1 \
train.model.global_batch_size=8 \
train.model.train_samples=16 \
++train.system.enable_straggler_detection=true \
++train.system.straggler_report_interval=2 \
++train.system.straggler_threshold=1.5 \
++train.system.straggler_warmup_steps=0
```

## Expected Result

- Training starts from random initialization.
- `iteration 1/2` and `iteration 2/2` both complete.
- A straggler report is printed near the end of the short run.
- Report files are written under:

```bash
/workspace/exp/aquila_straggler_smoke_${TS}/logs/straggler
```

## Full Run Example

```bash
TS=$(date +%Y%m%d_%H%M%S)

python run.py \
--config-path ./examples/aquila/conf \
--config-name train \
action=run \
experiment.exp_dir=/workspace/exp/aquila_straggler_run_${TS} \
train.system.checkpoint.load=/workspace/exp/__no_ckpt__/does_not_exist \
train.system.checkpoint.save=/workspace/exp/aquila_straggler_run_${TS}/checkpoints \
train.system.use_flash_attn=false \
train.model.attention_backend=unfused \
train.model.num_layers=8 \
train.model.hidden_size=1024 \
train.model.num_attention_heads=16 \
train.model.seq_length=512 \
train.model.max_position_embeddings=512 \
train.model.multiple_of=128 \
train.model.micro_batch_size=1 \
train.model.global_batch_size=8 \
train.model.train_samples=1600 \
++train.system.enable_straggler_detection=true \
++train.system.straggler_report_interval=20 \
++train.system.straggler_threshold=1.5 \
++train.system.straggler_warmup_steps=10
```
29 changes: 29 additions & 0 deletions flagscale/runner/straggler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
"""FlagScale straggler detection utilities."""

from .comm import CommProfiler, CommStatsCollector, GlooCommHook, NCCLCommHook
from .config import StragglerConfig
from .detector import StragglerDetector
from .healthcheck import ElasticTrainingHealthChecker, NetworkHealthChecker
from .report import StragglerReport
from .section import (
OptionalSectionContext,
SectionContext,
SectionProfiler,
create_section_decorator,
)

__all__ = [
"CommProfiler",
"CommStatsCollector",
"ElasticTrainingHealthChecker",
"GlooCommHook",
"NCCLCommHook",
"NetworkHealthChecker",
"OptionalSectionContext",
"SectionContext",
"SectionProfiler",
"StragglerConfig",
"StragglerDetector",
"StragglerReport",
"create_section_decorator",
]
193 changes: 193 additions & 0 deletions flagscale/runner/straggler/comm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Communication monitoring helpers for straggler analysis."""

import time
from collections import defaultdict
from typing import Any


class CommStatsCollector:
"""Collect basic communication timings."""

def __init__(self, enabled: bool = True):
self.enabled = enabled
self.operation_stats: dict[str, dict[str, Any]] = defaultdict(
lambda: {
"count": 0,
"total_time": 0.0,
"min_time": float("inf"),
"max_time": 0.0,
"rank_times": defaultdict(list),
}
)
self.backend = "unknown"
self.world_size = 1
self.rank = 0

def set_backend_info(self, backend: str, world_size: int, rank: int):
self.backend = backend
self.world_size = world_size
self.rank = rank

def record_operation(
self,
op_type: str,
op_name: str,
start_time: float,
end_time: float,
data_size: int | None = None,
target_ranks: list | None = None,
):
if not self.enabled:
return

duration = end_time - start_time
key = f"{op_type}_{op_name}"
stats = self.operation_stats[key]
stats["count"] += 1
stats["total_time"] += duration
stats["min_time"] = min(stats["min_time"], duration)
stats["max_time"] = max(stats["max_time"], duration)
stats["rank_times"][self.rank].append(duration)

if data_size is not None:
stats["total_data_size"] = stats.get("total_data_size", 0) + data_size

if target_ranks is not None:
stats["target_ranks"] = target_ranks

def get_operation_stats(self, op_type: str, op_name: str) -> dict[str, Any]:
return self.operation_stats[f"{op_type}_{op_name}"].copy()

def get_all_stats(self) -> dict[str, dict[str, Any]]:
return dict(self.operation_stats)

def get_straggler_operations(self, threshold: float = 2.0) -> list:
stragglers = []
for op_key, stats in self.operation_stats.items():
if stats["count"] == 0:
continue
avg_time = stats["total_time"] / stats["count"]
max_time = stats["max_time"]
if avg_time > 0 and max_time / avg_time >= threshold:
stragglers.append(
{
"operation": op_key,
"avg_time": avg_time,
"max_time": max_time,
"slowdown_ratio": max_time / avg_time,
"count": stats["count"],
}
)
return stragglers


class NCCLCommHook:
"""Wrap NCCL collectives with timing."""

def __init__(self, collector: CommStatsCollector):
self.collector = collector

def wrap_all_reduce(self, op_func):
def wrapped(*args, **kwargs):
start_time = time.perf_counter()
result = op_func(*args, **kwargs)
end_time = time.perf_counter()
self.collector.record_operation(
"all_reduce",
"default",
start_time,
end_time,
)
return result

return wrapped

def wrap_broadcast(self, op_func):
def wrapped(*args, **kwargs):
start_time = time.perf_counter()
result = op_func(*args, **kwargs)
end_time = time.perf_counter()
self.collector.record_operation(
"broadcast",
"default",
start_time,
end_time,
)
return result

return wrapped


class GlooCommHook:
"""Wrap Gloo collectives with timing."""

def __init__(self, collector: CommStatsCollector):
self.collector = collector

def wrap_all_reduce(self, op_func):
def wrapped(*args, **kwargs):
start_time = time.perf_counter()
result = op_func(*args, **kwargs)
end_time = time.perf_counter()
self.collector.record_operation(
"all_reduce",
"default",
start_time,
end_time,
)
return result

return wrapped


class CommProfiler:
"""Backend-aware communication profiler."""

def __init__(self, backend: str = "auto", enabled: bool = True):
self.collector = CommStatsCollector(enabled=enabled)
self.hooks = {}

if backend == "auto":
backend = self._detect_backend()

if backend == "nccl":
self.hooks["nccl"] = NCCLCommHook(self.collector)
elif backend == "gloo":
self.hooks["gloo"] = GlooCommHook(self.collector)

self.collector.set_backend_info(backend, 1, 0)

def _detect_backend(self) -> str:
try:
import torch

if torch.cuda.is_available():
return "nccl"
except ImportError:
pass
return "gloo"

def wrap_operation(self, op_type: str, op_func):
backend = self.collector.backend
if backend in self.hooks:
if op_type == "all_reduce":
return self.hooks[backend].wrap_all_reduce(op_func)
if op_type == "broadcast" and hasattr(self.hooks[backend], "wrap_broadcast"):
return self.hooks[backend].wrap_broadcast(op_func)
return op_func

def record_custom_operation(
self,
op_type: str,
op_name: str,
start_time: float,
end_time: float,
data_size: int | None = None,
):
self.collector.record_operation(
op_type,
op_name,
start_time,
end_time,
data_size=data_size,
)
Loading
Loading