Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c2852f3
feat(engine): LinearCrossEntropy
May 8, 2026
fc5211b
fix(kernel): continus
May 8, 2026
e494738
test(linear_cross_entropy): add test for tp > 1
May 8, 2026
9495141
test(linear_cross_entropy): log test
May 8, 2026
77ea280
refactor(test): remove useless test code
May 8, 2026
47fd3e2
perf: NVTX
May 8, 2026
203a7e4
feat(config): add use_fused_linear_ce config
May 8, 2026
e2e5d70
fix(utils): network
May 8, 2026
fecfcbd
feat(profiling):nsys flush
May 8, 2026
0b5eefe
fix(engine): fix
May 8, 2026
e69674e
feat(profiler): torch profile
May 8, 2026
d444cbe
fix(sequence_parallel): fix sp
May 8, 2026
07f03d8
fix(engine): dtype
May 8, 2026
003296b
fix(engine): dtype again
May 8, 2026
0020faa
test(linear_cross_entropy): fix test
May 9, 2026
fa82bb9
perf(benchmark): benchmark
May 9, 2026
4e9f8c7
refactor(fsdp): remove useless
May 9, 2026
64b91a4
refactor: remove test profile
May 9, 2026
92c4298
refactor(kernel): fix
May 9, 2026
c640e03
fix(kernel): fix code
May 10, 2026
8b1a243
refactor(kernel): fix
May 10, 2026
a6952f3
refactor(kernel): remove useless
May 10, 2026
6346396
refactor(kernel): comment
May 10, 2026
0b044e5
docs(kernels): fix
May 10, 2026
5e35bbe
feat(test): fix
May 10, 2026
09eff54
fix(engine): fix
May 10, 2026
825d44c
fix(megatron): fix vocab
May 10, 2026
8aaff93
feat(engine): fix conflict
May 10, 2026
292dab1
Merge branch 'main' into lm
TaoZex May 10, 2026
8addb48
Merge branch 'main' into lm
TaoZex May 10, 2026
c5525d6
feat: precommit fix
May 10, 2026
9a6b467
feat: fix by comment
TaoZex May 14, 2026
7fcc99b
feat: fix
TaoZex May 14, 2026
fa91e9d
Merge branch 'main' into lm
TaoZex May 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1137,6 +1137,14 @@ class TrainEngineConfig:
default=False,
metadata={"help": "Enable tree training with flex attention module."},
)
use_fused_linear_ce: bool = field(
default=False,
Comment thread
TaoZex marked this conversation as resolved.
Outdated
metadata={
"help": "Fuse the linear projection with cross-entropy so that the "
"[num_tokens, vocab_size] logits tensor is never materialised. "
"Only effective for the Megatron actor backend with parallel_output=True."
},
)

