Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
c761be5
llama3 mfu experiment
gagank1 Apr 4, 2026
7d60786
Add CP golden value tests, fix RoPE bug, and improve MFU scripts
gagank1 Apr 6, 2026
aa296da
Switch bandwidth measurement to P2P send/recv
gagank1 Apr 6, 2026
f412e3b
Add CLI to compare_mfu_common.py for standalone utility access
gagank1 Apr 6, 2026
231845e
Generalize MFU/FLOPs module across recipes with log_mfu training hook
gagank1 Apr 11, 2026
3b891f0
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 11, 2026
1930644
Remove first_principles.md from repository
gagank1 Apr 11, 2026
7094d17
Remove first_principles.md reference from flops.py docstring
gagank1 Apr 11, 2026
8ea45ac
Update GPU TFLOPS table: fix RTX values, add B300/GB300, add sources
gagank1 Apr 11, 2026
d853f96
Clean up flops.py: remove dead code from deleted benchmark scripts
gagank1 Apr 13, 2026
1635aab
Add flops tests and move source from models/esm2 to llama3_native_te
gagank1 Apr 13, 2026
27255a1
Add MFU tracking documentation to recipe READMEs
gagank1 Apr 13, 2026
e6468fc
Consolidate MFU tracking into perf_logger, address PR review feedback
gagank1 Apr 18, 2026
e2f4934
MFU: count configured-shape tokens, not attention_mask.sum()
gagank1 Apr 18, 2026
afc8450
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 19, 2026
f4858eb
MFU: use Σ(Lᵢ²) for attention work; fix ESM-2 grad-acc undercount
gagank1 Apr 22, 2026
b6685f5
MFU: fix BSHD+CP attention-FLOP undercount (factor cp²)
gagank1 Apr 22, 2026
ab46239
perf_logger: report true peak memory, not post-step resting
gagank1 Apr 22, 2026
f8f84cb
perf_logger: split MFU into unpadded vs padded variants
gagank1 Apr 23, 2026
909c1d7
perf_logger: remove legacy _compute_per_token_flops back-compat shim
gagank1 Apr 23, 2026
423eab7
docs: update MFU tracking sections in recipe READMEs
gagank1 Apr 23, 2026
b979eed
MFU: use padded_vocab_size for mfu_padded_pct LM-head FLOPs
gagank1 Apr 23, 2026
44172ae
test(esm2): update perf_logger tests for split _attn_work_*_accum buf…
gagank1 Apr 24, 2026
29121fe
docs(perf_logger): note pad_to_multiple_of / cu_seq_lens_q collapse (…
gagank1 Apr 24, 2026
245d6e0
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 24, 2026
74fc4b6
Merge remote-tracking branch 'origin/main' into gkaushik/mfu_experiment
gagank1 Apr 27, 2026
ff0410d
esm2: revert log_micro_step split (no grad-acc in ESM-2)
gagank1 Apr 27, 2026
b9f31ae
perf_logger: move log_mfu gate inside PerfLogger across all recipes
gagank1 Apr 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions bionemo-recipes/recipes/codonfm_native_te/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,27 @@ python train_fsdp2.py \
A final model suitable for uploading to the Hugging Face Hub can be exported at the end of training by setting
`checkpoint.save_final_model=true`.

## MFU Tracking

Enable per-step MFU logging by adding `log_mfu=true`:

```bash
torchrun --nproc_per_node=1 train_fsdp2.py --config-name encodon_1b log_mfu=true
```

Two pairs of metrics are emitted per logging interval:

- `train/mfu_pct` / `train/tflops_per_gpu` — useful-work rate. Excludes padding of all kinds.
- `train/mfu_padded_pct` / `train/tflops_per_gpu_padded` — hardware view (HFU-like). Counts
every slot the GPU processes, including BSHD row padding.

Non-attention uses the unpadded/padded token count respectively; attention uses `Σ(Lᵢ²)` from
`cu_seq_lens_q` (THD) or per-row `attention_mask.sum()` (BSHD) for the unpadded variant and
`cu_seq_lens_q_padded` / full `B·S²` for the padded variant. Implementation in `perf_logger.py`.

Memory: `train/gpu_memory_allocated_max_gb` is the true transient peak per window; `_mean_gb` is
the post-step resting footprint.

## Developer Guide

### Running Tests
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,4 @@ quant_stats_config:
fp8_layers: null
fp4_layers: null
use_fp32_master_weights: null
log_mfu: false
244 changes: 237 additions & 7 deletions bionemo-recipes/recipes/codonfm_native_te/perf_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,147 @@
PAD_TOKEN_ID = 3


# Dense BF16 tensor core peak TFLOPS (without sparsity). Product pages often list
# the 2x sparse number; dense = sparse / 2. Sources: NVIDIA datasheets for each GPU.
_GPU_PEAK_TFLOPS_BF16 = {
"H100": 989.0,
"H200": 989.0,
"A100": 312.0,
"A6000": 155.0,
"L40": 181.0,
"GH200": 989.0,
"B200": 2250.0,
"GB200": 2250.0,
"B300": 2500.0,
"GB300": 2500.0,
}

# Model types that use gated MLP (SwiGLU/GeGLU) with 3 projections vs. standard FFN with 2.
_GATED_MLP_MODEL_TYPES = frozenset({"llama", "mistral", "qwen2"})


def _detect_peak_tflops_bf16():
"""Auto-detect dense BF16 peak TFLOPS for the local GPU. Returns (peak, device_name)."""
if not torch.cuda.is_available():
return None, "unknown"
name = torch.cuda.get_device_name(0)
for key, tflops in _GPU_PEAK_TFLOPS_BF16.items():
if key.lower() in name.lower():
return tflops, name
return None, name


def _compute_non_attn_per_token_flops(model_config_dict: dict, use_padded_vocab: bool = False) -> int:
"""Per-token FLOPs for everything EXCEPT the S² attention term.

Q/K/V/O projections (GQA-aware) + MLP + LM head, 3x for fwd+bwd. Multiply by the
actual total token count of the batch to get per-step non-attention FLOPs. Pairs
with ``_compute_attn_flop_coeff``, which contributes the attention term as
``coeff · Σ(Lᵢ²)`` from cu_seq_lens.
"""
h = model_config_dict["hidden_size"]
n_heads = model_config_dict["num_attention_heads"]
n_kv = model_config_dict.get("num_key_value_heads", n_heads)
head_dim = h // n_heads
kv_dim = n_kv * head_dim
ffn = model_config_dict["intermediate_size"]
vocab = model_config_dict.get("vocab_size", 0)
if use_padded_vocab:
# LM-head matmul runs at padded width (e.g. ESM-2: vocab=33 → padded=64 for
# FP8/tensor-core friendliness); logits are sliced back post-matmul.
vocab = model_config_dict.get("padded_vocab_size") or vocab
num_layers = model_config_dict["num_hidden_layers"]
model_type = model_config_dict.get("model_type", "")
num_mlp_proj = 3 if model_type in _GATED_MLP_MODEL_TYPES else 2

per_layer = (
2 * h * h # Q projection
+ 4 * h * kv_dim # K + V projections (GQA-aware)
+ 2 * h * h # O projection
+ 2 * num_mlp_proj * h * ffn # MLP (2 or 3 projections)
)
lm_head = 2 * h * vocab if vocab > 0 else 0
return 3 * (num_layers * per_layer + lm_head)


def _compute_attn_flop_coeff(model_config_dict: dict) -> int:
"""Coefficient K such that per-step attention FLOPs = K · Σ(Lᵢ²) globally.

Per CP rank: ``K · Σ(Lᵢ²) / cp_size`` — each CP rank computes 1/cp_size of each
doc's Lᵢ * Lᵢ score matrix. The 4 counts QK^T (2) + softmax·V (2); the 3 is
fwd+bwd. Hidden size appears linearly because attention is over heads and each
contributes head_dim, and heads * head_dim == h.
"""
h = model_config_dict["hidden_size"]
num_layers = model_config_dict["num_hidden_layers"]
return 3 * num_layers * 4 * h


def _attn_work_from_batch(
batch: dict, device: torch.device, cp_size: int = 1, include_padding: bool = False
) -> torch.Tensor:
"""Return GLOBAL Σ(Lᵢ²) for this batch as an int64 scalar tensor.

The caller divides by cp_size in log_step to convert this global number into
per-rank attention work; this helper always returns a pre-CP-shard quantity.

``include_padding=False`` (default) counts only real tokens — "useful work":
* THD: uses ``cu_seq_lens_q`` (real per-doc lengths, already global).
* BSHD: uses ``attention_mask.sum(dim=-1)`` per row, scaled by ``cp_size²`` to
recover global.

``include_padding=True`` counts padded positions too — "hardware view":
* THD: uses ``cu_seq_lens_q_padded`` (includes CP zigzag-divisibility padding).
* BSHD: uses full ``input_ids.shape``, scaled by ``cp_size²``.

CodonFM currently runs FSDP without CP (cp_size=1), but the formula stays correct
if CP is added later.
Int32 lens cast to int64 BEFORE squaring (overflow at L ≈ 46k otherwise).

NOTE: With the collator's ``pad_to_multiple_of`` option (FP8/FP4 alignment, inlined
in ``CodonTHDCollator.__call__`` in dataset.py), the cu_seq_lens_q tensor is mutated
in place to include one or more appended mock pad sequences and no
``cu_seq_lens_q_padded`` key is written (that key is reserved for TE's per-sequence
CP padding). In that path the unpadded and padded metrics collapse, inflated by
≤``pad_to_multiple_of²`` relative to the real Σ(Lᵢ²) — typically <10⁻⁵ and below
measurement noise. Known limitation; see
https://github.com/NVIDIA/bionemo-framework/issues/1561.
"""
if include_padding:
cu = batch.get("cu_seq_lens_q_padded")
if cu is None:
cu = batch.get("cu_seq_lens_q")
if cu is not None:
lens = (cu[1:] - cu[:-1]).to(torch.int64)
return (lens * lens).sum()
shape = batch["input_ids"].shape
batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1])
return torch.tensor(
batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size,
dtype=torch.int64,
device=device,
)
cu = batch.get("cu_seq_lens_q")
if cu is not None:
lens = (cu[1:] - cu[:-1]).to(torch.int64)
return (lens * lens).sum()
mask = batch.get("attention_mask")
if mask is not None:
per_row_real = mask.sum(dim=-1).to(torch.int64)
return (per_row_real * per_row_real).sum() * cp_size * cp_size
cu = batch.get("cu_seq_lens_q_padded")
if cu is not None:
lens = (cu[1:] - cu[:-1]).to(torch.int64)
return (lens * lens).sum()
shape = batch["input_ids"].shape
batch_size, seq_len_per_rank = int(shape[0]), int(shape[-1])
return torch.tensor(
batch_size * seq_len_per_rank * seq_len_per_rank * cp_size * cp_size,
dtype=torch.int64,
device=device,
)


