diff --git a/flagscale/runner/backend/backend_megatron.py b/flagscale/runner/backend/backend_megatron.py index 9ff7e2bf43..39262c9819 100644 --- a/flagscale/runner/backend/backend_megatron.py +++ b/flagscale/runner/backend/backend_megatron.py @@ -10,6 +10,18 @@ ) from flagscale.runner.utils import get_pkg_dir, logger, parse_hostfile, resolve_path +PERF_MONITOR_RUNNER_KEYS = ( + "enable_perf_monitor", + "perf_log_interval", + "perf_log_dir", + "perf_console_output", + "perf_log_format", + "perf_memory_tracking", + "perf_breakdown", + "perf_max_log_files", + "perf_model_type", +) + class MegatronBackend(BackendBase): def __init__(self, config: DictConfig): @@ -20,6 +32,7 @@ def __init__(self, config: DictConfig): def _prepare(self): _update_config_train(self.config) + self._prepare_perf_monitor_config() self.user_args = _get_args_megatron(self.config) self.rdzv_id = datetime.now().strftime("%Y%m%d_%H%M%S.%f") self.user_envs = self.config.experiment.get("envs", {}) @@ -30,6 +43,22 @@ def _prepare(self): logger.info("\n************** configuration **************") logger.info(f"\n{OmegaConf.to_yaml(self.config)}") + def _prepare_perf_monitor_config(self): + system_config = self.config.train.system + runner_config = self.config.experiment.runner + + OmegaConf.set_struct(system_config, False) + for key in PERF_MONITOR_RUNNER_KEYS: + if runner_config.get(key, None) is not None and system_config.get(key, None) is None: + system_config[key] = runner_config.get(key) + + if system_config.get("perf_log_dir", None) is not None: + system_config.perf_log_dir = resolve_path( + system_config.perf_log_dir, "system.perf_log_dir" + ) + elif system_config.get("enable_perf_monitor", False): + system_config.perf_log_dir = os.path.join(system_config.logging.log_dir, "perf_monitor") + def generate_run_script( self, config, @@ -78,6 +107,8 @@ def generate_run_script( f.write(f"mkdir -p {system_config.logging.details_dir}\n") f.write(f"mkdir -p {system_config.logging.tensorboard_dir}\n") f.write(f"mkdir -p {system_config.logging.wandb_save_dir}\n") + if system_config.get("perf_log_dir", None): + f.write(f"mkdir -p {system_config.perf_log_dir}\n") f.write("\n") f.write(f"cd {pkg_dir}\n") f.write("\n") diff --git a/flagscale/runner/launcher/launcher_ssh.py b/flagscale/runner/launcher/launcher_ssh.py index 82eb3d044d..f238cdc8c2 100644 --- a/flagscale/runner/launcher/launcher_ssh.py +++ b/flagscale/runner/launcher/launcher_ssh.py @@ -116,6 +116,24 @@ def _get_runner_cmd_train( del runner_args["nsys_rep_file_path"] if "deploy" in runner_args: del runner_args["deploy"] + if "enable_perf_monitor" in runner_args: + del runner_args["enable_perf_monitor"] + if "perf_log_interval" in runner_args: + del runner_args["perf_log_interval"] + if "perf_log_dir" in runner_args: + del runner_args["perf_log_dir"] + if "perf_console_output" in runner_args: + del runner_args["perf_console_output"] + if "perf_log_format" in runner_args: + del runner_args["perf_log_format"] + if "perf_memory_tracking" in runner_args: + del runner_args["perf_memory_tracking"] + if "perf_breakdown" in runner_args: + del runner_args["perf_breakdown"] + if "perf_max_log_files" in runner_args: + del runner_args["perf_max_log_files"] + if "perf_model_type" in runner_args: + del runner_args["perf_model_type"] runner_args["rdzv_id"] = rdzv_id # runner_args["master_addr"] = master_addr # runner_args["master_port"] = master_port diff --git a/flagscale/train/megatron/training/arguments_fs.py b/flagscale/train/megatron/training/arguments_fs.py index 1bf27fc328..0765b27c30 100644 --- a/flagscale/train/megatron/training/arguments_fs.py +++ b/flagscale/train/megatron/training/arguments_fs.py @@ -763,6 +763,73 @@ def _add_regularization_args(parser): help='If set, disable Nesterov momentum for muon') return parser +def _add_perf_monitor_args(parser): + group = parser.add_argument_group(title="flagscale perf monitor") + + group.add_argument( + "--enable-perf-monitor", + action="store_true", + default=False, + help="Enable FlagScale performance monitoring during training.", + ) + group.add_argument( + "--perf-log-interval", + type=int, + default=10, + help="Log performance metrics every N iterations.", + ) + group.add_argument( + "--perf-log-dir", + type=str, + default=None, + help="Directory used to save performance monitor logs.", + ) + group.add_argument( + "--perf-console-output", + action="store_true", + default=False, + help="Also emit performance monitor logs to stdout on rank 0.", + ) + group.add_argument( + "--perf-log-format", + type=str, + choices=["text", "json", "both"], + default="both", + help="Output format for performance monitor files.", + ) + group.add_argument( + "--perf-memory-tracking", + dest="perf_memory_tracking", + action="store_true", + help="Track CUDA memory usage in the performance monitor.", + ) + group.add_argument( + "--no-perf-memory-tracking", + dest="perf_memory_tracking", + action="store_false", + help="Disable CUDA memory tracking in the performance monitor.", + ) + group.set_defaults(perf_memory_tracking=True) + group.add_argument( + "--perf-breakdown", + action="store_true", + default=False, + help="Include estimated component breakdowns in performance logs.", + ) + group.add_argument( + "--perf-max-log-files", + type=int, + default=10, + help="Maximum number of historical performance log files to keep.", + ) + group.add_argument( + "--perf-model-type", + type=str, + choices=["auto", "gpt", "llama", "qwen", "mixtral", "aquila", "moe"], + default="auto", + help="Model type hint used for FLOPS estimation.", + ) + return parser def _add_flagos_args(parser): group = parser.add_argument_group(title="flagscale transformer engine fl") @@ -878,6 +945,7 @@ def add_flagscale_arguments(parser): parser = _add_auto_skip_spiky_loss(parser) parser = _add_peft_args(parser) parser = _add_regularization_args(parser) + parser = _add_perf_monitor_args(parser) parser = _add_flagos_args(parser) parser = _add_engram_args(parser) return parser diff --git a/flagscale/train/megatron/training/training.py b/flagscale/train/megatron/training/training.py index 674403e3dc..4fa38b8545 100644 --- a/flagscale/train/megatron/training/training.py +++ b/flagscale/train/megatron/training/training.py @@ -7,9 +7,11 @@ import functools import gc import inspect +import json import logging import math import os +import socket import sys from typing import Any, Optional @@ -138,6 +140,12 @@ from megatron.training.global_vars import get_spiky_loss_detector from megatron.training.fs_theoretical_memory_usage import report_theoretical_memory as fs_report_theoretical_memory from megatron.plugin.hetero.parallel_context import get_parallel_context +from flagscale.train.perf_monitor.hooks import ( + initialize_perf_monitor, + perf_monitor_end_iteration, + perf_monitor_end_training, + perf_monitor_start_iteration, +) stimer = StragglerDetector() @@ -2453,6 +2461,7 @@ def train( timers('interval-time', log_level=0).start(barrier=True) print_datetime('before the start of training step') report_memory_flag = True + perf_callback = initialize_perf_monitor(args) pre_hook_enabled = False should_exit = False exit_code = 0 @@ -2483,6 +2492,7 @@ def train( num_microbatches = get_num_microbatches() + writer = get_tensorboard_writer() wandb_writer = get_wandb_writer() if wandb_writer and args.wandb_log_model: # wandb.watch's log_freg needs to take the accumulated number of microbatches into account @@ -2665,7 +2675,8 @@ def get_e2e_base_metrics(): model, optimizer, iteration, ref_state_dict, ) train_data_iterator = buffered_rollouts - + if perf_callback is not None: + perf_monitor_start_iteration(iteration) ft_integration.on_training_step_start() ( loss_dict, @@ -2679,6 +2690,8 @@ def get_e2e_base_metrics(): forward_step_func, train_data_iterator, model, optimizer, opt_param_scheduler, config, forward_backward_func ) ft_integration.on_training_step_end() + if perf_callback is not None: + perf_monitor_end_iteration(iteration, writer, wandb_writer) if should_checkpoint: save_checkpoint_and_time( iteration, @@ -2891,6 +2904,7 @@ def get_e2e_base_metrics(): # Flush TensorBoard, WandB writers and one-logger. writer = get_tensorboard_writer() + perf_monitor_end_training(writer, wandb_writer) if writer: writer.flush() diff --git a/flagscale/train/perf_monitor/READ_PERF_MONITOR.md b/flagscale/train/perf_monitor/READ_PERF_MONITOR.md new file mode 100644 index 0000000000..670551d47c --- /dev/null +++ b/flagscale/train/perf_monitor/READ_PERF_MONITOR.md @@ -0,0 +1,120 @@ +# Perf Monitor on Metax C550 + +## Scope + +This note documents the current `perf_monitor` path for Metax C550 on `main-legacy`. + +Important: + +- The current runnable branch is `main-legacy`. +- The active code path is the legacy runner: + - `flagscale/runner/runner_train.py` +- The new runner launcher path is not active in this branch. + +## Current Code Path + +- Runner integration: + - `flagscale/runner/runner_train.py` + - `run.py` + - `flagscale/runner/auto_tuner/tuner.py` +- Monitor service: + - `flagscale/runner/elastic/monitor_launcher.py` + - `flagscale/runner/elastic/monitor_service.py` + - `flagscale/runner/elastic/diagnostic.py` + +## Metax-Specific Notes + +- This monitor path is mostly process/log based and does not depend on `nvidia-smi`. +- Metax-specific diagnostic keywords were added for: + - `maca out of memory` + - `mxkw` + - `ioctl create queue block timeout` + +## Compatibility Aliases + +For convenience, the legacy monitor path also accepts: + +- `++experiment.runner.enable_perf_monitor=true` +- `++experiment.runner.perf_monitor_interval=5` + +These are mapped internally to the legacy keys: + +- `enable_perf_monitor` -> `enable_monitoring` +- `perf_monitor_interval` -> `monitor_interval` + +## Known Pitfalls + +- In the legacy runner, monitor enablement must be propagated to each node. This path was fixed in `runner_train.py`; do not bypass it with custom launch wrappers. +- Use a fresh `exp_dir` and a missing `checkpoint.load` during validation to avoid resume mismatches. +- Validate this first on the mini Aquila config, not on the original full 7B config. + +## Smoke Test + +```bash +cd /workspace/muxi-flagscale-legacy/build/Metax_C550/muxi-flagscale-legacy + +TS=$(date +%Y%m%d_%H%M%S) + +python run.py \ + --config-path ./examples/aquila/conf \ + --config-name train \ + action=test \ + experiment.exp_dir=/workspace/exp/aquila_perf_smoke_${TS} \ + train.system.checkpoint.load=/workspace/exp/__no_ckpt__/does_not_exist \ + train.system.checkpoint.save=/workspace/exp/aquila_perf_smoke_${TS}/checkpoints \ + train.system.use_flash_attn=false \ + train.model.attention_backend=unfused \ + train.model.num_layers=8 \ + train.model.hidden_size=1024 \ + train.model.num_attention_heads=16 \ + train.model.seq_length=512 \ + train.model.max_position_embeddings=512 \ + train.model.multiple_of=128 \ + train.model.micro_batch_size=1 \ + train.model.global_batch_size=8 \ + train.model.train_samples=16 \ + ++experiment.runner.enable_perf_monitor=true \ + ++experiment.runner.perf_monitor_interval=5 +``` + +## Expected Result + +- The short training run completes successfully. +- Monitor outputs are written under: + +```bash +/workspace/exp/aquila_perf_smoke_${TS}/logs/monitor +``` + +Typical files: + +- `status.log` +- `host_*_diagnostic.txt` +- `host_*_current.log` + +## Full Run Example + +```bash +TS=$(date +%Y%m%d_%H%M%S) + +python run.py \ + --config-path ./examples/aquila/conf \ + --config-name train \ + action=run \ + experiment.exp_dir=/workspace/exp/aquila_perf_run_${TS} \ + train.system.checkpoint.load=/workspace/exp/__no_ckpt__/does_not_exist \ + train.system.checkpoint.save=/workspace/exp/aquila_perf_run_${TS}/checkpoints \ + train.system.use_flash_attn=false \ + train.model.attention_backend=unfused \ + train.model.num_layers=8 \ + train.model.hidden_size=1024 \ + train.model.num_attention_heads=16 \ + train.model.seq_length=512 \ + train.model.max_position_embeddings=512 \ + train.model.multiple_of=128 \ + train.model.micro_batch_size=1 \ + train.model.global_batch_size=8 \ + train.model.train_samples=1600 \ + ++experiment.runner.enable_perf_monitor=true \ + ++experiment.runner.perf_monitor_interval=5 +``` diff --git a/flagscale/train/perf_monitor/__init__.py b/flagscale/train/perf_monitor/__init__.py new file mode 100644 index 0000000000..bc36983289 --- /dev/null +++ b/flagscale/train/perf_monitor/__init__.py @@ -0,0 +1,20 @@ +"""FlagScale performance monitor utilities.""" + +from .hooks import ( + get_perf_monitor, + initialize_perf_monitor, + perf_monitor_end_iteration, + perf_monitor_end_training, + perf_monitor_start_iteration, +) +from .perf_metrics import FLOPSMeasurementCallback, PerformanceMonitor + +__all__ = [ + "FLOPSMeasurementCallback", + "PerformanceMonitor", + "get_perf_monitor", + "initialize_perf_monitor", + "perf_monitor_end_iteration", + "perf_monitor_end_training", + "perf_monitor_start_iteration", +] diff --git a/flagscale/train/perf_monitor/flops_calculator.py b/flagscale/train/perf_monitor/flops_calculator.py new file mode 100644 index 0000000000..77eb3d90ac --- /dev/null +++ b/flagscale/train/perf_monitor/flops_calculator.py @@ -0,0 +1,63 @@ +"""FLOPS estimation helpers for performance monitoring.""" + +from __future__ import annotations + + +class FLOPSFormulas: + """Collection of approximate transformer FLOPS formulas.""" + + @staticmethod + def attention_flops(batch_size, seq_length, hidden_size, num_attention_heads): + head_dim = hidden_size // max(num_attention_heads, 1) + qkv_flops = 3 * 2 * batch_size * seq_length * hidden_size * hidden_size + score_flops = 2 * batch_size * num_attention_heads * seq_length * seq_length * head_dim + value_flops = 2 * batch_size * num_attention_heads * seq_length * seq_length * head_dim + out_flops = 2 * batch_size * seq_length * hidden_size * hidden_size + return qkv_flops + score_flops + value_flops + out_flops + + @staticmethod + def gqa_attention_flops( + batch_size, seq_length, hidden_size, num_attention_heads, num_query_groups + ): + head_dim = hidden_size // max(num_attention_heads, 1) + kv_hidden_size = head_dim * max(num_query_groups, 1) + q_flops = 2 * batch_size * seq_length * hidden_size * hidden_size + kv_flops = 2 * 2 * batch_size * seq_length * hidden_size * kv_hidden_size + score_flops = 2 * batch_size * num_attention_heads * seq_length * seq_length * head_dim + value_flops = 2 * batch_size * num_attention_heads * seq_length * seq_length * head_dim + out_flops = 2 * batch_size * seq_length * hidden_size * hidden_size + return q_flops + kv_flops + score_flops + value_flops + out_flops + + @staticmethod + def ffn_flops(batch_size, seq_length, hidden_size, ffn_hidden_size, use_swiglu=False): + if use_swiglu: + gate_flops = 2 * batch_size * seq_length * hidden_size * ffn_hidden_size + up_flops = 2 * batch_size * seq_length * hidden_size * ffn_hidden_size + swiglu_flops = batch_size * seq_length * ffn_hidden_size + down_flops = 2 * batch_size * seq_length * ffn_hidden_size * hidden_size + return gate_flops + up_flops + swiglu_flops + down_flops + + up_flops = 2 * batch_size * seq_length * hidden_size * ffn_hidden_size + down_flops = 2 * batch_size * seq_length * ffn_hidden_size * hidden_size + return up_flops + down_flops + + @staticmethod + def moe_flops( + batch_size, + seq_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, + use_swiglu=False, + ): + router_flops = 2 * batch_size * seq_length * hidden_size * num_experts + active_tokens = batch_size * seq_length * top_k + if use_swiglu: + expert_flops = ( + 3 * 2 * active_tokens * hidden_size * ffn_hidden_size + + active_tokens * ffn_hidden_size + ) + else: + expert_flops = 4 * active_tokens * hidden_size * ffn_hidden_size + return router_flops + expert_flops diff --git a/flagscale/train/perf_monitor/hooks.py b/flagscale/train/perf_monitor/hooks.py new file mode 100644 index 0000000000..44027f12f4 --- /dev/null +++ b/flagscale/train/perf_monitor/hooks.py @@ -0,0 +1,45 @@ +"""Training-loop hooks for the performance monitor.""" + +from __future__ import annotations + +import torch + +from flagscale.train.perf_monitor.perf_metrics import FLOPSMeasurementCallback + +_perf_monitor_callback = None + + +def initialize_perf_monitor(args): + """Initialize the global performance monitor if enabled.""" + global _perf_monitor_callback + + if not getattr(args, "enable_perf_monitor", False): + _perf_monitor_callback = None + return None + + log_interval = getattr(args, "perf_log_interval", 10) + _perf_monitor_callback = FLOPSMeasurementCallback(args, log_interval=log_interval) + + if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: + print(f"[Performance Monitor] Initialized with log interval: {log_interval}") + + return _perf_monitor_callback + + +def get_perf_monitor(): + return _perf_monitor_callback + + +def perf_monitor_start_iteration(iteration): + if _perf_monitor_callback is not None: + _perf_monitor_callback.on_train_batch_start(iteration) + + +def perf_monitor_end_iteration(iteration, writer=None, wandb_writer=None): + if _perf_monitor_callback is not None: + _perf_monitor_callback.on_train_batch_end(iteration, writer, wandb_writer) + + +def perf_monitor_end_training(writer=None, wandb_writer=None): + if _perf_monitor_callback is not None: + _perf_monitor_callback.on_train_end(writer, wandb_writer) diff --git a/flagscale/train/perf_monitor/perf_logger.py b/flagscale/train/perf_monitor/perf_logger.py new file mode 100644 index 0000000000..c526c71023 --- /dev/null +++ b/flagscale/train/perf_monitor/perf_logger.py @@ -0,0 +1,152 @@ +"""File logger for performance monitor output.""" + +from __future__ import annotations + +import json +import logging +from datetime import datetime +from pathlib import Path + + +class PerfMonitorLogger: + """Rank-0 logger that writes human-readable and json summaries.""" + + def __init__( + self, + log_dir="logs/perf_monitor", + log_level=logging.INFO, + enable_console=False, + max_log_files=10, + log_format="both", + ): + self.rank = 0 + try: + import torch.distributed as dist + + if dist.is_initialized(): + self.rank = dist.get_rank() + except ImportError: + pass + + self.enabled = self.rank == 0 + if not self.enabled: + return + + self.log_dir = Path(log_dir) + self.log_dir.mkdir(parents=True, exist_ok=True) + self.max_log_files = max_log_files + self.log_format = log_format + self.session_timestamp = datetime.now().astimezone().strftime("%Y%m%d_%H%M%S") + + self.metrics_file = self.log_dir / f"perf_metrics_{self.session_timestamp}.log" + self.summary_file = self.log_dir / f"perf_summary_{self.session_timestamp}.json" + self.realtime_file = self.log_dir / "perf_realtime.log" + + self.logger = logging.getLogger(f"perf_monitor_{self.session_timestamp}") + self.logger.setLevel(log_level) + self.logger.handlers = [] + + if self.log_format in ("text", "both"): + file_handler = logging.FileHandler(self.metrics_file) + file_handler.setFormatter(logging.Formatter("%(message)s")) + self.logger.addHandler(file_handler) + + if enable_console: + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter("%(message)s")) + self.logger.addHandler(console_handler) + + self.json_data = [] + self._write_header() + + def _write_header(self): + if not self.enabled or self.log_format not in ("text", "both"): + return + + header = "=" * 96 + "\n" + header += ( + f"Performance Monitor Session Started: " + f"{datetime.now().astimezone().strftime('%Y-%m-%d %H:%M:%S')}\n" + ) + header += "=" * 96 + "\n" + header += ( + f"{'Timestamp':<20} {'Step':<8} {'TFLOPS/GPU':<12} {'TFLOPS':<10} " + f"{'Samples/s':<12} {'Tokens/s':<12} {'Time(ms)':<10} {'Memory(GB)':<10}\n" + ) + header += "-" * 96 + self.logger.info(header) + self.realtime_file.write_text(f"{header}\n") + + def log_metrics(self, iteration, metrics_dict): + if not self.enabled: + return + + timestamp = datetime.now().astimezone().strftime("%Y-%m-%d %H:%M:%S") + if self.log_format in ("text", "both"): + log_line = ( + f"{timestamp:<20} {iteration:<8} " + f"{metrics_dict.get('TFLOPS_per_GPU', 0.0):<12.2f} " + f"{metrics_dict.get('TFLOPS_total', 0.0):<10.2f} " + f"{metrics_dict.get('samples_per_sec', 0.0):<12.1f} " + f"{metrics_dict.get('tokens_per_sec', 0.0):<12.0f} " + f"{metrics_dict.get('step_time_ms', 0.0):<10.1f} " + f"{metrics_dict.get('memory_GB', 0.0):<10.2f}" + ) + self.logger.info(log_line) + with self.realtime_file.open("a") as file_obj: + file_obj.write(f"{log_line}\n") + + if self.log_format in ("json", "both"): + self.json_data.append( + { + "iteration": iteration, + "timestamp": timestamp, + **metrics_dict, + } + ) + + def log_breakdown(self, iteration, breakdown): + if not self.enabled or self.log_format not in ("text", "both"): + return + + lines = [f"Estimated FLOPS Breakdown (Iteration {iteration}):"] + for key, value in breakdown.items(): + lines.append(f" {key}: {value / 1e12:.2f} TFLOPS") + self.logger.info("\n".join(lines)) + + def save_summary(self, final_stats=None): + if not self.enabled: + return + + summary = { + "session_info": { + "start_time": self.session_timestamp, + "end_time": datetime.now().astimezone().isoformat(), + "total_iterations": len(self.json_data), + }, + "final_statistics": final_stats or {}, + "iteration_logs": self.json_data, + } + if self.log_format in ("json", "both"): + with self.summary_file.open("w") as file_obj: + json.dump(summary, file_obj, indent=2) + self._cleanup_old_logs() + + def _cleanup_old_logs(self): + if not self.enabled or self.max_log_files <= 0: + return + + log_files = sorted(self.log_dir.glob("perf_metrics_*.log")) + if len(log_files) <= self.max_log_files: + return + + for old_file in log_files[: -self.max_log_files]: + try: + old_file.unlink() + summary_file = self.log_dir / ( + f"perf_summary_{old_file.stem.replace('perf_metrics_', '')}.json" + ) + if summary_file.exists(): + summary_file.unlink() + except OSError: + continue diff --git a/flagscale/train/perf_monitor/perf_metrics.py b/flagscale/train/perf_monitor/perf_metrics.py new file mode 100644 index 0000000000..d5a95ef9a2 --- /dev/null +++ b/flagscale/train/perf_monitor/perf_metrics.py @@ -0,0 +1,297 @@ +"""Performance metrics collection for training.""" + +from __future__ import annotations + +import statistics +import time +from dataclasses import dataclass + +import torch + +from flagscale.train.perf_monitor.flops_calculator import FLOPSFormulas +from flagscale.train.perf_monitor.perf_logger import PerfMonitorLogger + +try: + from megatron.core.num_microbatches_calculator import get_num_microbatches +except ImportError: + get_num_microbatches = None + + +@dataclass +class TFLOPSMetrics: + tflops_per_gpu: float = 0.0 + tflops_total: float = 0.0 + model_flops: float = 0.0 + avg_step_time: float = 0.0 + samples_per_second: float = 0.0 + tokens_per_second: float = 0.0 + forward_flops: float = 0.0 + backward_flops: float = 0.0 + optimizer_flops: float = 0.0 + min_step_time: float = float("inf") + max_step_time: float = 0.0 + std_step_time: float = 0.0 + + +class ModelFLOPSCalculator: + """Estimate per-step FLOPS from model hyperparameters.""" + + def __init__(self, args): + self.args = args + self.formulas = FLOPSFormulas() + self.model_type = self._determine_model_type() + + def _determine_model_type(self): + model_type = getattr(self.args, "perf_model_type", "auto") + if model_type != "auto": + return model_type + + model_name = getattr(self.args, "model_name", "") or getattr( + self.args, "wandb_exp_name", "" + ) + model_name = str(model_name).lower() + if "qwen" in model_name: + return "qwen" + if "llama" in model_name: + return "llama" + if "aquila" in model_name: + return "aquila" + if "mixtral" in model_name or getattr(self.args, "num_experts", None): + return "moe" + return "gpt" + + def _get_batch_size(self): + if get_num_microbatches is not None: + num_micro_batches = get_num_microbatches() + else: + num_micro_batches = getattr(self.args, "num_micro_batches", 1) + micro_batch_size = getattr(self.args, "micro_batch_size", 1) + return max(1, micro_batch_size * num_micro_batches) + + def calculate_total_flops(self, batch_size=None): + if batch_size is None: + batch_size = self._get_batch_size() + + seq_length = getattr(self.args, "seq_length", 512) + hidden_size = getattr(self.args, "hidden_size", 768) + num_layers = getattr(self.args, "num_layers", 12) + vocab_size = getattr( + self.args, "vocab_size", getattr(self.args, "padded_vocab_size", 50257) + ) + num_attention_heads = getattr(self.args, "num_attention_heads", 12) + ffn_hidden_size = getattr(self.args, "ffn_hidden_size", 4 * hidden_size) + use_swiglu = getattr(self.args, "swiglu", False) + + if self.model_type in ("llama", "qwen"): + attention_flops = self.formulas.gqa_attention_flops( + batch_size=batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + num_query_groups=getattr(self.args, "num_query_groups", num_attention_heads), + ) + use_swiglu = True + else: + attention_flops = self.formulas.attention_flops( + batch_size=batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + ) + + if self.model_type == "moe": + ffn_flops = self.formulas.moe_flops( + batch_size=batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_experts=getattr(self.args, "num_experts", 8), + top_k=getattr(self.args, "moe_router_topk", 2), + use_swiglu=use_swiglu, + ) + else: + ffn_flops = self.formulas.ffn_flops( + batch_size=batch_size, + seq_length=seq_length, + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + use_swiglu=use_swiglu, + ) + + embedding_flops = 2 * batch_size * seq_length * hidden_size * vocab_size + layer_flops = (attention_flops + ffn_flops) * num_layers + return 3 * (layer_flops + embedding_flops) + + def get_flops_breakdown(self): + batch_size = self._get_batch_size() + total = self.calculate_total_flops(batch_size=batch_size) + forward = total / 3 + backward = 2 * forward + return { + "forward": forward, + "backward": backward, + "optimizer": 0.0, + "total": total, + } + + +class PerformanceMonitor: + """Track step time, estimated FLOPS and throughput.""" + + def __init__(self, args, enable_memory_tracking=True): + self.args = args + self.enable_memory_tracking = enable_memory_tracking + self.iteration_start_time = None + self.step_times = [] + self.current_memory_gb = 0.0 + self.peak_memory_gb = 0.0 + self.metrics = TFLOPSMetrics() + self.flops_calculator = ModelFLOPSCalculator(args) + self.file_logger = PerfMonitorLogger( + log_dir=getattr(args, "perf_log_dir", "logs/perf_monitor"), + enable_console=getattr(args, "perf_console_output", False), + max_log_files=getattr(args, "perf_max_log_files", 10), + log_format=getattr(args, "perf_log_format", "both"), + ) + + def start_iteration(self): + self.iteration_start_time = time.time() + + def end_iteration(self): + if self.iteration_start_time is None: + return + step_time = time.time() - self.iteration_start_time + self.step_times.append(step_time) + self.iteration_start_time = None + self.metrics.min_step_time = min(self.metrics.min_step_time, step_time) + self.metrics.max_step_time = max(self.metrics.max_step_time, step_time) + + def update_memory_stats(self): + if not self.enable_memory_tracking or not torch.cuda.is_available(): + return + self.current_memory_gb = torch.cuda.memory_allocated() / (1024**3) + self.peak_memory_gb = max( + self.peak_memory_gb, torch.cuda.max_memory_allocated() / (1024**3) + ) + + def calculate_metrics(self): + if not self.step_times: + return self.metrics + + half_idx = len(self.step_times) // 2 + recent_times = self.step_times[half_idx:] if half_idx > 0 else self.step_times + avg_step_time = statistics.median(recent_times) + self.metrics.avg_step_time = avg_step_time + self.metrics.std_step_time = ( + statistics.pstdev(recent_times) if len(recent_times) > 1 else 0.0 + ) + + batch_size = self.flops_calculator._get_batch_size() + model_flops = self.flops_calculator.calculate_total_flops(batch_size=batch_size) + self.metrics.model_flops = model_flops + + if avg_step_time > 0: + world_size = max(1, getattr(self.args, "world_size", 1)) + self.metrics.tflops_total = model_flops / (1e12 * avg_step_time) + self.metrics.tflops_per_gpu = self.metrics.tflops_total / world_size + self.metrics.samples_per_second = batch_size / avg_step_time + self.metrics.tokens_per_second = self.metrics.samples_per_second * getattr( + self.args, "seq_length", 0 + ) + + breakdown = self.flops_calculator.get_flops_breakdown() + self.metrics.forward_flops = breakdown.get("forward", 0.0) + self.metrics.backward_flops = breakdown.get("backward", 0.0) + self.metrics.optimizer_flops = breakdown.get("optimizer", 0.0) + return self.metrics + + def log_metrics(self, iteration, writer=None, wandb_writer=None): + metrics = self.calculate_metrics() + metrics_dict = { + "TFLOPS_per_GPU": metrics.tflops_per_gpu, + "TFLOPS_total": metrics.tflops_total, + "samples_per_sec": metrics.samples_per_second, + "tokens_per_sec": metrics.tokens_per_second, + "step_time_ms": metrics.avg_step_time * 1000, + } + if self.enable_memory_tracking: + metrics_dict["memory_GB"] = self.current_memory_gb + metrics_dict["peak_memory_GB"] = self.peak_memory_gb + self.file_logger.log_metrics(iteration, metrics_dict) + + if getattr(self.args, "perf_breakdown", False): + self.file_logger.log_breakdown( + iteration, + { + "forward": metrics.forward_flops, + "backward": metrics.backward_flops, + "optimizer": metrics.optimizer_flops, + "total": metrics.model_flops, + }, + ) + + if writer is not None: + writer.add_scalar("performance/tflops_per_gpu", metrics.tflops_per_gpu, iteration) + writer.add_scalar("performance/tflops_total", metrics.tflops_total, iteration) + writer.add_scalar( + "performance/avg_step_time_ms", metrics.avg_step_time * 1000, iteration + ) + writer.add_scalar( + "performance/samples_per_second", metrics.samples_per_second, iteration + ) + writer.add_scalar("performance/tokens_per_second", metrics.tokens_per_second, iteration) + if self.enable_memory_tracking: + writer.add_scalar("memory/current_gb", self.current_memory_gb, iteration) + writer.add_scalar("memory/peak_gb", self.peak_memory_gb, iteration) + + if wandb_writer is not None: + wandb_writer.log( + { + "performance/tflops_per_gpu": metrics.tflops_per_gpu, + "performance/tflops_total": metrics.tflops_total, + "performance/avg_step_time_ms": metrics.avg_step_time * 1000, + "performance/samples_per_second": metrics.samples_per_second, + "performance/tokens_per_second": metrics.tokens_per_second, + "memory/current_gb": self.current_memory_gb, + "memory/peak_gb": self.peak_memory_gb, + }, + iteration, + ) + + +class FLOPSMeasurementCallback: + """Train-loop callback wrapper around :class:`PerformanceMonitor`.""" + + def __init__(self, args, log_interval=100): + self.args = args + self.log_interval = max(1, log_interval) + self.monitor = PerformanceMonitor( + args, enable_memory_tracking=getattr(args, "perf_memory_tracking", True) + ) + + def on_train_batch_start(self, iteration): + self.monitor.start_iteration() + + def on_train_batch_end(self, iteration, writer=None, wandb_writer=None): + self.monitor.end_iteration() + self.monitor.update_memory_stats() + if iteration > 0 and (iteration == 1 or iteration % self.log_interval == 0): + self.monitor.log_metrics(iteration, writer, wandb_writer) + + def on_train_end(self, writer=None, wandb_writer=None): + if self.monitor.step_times and not self.monitor.file_logger.json_data: + self.monitor.log_metrics(len(self.monitor.step_times), writer, wandb_writer) + + metrics = self.monitor.calculate_metrics() + self.monitor.file_logger.save_summary( + { + "avg_tflops_per_gpu": metrics.tflops_per_gpu, + "avg_tflops_total": metrics.tflops_total, + "avg_step_time_ms": metrics.avg_step_time * 1000, + "min_step_time_ms": metrics.min_step_time * 1000 + if metrics.min_step_time != float("inf") + else 0.0, + "max_step_time_ms": metrics.max_step_time * 1000, + "peak_memory_gb": self.monitor.peak_memory_gb, + } + ) diff --git a/tests/unit_tests/runner/test_backend_megatron_perf.py b/tests/unit_tests/runner/test_backend_megatron_perf.py new file mode 100644 index 0000000000..3e051accfe --- /dev/null +++ b/tests/unit_tests/runner/test_backend_megatron_perf.py @@ -0,0 +1,111 @@ +import os +import sys +import tempfile +import types +from unittest.mock import patch + +from omegaconf import OmegaConf + +hydra_module = types.ModuleType("hydra") +hydra_core_module = types.ModuleType("hydra.core") +hydra_config_module = types.ModuleType("hydra.core.hydra_config") + + +class _HydraConfig: + @staticmethod + def get(): + raise RuntimeError("HydraConfig.get() is not expected in this test") + + +hydra_config_module.HydraConfig = _HydraConfig +sys.modules.setdefault("hydra", hydra_module) +sys.modules.setdefault("hydra.core", hydra_core_module) +sys.modules.setdefault("hydra.core.hydra_config", hydra_config_module) + +from flagscale.runner.backend.backend_megatron import MegatronBackend + + +def _make_config(): + return OmegaConf.create( + { + "experiment": { + "exp_dir": "/tmp/test_exp", + "task": { + "type": "train", + "backend": "megatron", + "entrypoint": "flagscale/train/megatron/train_gpt.py", + }, + "runner": { + "hostfile": None, + "enable_perf_monitor": True, + "perf_log_interval": 5, + }, + "envs": {}, + }, + "train": { + "system": { + "checkpoint": { + "save": "/tmp/test_exp/checkpoints", + "load": "/tmp/test_exp/checkpoints", + }, + "logging": { + "log_dir": "/tmp/test_exp/logs", + "scripts_dir": "/tmp/test_exp/logs/scripts", + "pids_dir": "/tmp/test_exp/logs/pids", + "details_dir": "/tmp/test_exp/logs/details", + "tensorboard_dir": "/tmp/test_exp/tensorboard", + "wandb_save_dir": "/tmp/test_exp/wandb", + "straggler_dir": "/tmp/test_exp/logs/straggler", + }, + "straggler_log_dir": "/tmp/test_exp/logs/straggler", + }, + "model": {}, + "data": {}, + }, + } + ) + + +def test_backend_prepare_copies_perf_monitor_config_from_runner(): + config = _make_config() + with ( + patch("flagscale.runner.backend.backend_megatron._get_args_megatron", return_value=[]), + patch("flagscale.runner.backend.backend_megatron._update_config_train"), + patch("flagscale.runner.backend.backend_megatron.parse_hostfile", return_value=None), + patch("flagscale.runner.backend.backend_megatron.logger"), + ): + backend = MegatronBackend(config) + + assert backend.config.train.system.enable_perf_monitor is True + assert backend.config.train.system.perf_log_interval == 5 + assert backend.config.train.system.perf_log_dir == "/tmp/test_exp/logs/perf_monitor" + + +def test_generate_run_script_creates_perf_log_dir(): + config = _make_config() + with ( + patch("flagscale.runner.backend.backend_megatron._get_args_megatron", return_value=[]), + patch("flagscale.runner.backend.backend_megatron._update_config_train"), + patch("flagscale.runner.backend.backend_megatron.parse_hostfile", return_value=None), + patch("flagscale.runner.backend.backend_megatron.logger"), + ): + backend = MegatronBackend(config) + + with tempfile.TemporaryDirectory() as tmpdir: + config.train.system.logging.scripts_dir = os.path.join(tmpdir, "scripts") + config.train.system.logging.log_dir = os.path.join(tmpdir, "logs") + config.train.system.logging.pids_dir = os.path.join(tmpdir, "pids") + config.train.system.perf_log_dir = os.path.join(tmpdir, "logs", "perf_monitor") + + with ( + patch("os.path.exists", return_value=True), + patch("flagscale.runner.backend.backend_megatron.get_pkg_dir", return_value=tmpdir), + ): + script_path = backend.generate_run_script( + config, "localhost", 0, "python train.py", background=True + ) + + with open(script_path, "r") as file_obj: + content = file_obj.read() + + assert f"mkdir -p {config.train.system.perf_log_dir}" in content diff --git a/tests/unit_tests/runner/test_launcher_ssh_perf.py b/tests/unit_tests/runner/test_launcher_ssh_perf.py new file mode 100644 index 0000000000..a3ad7521fd --- /dev/null +++ b/tests/unit_tests/runner/test_launcher_ssh_perf.py @@ -0,0 +1,39 @@ +from omegaconf import OmegaConf + +from flagscale.runner.launcher.launcher_ssh import _get_runner_cmd_train + + +def test_get_runner_cmd_train_strips_perf_monitor_runner_keys(): + config = OmegaConf.create( + { + "experiment": { + "runner": { + "backend": "torchrun", + "nnodes": 1, + "nproc_per_node": 8, + "rdzv_backend": "static", + "enable_perf_monitor": True, + "perf_log_interval": 5, + "perf_log_dir": "/tmp/perf_monitor", + "perf_console_output": True, + } + }, + "train": { + "system": { + "logging": { + "details_dir": "/tmp/details", + } + } + }, + } + ) + + cmd = _get_runner_cmd_train("localhost", "127.0.0.1", 29500, 1, 0, 8, config) + + assert cmd[0] == "torchrun" + assert "--enable_perf_monitor" not in cmd + assert "--perf_log_interval" not in cmd + assert "--perf_log_dir" not in cmd + assert "--perf_console_output" not in cmd + assert "--log_dir" in cmd + assert "--rdzv_endpoint" in cmd diff --git a/tests/unit_tests/train/test_perf_monitor.py b/tests/unit_tests/train/test_perf_monitor.py new file mode 100644 index 0000000000..bcdf8de137 --- /dev/null +++ b/tests/unit_tests/train/test_perf_monitor.py @@ -0,0 +1,54 @@ +import time +from pathlib import Path +from types import SimpleNamespace + +import pytest + +torch = pytest.importorskip("torch") + +from flagscale.train.perf_monitor.hooks import ( + initialize_perf_monitor, + perf_monitor_end_iteration, + perf_monitor_end_training, + perf_monitor_start_iteration, +) + + +def test_perf_monitor_smoke_writes_summary(tmp_path): + args = SimpleNamespace( + enable_perf_monitor=True, + perf_log_interval=2, + perf_log_dir=str(tmp_path), + perf_console_output=False, + perf_log_format="both", + perf_memory_tracking=False, + perf_breakdown=False, + perf_max_log_files=10, + perf_model_type="gpt", + world_size=1, + seq_length=512, + hidden_size=1024, + num_layers=8, + num_attention_heads=16, + ffn_hidden_size=4096, + padded_vocab_size=50257, + micro_batch_size=1, + num_micro_batches=1, + swiglu=False, + ) + + callback = initialize_perf_monitor(args) + assert callback is not None + + for iteration in range(1, 5): + perf_monitor_start_iteration(iteration) + time.sleep(0.001) + perf_monitor_end_iteration(iteration) + + perf_monitor_end_training() + + realtime_log = Path(tmp_path) / "perf_realtime.log" + summary_files = list(Path(tmp_path).glob("perf_summary_*.json")) + + assert realtime_log.exists() + assert summary_files diff --git a/tools/perf_monitor/perf_smoke.py b/tools/perf_monitor/perf_smoke.py new file mode 100644 index 0000000000..a60290fcef --- /dev/null +++ b/tools/perf_monitor/perf_smoke.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 +"""Standalone smoke test for the FlagScale performance monitor.""" + +from __future__ import annotations + +import argparse +import os +import time +from pathlib import Path +from types import SimpleNamespace + +import torch +import torch.distributed as dist + +from flagscale.train.perf_monitor.hooks import ( + initialize_perf_monitor, + perf_monitor_end_iteration, + perf_monitor_end_training, + perf_monitor_start_iteration, +) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Standalone FlagScale perf monitor smoke test.") + parser.add_argument("--steps", type=int, default=12, help="Synthetic training steps to run.") + parser.add_argument( + "--log-interval", + type=int, + default=5, + help="How often the perf monitor logs metrics.", + ) + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Directory used to save performance monitor logs.", + ) + parser.add_argument("--matmul-size", type=int, default=4096, help="Synthetic workload size.") + parser.add_argument( + "--sleep-ms", + type=float, + default=8.0, + help="Extra per-step CPU sleep to make timings visible.", + ) + parser.add_argument("--seq-length", type=int, default=512, help="Model seq length hint.") + parser.add_argument("--hidden-size", type=int, default=1024, help="Model hidden size hint.") + parser.add_argument("--num-layers", type=int, default=8, help="Model layer count hint.") + parser.add_argument( + "--num-attention-heads", + type=int, + default=16, + help="Model attention head hint.", + ) + parser.add_argument("--micro-batch-size", type=int, default=1, help="Micro batch size hint.") + parser.add_argument( + "--perf-breakdown", + action="store_true", + help="Emit estimated FLOPS breakdown to the text log.", + ) + return parser.parse_args() + + +def init_dist(): + if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: + return 0, 1, 0, False + + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + use_cuda = torch.cuda.is_available() + backend = "nccl" if use_cuda else "gloo" + + if use_cuda: + torch.cuda.set_device(local_rank) + + dist.init_process_group(backend=backend, init_method="env://") + return rank, world_size, local_rank, use_cuda + + +def cleanup_dist(): + if dist.is_available() and dist.is_initialized(): + dist.barrier() + dist.destroy_process_group() + + +def build_perf_args(cli_args, world_size): + return SimpleNamespace( + enable_perf_monitor=True, + perf_log_interval=cli_args.log_interval, + perf_log_dir=str(Path(cli_args.output_dir).expanduser().resolve()), + perf_console_output=True, + perf_log_format="both", + perf_memory_tracking=True, + perf_breakdown=cli_args.perf_breakdown, + perf_max_log_files=10, + perf_model_type="gpt", + world_size=world_size, + seq_length=cli_args.seq_length, + hidden_size=cli_args.hidden_size, + num_layers=cli_args.num_layers, + num_attention_heads=cli_args.num_attention_heads, + ffn_hidden_size=4 * cli_args.hidden_size, + padded_vocab_size=50257, + micro_batch_size=cli_args.micro_batch_size, + num_micro_batches=1, + swiglu=False, + ) + + +def allocate_work_tensors(size, use_cuda, local_rank): + if not use_cuda: + return None, None + device = torch.device("cuda", local_rank) + a = torch.randn(size, size, device=device, dtype=torch.bfloat16) + b = torch.randn(size, size, device=device, dtype=torch.bfloat16) + return a, b + + +def run_step(a, b, sleep_ms, use_cuda): + if use_cuda: + _ = a @ b + _ = a @ b + torch.cuda.synchronize() + if sleep_ms > 0: + time.sleep(sleep_ms / 1000.0) + + +def main(): + cli_args = parse_args() + output_dir = Path(cli_args.output_dir).expanduser().resolve() + output_dir.mkdir(parents=True, exist_ok=True) + + rank, world_size, local_rank, use_cuda = init_dist() + perf_args = build_perf_args(cli_args, world_size) + perf_callback = initialize_perf_monitor(perf_args) + a, b = allocate_work_tensors(cli_args.matmul_size, use_cuda, local_rank) + + if rank == 0: + print( + f"[rank0] Starting perf monitor smoke test: world_size={world_size}, " + f"log_interval={cli_args.log_interval}, output_dir={output_dir}" + ) + + try: + for iteration in range(1, cli_args.steps + 1): + if perf_callback is not None: + perf_monitor_start_iteration(iteration) + run_step(a, b, cli_args.sleep_ms, use_cuda) + if dist.is_available() and dist.is_initialized(): + dist.barrier() + if perf_callback is not None: + perf_monitor_end_iteration(iteration) + finally: + perf_monitor_end_training() + cleanup_dist() + + +if __name__ == "__main__": + main()