# Scheduling
scheduling_spec: tuple[SchedulingSpec, ...] = field(
Expand Down
121 changes: 105 additions & 16 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@
from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager
from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms
from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper
from areal.engine.megatron_utils.fused_lce_capture import (
FUSED_LCE_HIDDEN_KEY,
FUSED_LCE_WEIGHT_KEY,
capture_lm_head_hidden,
)
from areal.engine.megatron_utils.megatron import (
all_gather_param,
convert_to_hf,
Expand Down Expand Up @@ -106,7 +111,12 @@
split_padded_tensor_dict_into_mb_list,
unpad_logits,
)
from areal.utils.functional import gather_logprobs, gather_logprobs_entropy
from areal.utils.functional import (
gather_logprobs,
gather_logprobs_entropy,
linear_cross_entropy_logprobs,
linear_cross_entropy_logprobs_entropy,
)
from areal.utils.hf_utils import load_hf_tokenizer
from areal.utils.lock import DistributedLock
from areal.utils.network import find_free_ports, format_host_for_url, gethostip
Expand Down Expand Up @@ -710,6 +720,12 @@ def forward_backward_batch(
) -> None:
self._ensure_ready()

use_fused_lce = (
getattr(self.config, "use_fused_linear_ce", False)
and not self.config.is_critic
and not self.enable_tree_training
)

def forward_step(batch_iter, model):
mb_input: MicroBatchItem = next(batch_iter)

Expand Down Expand Up @@ -740,12 +756,33 @@ def forward_step(batch_iter, model):
cp_size = mpu.get_context_parallel_world_size()
cp_local = cp_size > 1

output = packed_context_parallel_forward(
model,
mb_input.padded_mb,
gather_cp_output=not cp_local,
model_vp_stage_for_capture = getattr(model, "vp_stage", 0)
should_capture = (
use_fused_lce
and mpu.is_pipeline_last_stage(
ignore_virtual=False, vp_stage=model_vp_stage_for_capture
)
and not cp_local
)

with capture_lm_head_hidden(
model, enabled=should_capture
) as capture:
output = packed_context_parallel_forward(
model,
mb_input.padded_mb,
gather_cp_output=not cp_local,
)

if (
capture is not None
and capture.hidden is not None
and capture.weight is not None
):
mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden
mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight
mb_input.orig_mb["_fused_lce_active"] = True

# Release tree attention metadata after forward pass
for key in tree_attn_keys:
del mb_input.padded_mb[key]
Expand Down Expand Up @@ -784,6 +821,17 @@ def _process_output(input_, output_):
cu_seqlens=cu_seqlens,
old_cu_seqlens=mb_input.old_cu_seqlens,
)
# Re-align Float16Module's fp32 hidden to lm-head weight dtype.
if mb_input.orig_mb.get("_fused_lce_active", False):
fused_weight = mb_input.orig_mb.get(
FUSED_LCE_WEIGHT_KEY
)
if (
fused_weight is not None
and output.dtype != fused_weight.dtype
):
output = output.to(fused_weight.dtype)
mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output
Comment on lines +919 to +925
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we usually require fp32 logits, will this downcast operation cause a precision issue?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fused LCE kernel internally accumulates the matrix multiplication in fp32. Therefore, even with bf16 input hidden states, the precision of the logits and log-softmax computations within the kernel remains fully preserved in fp32.

In practice, the non-fused computation path follows:
bf16 hiddenbf16 matmulbf16 logitsfp32 logits (upcast by Float16Module) → fp32 log-softmax.

In contrast, the fused path maintains fp32 accumulation throughout the entire computation, ensuring its numerical precision is at least on par with, if not better than, the non-fused baseline.

return output, functools.partial(_process_output, mb_input.orig_mb)

forward_backward_func = get_forward_backward_func()
Expand Down Expand Up @@ -845,7 +893,9 @@ def process_output(
)

# Step 4: Optimizer step
return self.optimizer_step()
result = self.optimizer_step()

return result

@torch.no_grad()
def eval_batch(
Expand Down Expand Up @@ -1814,16 +1864,37 @@ def _compute_logprobs_and_loss(
labels = cp_local_labels
else:
labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1)
logprobs, entropy = gather_logprobs_entropy(
output,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
vocab_min_logits = output.detach().min(-1).values.float()
vocab_max_logits = output.detach().max(-1).values.float()
fused_active = inputs.get("_fused_lce_active", False)
fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY)
fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY)
if (
fused_active
and fused_hidden is not None
and fused_weight is not None
):
Comment thread
TaoZex marked this conversation as resolved.
Outdated
logprobs, entropy = linear_cross_entropy_logprobs_entropy(
fused_hidden,
fused_weight,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
proxy = logprobs.detach().float()
vocab_min_logits = proxy
vocab_max_logits = proxy
Comment thread
TaoZex marked this conversation as resolved.
Outdated
else:
logprobs, entropy = gather_logprobs_entropy(
output,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
vocab_min_logits = output.detach().min(-1).values.float()
vocab_max_logits = output.detach().max(-1).values.float()
loss = loss_fn(
logprobs,
entropy,
Expand Down Expand Up @@ -1860,6 +1931,24 @@ def _compute_forward_result(
)
return logprobs
labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1)
fused_active = inputs.get("_fused_lce_active", False)
fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY)
fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY)
if (
fused_active
and fused_hidden is not None
and fused_weight is not None
):
logprobs = linear_cross_entropy_logprobs(
fused_hidden,
fused_weight,
labels,
temperature=self.config.temperature,
tp_group=mpu.get_tensor_model_parallel_group()
if mpu.get_tensor_model_parallel_world_size() > 1
else None,
)
return logprobs
logprobs = gather_logprobs(
output,
labels,
Expand Down
132 changes: 132 additions & 0 deletions areal/engine/megatron_utils/fused_lce_capture.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
# SPDX-License-Identifier: Apache-2.0
"""
LM-head hidden-state capture for the fused linear-cross-entropy fast path.

The fused LCE kernel needs ``(hidden, weight)`` instead of materialised
``[seq, vocab]`` logits. This module temporarily monkey-patches
``output_layer.forward`` to capture those tensors for one microbatch.

Compatibility: incompatible with MuP (``use_mup``), MTP
(``mtp_num_layers > 0``), and critic heads. The engine falls back to
the materialised path automatically when any of these conditions hold.
"""

from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from dataclasses import dataclass

import torch
from megatron.core import parallel_state as mpu
from megatron.core.tensor_parallel.mappings import (
gather_from_sequence_parallel_region,
)

from areal.utils import logging

logger = logging.getLogger("FusedLCECapture")

FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden"
FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight"


@dataclass
class _CaptureSlot:
hidden: torch.Tensor | None = None
weight: torch.Tensor | None = None


def _unwrap_to_post_process_module(model: torch.nn.Module) -> torch.nn.Module | None:
inner = model
for _ in range(8):
if hasattr(inner, "output_layer") and inner.output_layer is not None:
return inner
if not hasattr(inner, "module"):
return None
inner = inner.module
return None


def _is_compatible(post_process_module: torch.nn.Module) -> bool:
config = getattr(post_process_module, "config", None)
if config is None:
return False

if getattr(config, "use_mup", False):
logger.warning(
"Fused LCE disabled: MuP scaling is enabled (config.use_mup=True)."
)
return False
if getattr(config, "mtp_num_layers", 0):
logger.warning(
"Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)."
)
return False

output_layer = getattr(post_process_module, "output_layer", None)
if output_layer is None:
return False

parallel_output = getattr(post_process_module, "parallel_output", True)
if not parallel_output:
logger.warning(
"Fused LCE disabled: model has parallel_output=False; "
"would require an extra TP gather."
)
return False

return True


@contextmanager
def capture_lm_head_hidden(
model: torch.nn.Module, *, enabled: bool
) -> Iterator[_CaptureSlot | None]:
if not enabled:
yield None
return

post_process = _unwrap_to_post_process_module(model)
if post_process is None or not _is_compatible(post_process):
yield None
return

output_layer = post_process.output_layer
slot = _CaptureSlot()
original_forward = output_layer.forward

config = getattr(post_process, "config", None)
sequence_parallel = bool(getattr(config, "sequence_parallel", False))
tp_world_size = mpu.get_tensor_model_parallel_world_size()
needs_sp_gather = sequence_parallel and tp_world_size > 1

def _patched_forward(input_, weight=None, runtime_gather_output=None):
actual_weight = weight if weight is not None else output_layer.weight

hidden = input_
if needs_sp_gather:
hidden = gather_from_sequence_parallel_region(hidden)

if hidden.dtype != actual_weight.dtype:
hidden = hidden.to(actual_weight.dtype)

slot.hidden = hidden
slot.weight = actual_weight
return hidden, None

output_layer.forward = _patched_forward # type: ignore[assignment]
try:
yield slot
finally:
try:
del output_layer.forward
except AttributeError:
output_layer.forward = original_forward # type: ignore[assignment]


__all__ = [
"FUSED_LCE_HIDDEN_KEY",
"FUSED_LCE_WEIGHT_KEY",
"capture_lm_head_hidden",
]
7 changes: 7 additions & 0 deletions areal/utils/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
reward_overlong_penalty,
sapo_loss_fn,
)
from areal.utils.functional.linear_cross_entropy import (
linear_cross_entropy_logprobs,
linear_cross_entropy_logprobs_entropy,
)
from areal.utils.functional.vocab_parallel import (
gather_logprobs,
gather_logprobs_entropy,
Expand All @@ -30,4 +34,7 @@
# vocab_parallel.py
"gather_logprobs",
"gather_logprobs_entropy",
# linear_cross_entropy.py (fused linear + CE/entropy via Triton)
"linear_cross_entropy_logprobs",
"linear_cross_entropy_logprobs_entropy",
]
Loading