class PerfLogger:
"""Performance logger for CodonFM training.

Expand All @@ -44,17 +185,49 @@ class PerfLogger:
Args:
dist_config: The distributed configuration.
args: The Hydra arguments.
model_config_dict: Optional HF-style model config dict. When supplied together with
``args.log_mfu`` set to True, the logger computes per-step Model FLOPs Utilization
(``train/mfu_pct``) and throughput (``train/tflops_per_gpu``) on each logging step.
"""

def __init__(self, dist_config: DistributedConfig, args: DictConfig):
def __init__(self, dist_config: DistributedConfig, args: DictConfig, model_config_dict: dict | None = None):
"""Initialize the logger."""
self._dist_config = dist_config
self._run_config = OmegaConf.to_container(args, resolve=True, throw_on_missing=True)

self.min_loss = torch.tensor(float("inf"), device=torch.device(f"cuda:{dist_config.local_rank}"))
self._device = torch.device(f"cuda:{dist_config.local_rank}")
self.min_loss = torch.tensor(float("inf"), device=self._device)

self.logging_frequency = args.logger.frequency

# MFU setup: compute per-token FLOPs and peak TFLOPS once at init. Actual FLOPs per
# step are derived at log time from the tracked unpadded token count, which already
# reflects each rank's share under DP and sequence packing.
self._log_mfu = bool(args.get("log_mfu", False)) and model_config_dict is not None
self._non_attn_per_token_flops = 0
self._non_attn_per_token_flops_padded = 0
self._attn_flop_coeff = 0
self._cp_size = int(args.get("cp_size", 1))
self._peak_tflops: float | None = None
if self._log_mfu:
self._non_attn_per_token_flops = _compute_non_attn_per_token_flops(model_config_dict)
self._non_attn_per_token_flops_padded = _compute_non_attn_per_token_flops(
model_config_dict, use_padded_vocab=True
)
self._attn_flop_coeff = _compute_attn_flop_coeff(model_config_dict)
self._peak_tflops, gpu_name = _detect_peak_tflops_bf16()
if dist_config.local_rank == 0:
logger.info(
"MFU tracking enabled: GPU=%s, peak=%s TFLOPS BF16, "
"non_attn_per_token=%.3e, attn_coeff=%.3e, seq_len=%d, cp_size=%d",
gpu_name,
f"{self._peak_tflops:.1f}" if self._peak_tflops else "unknown",
float(self._non_attn_per_token_flops),
float(self._attn_flop_coeff),
args.dataset.max_seq_length,
self._cp_size,
)

