Skip to content
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c761be5
llama3 mfu experiment
gagank1 Apr 4, 2026
7d60786
Add CP golden value tests, fix RoPE bug, and improve MFU scripts
gagank1 Apr 6, 2026
aa296da
Switch bandwidth measurement to P2P send/recv
gagank1 Apr 6, 2026
f412e3b
Add CLI to compare_mfu_common.py for standalone utility access
gagank1 Apr 6, 2026
231845e
Generalize MFU/FLOPs module across recipes with log_mfu training hook
gagank1 Apr 11, 2026
3b891f0
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 11, 2026
1930644
Remove first_principles.md from repository
gagank1 Apr 11, 2026
7094d17
Remove first_principles.md reference from flops.py docstring
gagank1 Apr 11, 2026
8ea45ac
Update GPU TFLOPS table: fix RTX values, add B300/GB300, add sources
gagank1 Apr 11, 2026
d853f96
Clean up flops.py: remove dead code from deleted benchmark scripts
gagank1 Apr 13, 2026
1635aab
Add flops tests and move source from models/esm2 to llama3_native_te
gagank1 Apr 13, 2026
27255a1
Add MFU tracking documentation to recipe READMEs
gagank1 Apr 13, 2026
e6468fc
Consolidate MFU tracking into perf_logger, address PR review feedback
gagank1 Apr 18, 2026
e2f4934
MFU: count configured-shape tokens, not attention_mask.sum()
gagank1 Apr 18, 2026
afc8450
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 19, 2026
f4858eb
MFU: use Σ(Lᵢ²) for attention work; fix ESM-2 grad-acc undercount
gagank1 Apr 22, 2026
b6685f5
MFU: fix BSHD+CP attention-FLOP undercount (factor cp²)
gagank1 Apr 22, 2026
ab46239
perf_logger: report true peak memory, not post-step resting
gagank1 Apr 22, 2026
f8f84cb
perf_logger: split MFU into unpadded vs padded variants
gagank1 Apr 23, 2026
909c1d7
perf_logger: remove legacy _compute_per_token_flops back-compat shim
gagank1 Apr 23, 2026
423eab7
docs: update MFU tracking sections in recipe READMEs
gagank1 Apr 23, 2026
b979eed
MFU: use padded_vocab_size for mfu_padded_pct LM-head FLOPs
gagank1 Apr 23, 2026
44172ae
test(esm2): update perf_logger tests for split _attn_work_*_accum buf…
gagank1 Apr 24, 2026
29121fe
docs(perf_logger): note pad_to_multiple_of / cu_seq_lens_q collapse (…
gagank1 Apr 24, 2026
245d6e0
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 24, 2026
74fc4b6
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 27, 2026
ff0410d
esm2: revert log_micro_step split (no grad-acc in ESM-2)
gagank1 Apr 27, 2026
b9f31ae
perf_logger: move log_mfu gate inside PerfLogger across all recipes
gagank1 Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions bionemo-recipes/recipes/codonfm_native_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,25 @@ python train_fsdp2.py \
A final model suitable for uploading to the Hugging Face Hub can be exported at the end of training by setting
`checkpoint.save_final_model=true`.

## MFU Tracking

Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`:

```bash
torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true
```

This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and
stdout:

- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU
- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS

The FLOPs formula auto-detects model architecture from the model config (MHA, standard FFN,
vocabulary size) and scales with the actual unpadded token count on each rank. This means it
naturally handles gradient accumulation, data parallelism, BSHD, and THD (sequence packing)
without per-strategy code paths. The implementation lives in `perf_logger.py`.

## Developer Guide

### Running Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ quant_stats_config:
fp8_layers: null
fp4_layers: null
use_fp32_master_weights: null
log_mfu: false
109 changes: 105 additions & 4 deletions bionemo-recipes/recipes/codonfm_native_te/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,69 @@
PAD_TOKEN_ID = 3


# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list
# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU.
_GPU_PEAK_TFLOPS_BF16 = {
"H100": 989.0,
"H200": 989.0,
"A100": 312.0,
"A6000": 155.0,
"L40": 181.0,
"GH200": 989.0,
"B200": 2250.0,
"GB200": 2250.0,
"B300": 2500.0,
"GB300": 2500.0,
}

# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2.
_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"})


def _detect_peak_tflops_bf16():
"""Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name)."""
if not torch.cuda.is_available():
return None, "unknown"
name = torch.cuda.get_device_name(0)
for key, tflops in _GPU_PEAK_TFLOPS_BF16.items():
if key.lower() in name.lower():
return tflops, name
return None, name


def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int:
"""Training FLOPs per token for a transformer (forward + backward = 3x forward).

First-principles matmul count: Q/K/V/O projections (GQA-aware), attention
logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection
MLP (SwiGLU detected via model_type), and LM head. The returned value is
multiplied by the actual unpadded token count at log time, so it naturally
Comment thread
gagank1 marked this conversation as resolved.
Outdated
handles BSHD, THD (sequence packing), gradient accumulation, DP, and CP:
unpadded tokens on each rank already reflect that rank's share of work.
"""
h = model_config_dict["hidden_size"]
n_heads = model_config_dict["num_attention_heads"]
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
head_dim = h // n_heads
kv_dim = n_kv * head_dim
ffn = model_config_dict["intermediate_size"]
vocab = model_config_dict.get("vocab_size", 0)
num_layers = model_config_dict["num_hidden_layers"]
model_type = model_config_dict.get("model_type", "")
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2

per_layer = (
2 * h * h # Q projection
+ 4 * h * kv_dim # K + V projections (GQA-aware)
+ 2 * h * h # O projection
+ 4 * seq_len * h # attention logits + values (S^2 -> S per token)
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
)
lm_head = 2 * h * vocab if vocab > 0 else 0
per_token_fwd = num_layers * per_layer + lm_head
return 3 * per_token_fwd


class PerfLogger:
"""Performance logger for CodonFM training.

Expand All @@ -44,17 +107,39 @@ class PerfLogger:
Args:
dist_config: The distributed configuration.
args: The Hydra arguments.
model_config_dict: Optional HF-style model config dict. When supplied together with
``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization
(``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step.
"""

def __init__(self, dist_config: DistributedConfig, args: DictConfig):
def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None):
"""Initialize the logger."""
self._dist_config = dist_config
self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}"))
self._device = torch.device(f"cuda:{dist_config.local_rank}")
self.min_loss = torch.tensor(float("inf"), device=self._device)

self.logging_frequency = args.logger.frequency

# MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per
# step are derived at log time from the tracked unpadded token count, which already
# reflects each rank's share under DP and sequence packing.
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
self._per_token_flops = 0
self._peak_tflops: float | None = None
if self._log_mfu:
self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length)
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
if dist_config.local_rank == 0:
logger.info(
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d",
gpu_name,
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
float(self._per_token_flops),
args.dataset.max_seq_length,
)

metrics_dict = {
"train/loss": torchmetrics.MeanMetric(),
"train/grad_norm": torchmetrics.MeanMetric(),
Expand All @@ -66,9 +151,13 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
"train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(),
"train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(),
}
if self._log_mfu:
metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric()
if self._peak_tflops is not None:
metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric()

self.metrics = torchmetrics.MetricCollection(metrics_dict)
self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}"))
self.metrics.to(self._device)
self.previous_step_time = time.perf_counter()

if self._dist_config.is_main_process():
Expand All @@ -79,7 +168,6 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
self.quant_stats_config = args.quant_stats_config.enabled

# Gradient accumulation tracking
self._device = torch.device(f"cuda:{dist_config.local_rank}")
self.num_tokens = 0
self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device)
self.running_loss = torch.tensor(0.0, device=self._device)
Expand Down Expand Up @@ -155,6 +243,19 @@ def log_step(
self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time)
self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time)

if self._log_mfu:
# PaLM/Megatron/MosaicML convention: count the configured-shape token budget
# (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD),
# not attention_mask.sum(). The hardware executes matmuls over every position
# regardless of masking, and this matches published MFU numbers.
# num_tokens is accumulated over the grad-acc micro-batches of one optimizer
# step (the last step in the logging window). step_time is per-step average.
flops_per_step = self._per_token_flops * self.num_tokens
tflops_per_gpu = flops_per_step / step_time / 1e12
self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu)
if self._peak_tflops is not None:
self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0)

memory_allocated = torch.cuda.memory_allocated() / (1024**3)
self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated)
self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated)
Expand Down
6 changes: 5 additions & 1 deletion bionemo-recipes/recipes/codonfm_native_te/train_fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,11 @@ def main(args: DictConfig) -> float | None:
start_step = 0
epoch = 0

perf_logger = PerfLogger(dist_config, args)
perf_logger = PerfLogger(
dist_config,
args,
model_config_dict=config.to_dict() if args.get("log_mfu", False) else None,
)

# Training loop
step = start_step
Expand Down
19 changes: 19 additions & 0 deletions bionemo-recipes/recipes/esm2_native_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,25 @@ output = model(**inputs)

- [ESM-2 Training with Accelerate](../esm2_accelerate_te/README.md)

## MFU Tracking

Enable per-step Model FLOPs Utilization (MFU) logging during training by adding `log_mfu=true`:

```bash
torchrun --nproc_per_node=2 train_fsdp2.py --config-name L1_3B log_mfu=true
```

This adds two metrics at each logging interval, emitted alongside existing metrics via WANDB and
stdout:

- `train/tflops_per_gpu` — achieved BF16 TFLOPS per GPU
- `train/mfu_pct` — MFU as a percentage of the GPU's peak dense BF16 TFLOPS

The FLOPs formula auto-detects model architecture from the HF config (MHA vs. GQA, gated vs.
standard FFN, LM head presence) and scales with the actual unpadded token count on each rank. This
means it naturally handles data parallelism, context parallelism, BSHD, and THD (sequence packing)
without per-strategy code paths. The implementation lives in `perf_logger.py`.

## Developer Guide

### Running Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ use_torch_compile: false

cp_size: 1

log_mfu: false

use_sequence_packing: false
dataset:
tokenizer_name: ${config_name_or_path}
Expand Down
101 changes: 100 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,84 @@
logger = logging.getLogger(__name__)


# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list
# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU.
_GPU_PEAK_TFLOPS_BF16 = {
"H100": 989.0,
"H200": 989.0,
"A100": 312.0,
"A6000": 155.0,
"L40": 181.0,
"GH200": 989.0,
"B200": 2250.0,
"GB200": 2250.0,
"B300": 2500.0,
"GB300": 2500.0,
}

# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2.
_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"})


def _detect_peak_tflops_bf16():
"""Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name)."""
if not torch.cuda.is_available():
return None, "unknown"
name = torch.cuda.get_device_name(0)
for key, tflops in _GPU_PEAK_TFLOPS_BF16.items():
if key.lower() in name.lower():
return tflops, name
return None, name


def _compute_per_token_flops(model_config_dict: dict, seq_len: int) -> int:
"""Training FLOPs per token for a transformer (forward + backward = 3x forward).

First-principles matmul count: Q/K/V/O projections (GQA-aware), attention
logits/values (the S^2 cost expressed per-token as 4*S*H), 2-or-3-projection
MLP (SwiGLU detected via model_type), and LM head. The returned value is
multiplied by the actual unpadded token count at log time, so it naturally
handles BSHD, THD (sequence packing), DP, and CP: unpadded tokens on each
rank already reflect that rank's share of work.
"""
h = model_config_dict["hidden_size"]
n_heads = model_config_dict["num_attention_heads"]
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
head_dim = h // n_heads
kv_dim = n_kv * head_dim
ffn = model_config_dict["intermediate_size"]
vocab = model_config_dict.get("vocab_size", 0)
Comment thread
gagank1 marked this conversation as resolved.
num_layers = model_config_dict["num_hidden_layers"]
model_type = model_config_dict.get("model_type", "")
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2

per_layer = (
2 * h * h # Q projection
+ 4 * h * kv_dim # K + V projections (GQA-aware)
+ 2 * h * h # O projection
+ 4 * seq_len * h # attention logits + values (S^2 -> S per token)
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
)
lm_head = 2 * h * vocab if vocab > 0 else 0
per_token_fwd = num_layers * per_layer + lm_head
return 3 * per_token_fwd


class PerfLogger:
"""Class to log performance metrics to stdout and wandb, and print final averaged metrics at the end of training.

Args:
dist_config: The distributed configuration.
args: The arguments.
model_config_dict: Optional HF-style model config dict. When supplied together with
``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization
(``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step.

Attributes:
min_loss: The minimum loss seen so far.
"""

def __init__(self, dist_config: DistributedConfig, args: DictConfig):
def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None):
"""Initialize the logger."""
self._dist_config = dist_config
self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)
Expand All @@ -53,6 +119,24 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
self.logging_frequency = args.logger.frequency
# Track whether to collect memory stats (disabled by default for max performance)

# MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per
# step are derived at log time from the current batch's unpadded token count, which
# already reflects each rank's share under DP/CP and sequence packing.
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
self._per_token_flops = 0
self._peak_tflops: float | None = None
if self._log_mfu:
self._per_token_flops = _compute_per_token_flops(model_config_dict, args.dataset.max_seq_length)
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
if dist_config.local_rank == 0:
logger.info(
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, per-token FLOPs=%.3e, seq_len=%d",
gpu_name,
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
float(self._per_token_flops),
args.dataset.max_seq_length,
)

metrics_dict = {
"train/loss": torchmetrics.MeanMetric(),
"train/grad_norm": torchmetrics.MeanMetric(),
Expand All @@ -65,6 +149,10 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
"train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(),
"train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(),
}
if self._log_mfu:
metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric()
if self._peak_tflops is not None:
metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric()

self.metrics = torchmetrics.MetricCollection(metrics_dict)
# We move metrics to a GPU device so we can use torch.distributed to aggregate them before logging.
Expand Down Expand Up @@ -124,6 +212,17 @@ def log_step(
self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(num_unpadded_tokens / step_time)
self.metrics["train/total_unpadded_tokens_per_batch"].update(num_unpadded_tokens)

if self._log_mfu:
# PaLM/Megatron/MosaicML convention: count the configured-shape token budget
# (input_ids.numel() = B * S_padded for BSHD, or total packed tokens for THD),
# not the attention-mask count. The hardware executes matmuls over every
# position regardless of masking, and this matches published MFU numbers.
flops_per_step = self._per_token_flops * num_tokens
tflops_per_gpu = flops_per_step / step_time / 1e12
self.metrics["train/tflops_per_gpu"].update(tflops_per_gpu)
if self._peak_tflops is not None:
self.metrics["train/mfu_pct"].update(tflops_per_gpu / self._peak_tflops * 100.0)

# Handle sequence packing for torchmetrics calculation.
if outputs.logits.dim() < 3:
outputs.logits = outputs.logits.unsqueeze(0)
Expand Down
6 changes: 5 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/train_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,11 @@ def main(args: DictConfig) -> float | None:
start_step = 0
epoch = 0

perf_logger = PerfLogger(dist_config, args)
perf_logger = PerfLogger(
dist_config,
args,
model_config_dict=config.to_dict() if args.get("log_mfu", False) else None,
Copy link
Copy Markdown
Collaborator

@pstjohn pstjohn Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

throughout, this could fit on the same line if you did this check inside the perf logger. In general, we want to keep these training scripts as clean as possible.

    perf_logger = PerfLogger(dist_config, args, model_config=config)

inside perf_logger.py:

if args.log_mfu:
    ...

Do we need args.get()? we should just make this default to false in the hydra default.yaml

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in b9f31ae. Train scripts now just call PerfLogger(dist_config, args, model_config_dict=config.to_dict()) — the args.log_mfu gate moved inside PerfLogger.__init__ across all 4 recipes (esm2, llama3, og2, codonfm). Also dropped args.get("log_mfu", False) for args.log_mfu since the default already lives in each recipe's hydra_config/defaults.yaml. 11 train scripts touched in total.

)

# Training loop
step = start_step
Expand Down
6 changes: 5 additions & 1 deletion bionemo-recipes/recipes/esm2_native_te/train_ddp_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def main(args: DictConfig) -> float | None:
start_step = 0
epoch = 0

perf_logger = PerfLogger(dist_config, args)
perf_logger = PerfLogger(
dist_config,
args,
model_config_dict=config.to_dict() if args.get("log_mfu", False) else None,
)

# Training loop
step = start_step
Expand Down
Loading
Loading