metrics_dict = {
"train/loss": torchmetrics.MeanMetric(),
"train/grad_norm": torchmetrics.MeanMetric(),
Expand All @@ -66,9 +239,18 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
"train/gpu_memory_allocated_max_gb": torchmetrics.MaxMetric(),
"train/gpu_memory_allocated_mean_gb": torchmetrics.MeanMetric(),
}
if self._log_mfu:
# Two TFLOPS/MFU pairs:
# * tflops_per_gpu / mfu_pct — useful work only (no padding)
# * tflops_per_gpu_padded / mfu_padded_pct — hardware view (counts padding slots)
metrics_dict["train/tflops_per_gpu"] = torchmetrics.MeanMetric()
metrics_dict["train/tflops_per_gpu_padded"] = torchmetrics.MeanMetric()
if self._peak_tflops is not None:
metrics_dict["train/mfu_pct"] = torchmetrics.MeanMetric()
metrics_dict["train/mfu_padded_pct"] = torchmetrics.MeanMetric()

self.metrics = torchmetrics.MetricCollection(metrics_dict)
self.metrics.to(torch.device(f"cuda:{dist_config.local_rank}"))
self.metrics.to(self._device)
self.previous_step_time = time.perf_counter()

if self._dist_config.is_main_process():
Expand All @@ -79,9 +261,13 @@ def __init__(self, dist_config: DistributedConfig, args: DictConfig):
self.quant_stats_config = args.quant_stats_config.enabled

# Gradient accumulation tracking
self._device = torch.device(f"cuda:{dist_config.local_rank}")
self.num_tokens = 0
self.num_unpadded_tokens = torch.tensor(0, dtype=torch.int64, device=self._device)
# Σ(Lᵢ²) over grad-acc micro-batches — two flavors:
# unpadded: only real tokens (useful work), drives mfu_pct
# padded: all slots including CP-zigzag / BSHD row padding, drives mfu_padded_pct
self._attn_work_unpadded_accum = torch.tensor(0, dtype=torch.int64, device=self._device)
self._attn_work_padded_accum = torch.tensor(0, dtype=torch.int64, device=self._device)
self.running_loss = torch.tensor(0.0, device=self._device)
self.grad_acc_step_count = 0

Expand All @@ -103,6 +289,15 @@ def log_micro_step(self, step: int, batch: dict[str, torch.Tensor], outputs: Mas
self.num_tokens += batch["input_ids"].numel()
num_unpadded_tokens = batch["input_ids"][batch["input_ids"] != PAD_TOKEN_ID].numel()
self.num_unpadded_tokens += num_unpadded_tokens
if self._log_mfu:
# Accumulate both unpadded (useful) and padded (hardware) Σ(Lᵢ²).
# Helper returns a GLOBAL value (pre-CP-shard); log_step divides by cp_size.
self._attn_work_unpadded_accum += _attn_work_from_batch(
batch, self._device, self._cp_size, include_padding=False
)
self._attn_work_padded_accum += _attn_work_from_batch(
batch, self._device, self._cp_size, include_padding=True
)

# Update perplexity per micro-batch since it needs logits + labels
logits = outputs.logits
Expand Down Expand Up @@ -155,9 +350,42 @@ def log_step(
self.metrics["train/tokens_per_second_per_gpu"].update(self.num_tokens / step_time)
self.metrics["train/unpadded_tokens_per_second_per_gpu"].update(self.num_unpadded_tokens / step_time)

memory_allocated = torch.cuda.memory_allocated() / (1024**3)
self.metrics["train/gpu_memory_allocated_max_gb"].update(memory_allocated)
self.metrics["train/gpu_memory_allocated_mean_gb"].update(memory_allocated)
if self._log_mfu:
# Two MFU flavors reported side-by-side:
# mfu_pct = useful-work rate. Non-attn over real tokens,
# attn over real Σ(Lᵢ²). Drops both padding types.
# mfu_padded_pct = hardware view. Non-attn over all slots, attn over
# padded Σ(Lᵢ²) (includes CP zigzag + BSHD row pad).
attn_unpadded = int(self._attn_work_unpadded_accum.item())
attn_padded = int(self._attn_work_padded_accum.item())
num_unpadded = int(self.num_unpadded_tokens.item())

non_attn_unpadded = self._non_attn_per_token_flops * num_unpadded
attn_flops_unpadded = (self._attn_flop_coeff * attn_unpadded) // self._cp_size
flops_unpadded = non_attn_unpadded + attn_flops_unpadded
tflops_unpadded = flops_unpadded / step_time / 1e12

non_attn_padded = self._non_attn_per_token_flops_padded * self.num_tokens
attn_flops_padded = (self._attn_flop_coeff * attn_padded) // self._cp_size
flops_padded = non_attn_padded + attn_flops_padded
tflops_padded = flops_padded / step_time / 1e12

self.metrics["train/tflops_per_gpu"].update(tflops_unpadded)
self.metrics["train/tflops_per_gpu_padded"].update(tflops_padded)
if self._peak_tflops is not None:
self.metrics["train/mfu_pct"].update(tflops_unpadded / self._peak_tflops * 100.0)
self.metrics["train/mfu_padded_pct"].update(tflops_padded / self._peak_tflops * 100.0)

# Report TRUE peak memory across the logging window (FSDP-gathered params +
# activations held for backward), not just the post-step resting footprint.
# Reset the peak counter so each window reports its own peak instead of a
# running max since process start. Both calls are pure host-side counter ops
# -- no sync, no kernel launch.
peak_gb = torch.cuda.max_memory_allocated() / (1024**3)
current_gb = torch.cuda.memory_allocated() / (1024**3)
torch.cuda.reset_peak_memory_stats()
self.metrics["train/gpu_memory_allocated_max_gb"].update(peak_gb)
self.metrics["train/gpu_memory_allocated_mean_gb"].update(current_gb)

metrics = self.metrics.compute()
self.metrics.reset()
Expand All @@ -179,6 +407,8 @@ def log_step(
self.running_loss.zero_()
self.num_tokens = 0
self.num_unpadded_tokens.zero_()
self._attn_work_unpadded_accum.zero_()
self._attn_work_padded_accum.zero_()
self.grad_acc_step_count = 0

def finish(self):
Expand Down
Loading
Loading