diff --git a/areal/api/cli_args.py b/areal/api/cli_args.py index ea55f557a8..3681bd2458 100644 --- a/areal/api/cli_args.py +++ b/areal/api/cli_args.py @@ -929,6 +929,15 @@ class MegatronEngineConfig: }, ) + use_fused_linear_ce: bool = field( + default=False, + metadata={ + "help": "Fuse the linear projection with cross-entropy so that the " + "[num_tokens, vocab_size] logits tensor is never materialised. " + "Only effective for the Megatron actor backend with parallel_output=True." + }, + ) + class SchedulingStrategyType(str, Enum): separation = "separation" diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index a512469bc0..4f11162964 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -64,6 +64,11 @@ from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper +from areal.engine.megatron_utils.fused_lce_capture import ( + FUSED_LCE_HIDDEN_KEY, + FUSED_LCE_WEIGHT_KEY, + capture_lm_head_hidden, +) from areal.engine.megatron_utils.megatron import ( all_gather_param, convert_to_hf, @@ -83,6 +88,10 @@ ) from areal.infra.dist_rollout import DistRolloutCoordinator from areal.infra.platforms import current_platform +from areal.models.kernel import ( + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) from areal.models.mcore.hf_load import load_weights_from_hf_with_mbridge_fast from areal.models.mcore.hf_save import ( save_critic_value_head, @@ -117,7 +126,10 @@ split_padded_tensor_dict_into_mb_list, unpad_logits, ) -from areal.utils.functional import gather_logprobs, gather_logprobs_entropy +from areal.utils.functional import ( + gather_logprobs, + gather_logprobs_entropy, +) from areal.utils.hf_utils import load_hf_processor_and_tokenizer, load_hf_tokenizer from areal.utils.lock import DistributedLock from areal.utils.network import find_free_ports, format_host_for_url, gethostip @@ -805,6 +817,12 @@ def forward_backward_batch( ) -> None: self._ensure_ready() + use_fused_lce = ( + getattr(self.config.megatron, "use_fused_linear_ce", False) + and not self.config.is_critic + and not self.enable_tree_training + ) + def forward_step(batch_iter, model): mb_input: MicroBatchItem = next(batch_iter) @@ -835,13 +853,32 @@ def forward_step(batch_iter, model): cp_size = mpu.get_context_parallel_world_size() cp_local = cp_size > 1 - output = packed_context_parallel_forward( - model, - mb_input.padded_mb, - gather_cp_output=not cp_local, - is_vision_model=self.is_vision_model, + model_vp_stage_for_capture = getattr(model, "vp_stage", 0) + should_capture = ( + use_fused_lce + and mpu.is_pipeline_last_stage( + ignore_virtual=False, vp_stage=model_vp_stage_for_capture + ) + and not cp_local ) + with capture_lm_head_hidden(model, enabled=should_capture) as capture: + output = packed_context_parallel_forward( + model, + mb_input.padded_mb, + gather_cp_output=not cp_local, + is_vision_model=self.is_vision_model, + ) + + if ( + capture is not None + and capture.hidden is not None + and capture.weight is not None + ): + mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden + mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight + mb_input.orig_mb["_fused_lce_active"] = True + # Release tree attention metadata after forward pass for key in tree_attn_keys: del mb_input.padded_mb[key] @@ -877,6 +914,15 @@ def _process_output(input_, output_): cu_seqlens=cu_seqlens, old_cu_seqlens=mb_input.old_cu_seqlens, ) + # Re-align Float16Module's fp32 hidden to lm-head weight dtype. + if mb_input.orig_mb.get("_fused_lce_active", False): + fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY) + if ( + fused_weight is not None + and output.dtype != fused_weight.dtype + ): + output = output.to(fused_weight.dtype) + mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output return output, functools.partial(_process_output, mb_input.orig_mb) forward_backward_func = get_forward_backward_func() @@ -2063,64 +2109,9 @@ def _compute_logprobs_and_loss( else None, ) else: - cp_local_labels = inputs.get("_cp_local_labels") - cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens") - if cp_local_labels is not None: - labels = cp_local_labels - else: - labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) - logprobs, entropy = gather_logprobs_entropy( - output, - labels, - temperature=self.config.temperature, - tp_group=mpu.get_tensor_model_parallel_group() - if mpu.get_tensor_model_parallel_world_size() > 1 - else None, + logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs = ( + self._compute_packed_logprobs_entropy(output, inputs) ) - vocab_min_logits = output.detach().min(-1).values.float() - vocab_max_logits = output.detach().max(-1).values.float() - if cp_padded_cu_seqlens is not None: - logprobs = reassemble_cp_packed_logprobs( - logprobs, cp_padded_cu_seqlens - ) - entropy = reassemble_cp_packed_logprobs( - entropy, cp_padded_cu_seqlens - ) - vocab_min_logits = reassemble_cp_packed_logprobs( - vocab_min_logits, cp_padded_cu_seqlens - ) - vocab_max_logits = reassemble_cp_packed_logprobs( - vocab_max_logits, cp_padded_cu_seqlens - ) - cp_padding_length = inputs.get("_cp_padding_length", 0) - cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") - logprobs = unpad_logits( - logprobs, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - entropy = unpad_logits( - entropy, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - vocab_min_logits = unpad_logits( - vocab_min_logits, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - vocab_max_logits = unpad_logits( - vocab_max_logits, - cp_padding_length, - cp_padded_cu_seqlens, - cp_old_cu_seqlens, - ) - inputs = { - k: v for k, v in inputs.items() if not k.startswith("_cp_") - } loss = loss_fn( logprobs, entropy, @@ -2135,6 +2126,99 @@ def _compute_logprobs_and_loss( loss_scale = local_weight / total_loss_weight * loss_multiplier return loss * loss_scale + def _compute_packed_logprobs_entropy( + self, + output: torch.Tensor, + inputs: dict[str, Any], + ) -> tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor | None, + torch.Tensor | None, + dict[str, Any], + ]: + """Compute per-token logprobs/entropy for the non-tree packed path. + + Returns ``(logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs)``. + ``inputs`` is returned because the materialised CP branch strips the + ``_cp_*`` keys before the loss is invoked. + """ + cp_local_labels = inputs.get("_cp_local_labels") + cp_padded_cu_seqlens = inputs.get("_cp_padded_cu_seqlens") + if cp_local_labels is not None: + labels = cp_local_labels + else: + labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) + + tp_group = ( + mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None + ) + + # Fused LCE fast path: logits are never materialised, so we skip the + # min telemetry rather than report a misleading proxy. + fused_active = inputs.get("_fused_lce_active", False) + fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) + fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) + if fused_active and fused_hidden is not None and fused_weight is not None: + logprobs, entropy, vocab_max_logits = linear_cross_entropy_logprobs_entropy( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=tp_group, + return_max_logits=True, + ) + return logprobs, entropy, None, vocab_max_logits, inputs + + # Materialised path. + logprobs, entropy = gather_logprobs_entropy( + output, + labels, + temperature=self.config.temperature, + tp_group=tp_group, + ) + vocab_min_logits = output.detach().min(-1).values.float() + vocab_max_logits = output.detach().max(-1).values.float() + if cp_padded_cu_seqlens is not None: + logprobs = reassemble_cp_packed_logprobs(logprobs, cp_padded_cu_seqlens) + entropy = reassemble_cp_packed_logprobs(entropy, cp_padded_cu_seqlens) + vocab_min_logits = reassemble_cp_packed_logprobs( + vocab_min_logits, cp_padded_cu_seqlens + ) + vocab_max_logits = reassemble_cp_packed_logprobs( + vocab_max_logits, cp_padded_cu_seqlens + ) + cp_padding_length = inputs.get("_cp_padding_length", 0) + cp_old_cu_seqlens = inputs.get("_cp_old_cu_seqlens") + logprobs = unpad_logits( + logprobs, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + entropy = unpad_logits( + entropy, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_min_logits = unpad_logits( + vocab_min_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + vocab_max_logits = unpad_logits( + vocab_max_logits, + cp_padding_length, + cp_padded_cu_seqlens, + cp_old_cu_seqlens, + ) + inputs = {k: v for k, v in inputs.items() if not k.startswith("_cp_")} + return logprobs, entropy, vocab_min_logits, vocab_max_logits, inputs + def _compute_forward_result( self, output: torch.Tensor, @@ -2157,6 +2241,20 @@ def _compute_forward_result( ) return logprobs labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) + fused_active = inputs.get("_fused_lce_active", False) + fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) + fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) + if fused_active and fused_hidden is not None and fused_weight is not None: + logprobs = linear_cross_entropy_logprobs( + fused_hidden, + fused_weight, + labels, + temperature=self.config.temperature, + tp_group=mpu.get_tensor_model_parallel_group() + if mpu.get_tensor_model_parallel_world_size() > 1 + else None, + ) + return logprobs logprobs = gather_logprobs( output, labels, diff --git a/areal/engine/megatron_utils/fused_lce_capture.py b/areal/engine/megatron_utils/fused_lce_capture.py new file mode 100644 index 0000000000..fd58d83c7f --- /dev/null +++ b/areal/engine/megatron_utils/fused_lce_capture.py @@ -0,0 +1,183 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +LM-head hidden-state capture for the fused linear-cross-entropy fast path. + +The fused LCE kernel needs ``(hidden, weight)`` instead of materialised +``[seq, vocab]`` logits. This module temporarily monkey-patches +``output_layer.forward`` to capture those tensors for one microbatch. + +Compatibility: incompatible with MuP (``use_mup``), MTP +(``mtp_num_layers > 0``), critic heads, and hidden sizes that do not +satisfy the fused-kernel alignment requirement. The engine falls back +to the materialised path automatically when any of these conditions hold. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass + +import torch +from megatron.core import parallel_state as mpu +from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) + +from areal.utils import logging + +logger = logging.getLogger("FusedLCECapture") + +FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden" +FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight" +_HIDDEN_SIZE_ALIGNMENT = 128 +_WARNED_INCOMPATIBILITIES: set[str] = set() + + +@dataclass +class _CaptureSlot: + hidden: torch.Tensor | None = None + weight: torch.Tensor | None = None + + +def _unwrap_to_post_process_module(model: torch.nn.Module) -> torch.nn.Module | None: + inner = model + for _ in range(8): + if hasattr(inner, "output_layer") and inner.output_layer is not None: + return inner + if not hasattr(inner, "module"): + return None + inner = inner.module + return None + + +def _warn_incompatible_once(key: str, message: str, *args: object) -> None: + if key in _WARNED_INCOMPATIBILITIES: + return + _WARNED_INCOMPATIBILITIES.add(key) + logger.warning(message, *args) + + +def _get_lm_head_hidden_size( + config: object, + output_layer: torch.nn.Module, +) -> int | None: + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is not None: + return int(hidden_size) + + weight = getattr(output_layer, "weight", None) + if weight is not None and hasattr(weight, "shape") and len(weight.shape) > 0: + return int(weight.shape[-1]) + + return None + + +def _is_compatible(post_process_module: torch.nn.Module) -> bool: + config = getattr(post_process_module, "config", None) + if config is None: + return False + + if getattr(config, "use_mup", False): + _warn_incompatible_once( + "use_mup", + "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True).", + ) + return False + if getattr(config, "mtp_num_layers", 0): + _warn_incompatible_once( + "mtp", "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." + ) + return False + + output_layer = getattr(post_process_module, "output_layer", None) + if output_layer is None: + return False + + hidden_size = _get_lm_head_hidden_size(config, output_layer) + if hidden_size is not None and hidden_size % _HIDDEN_SIZE_ALIGNMENT != 0: + _warn_incompatible_once( + f"hidden_size:{hidden_size}", + "Fused LCE disabled: hidden_size=%s is not divisible by %s.", + hidden_size, + _HIDDEN_SIZE_ALIGNMENT, + ) + return False + + parallel_output = getattr(post_process_module, "parallel_output", True) + if not parallel_output: + _warn_incompatible_once( + "parallel_output", + "Fused LCE disabled: model has parallel_output=False; " + "would require an extra TP gather.", + ) + return False + + # The Triton kernel hard-requires hidden_size to be a multiple of 128 + # (BLOCK_HD constant). Surface this constraint at the gating layer so + # incompatible models fall back to the materialised path before the + # autograd graph is built; an assert raised inside ``backward`` would + # otherwise hard-kill the training loop. + hidden_size = getattr(config, "hidden_size", None) + if hidden_size is None or hidden_size % 128 != 0: + logger.warning( + "Fused LCE disabled: hidden_size=%s is not a multiple of 128 " + "(Triton kernel BLOCK_HD constraint).", + hidden_size, + ) + return False + + return True + + +@contextmanager +def capture_lm_head_hidden( + model: torch.nn.Module, *, enabled: bool +) -> Iterator[_CaptureSlot | None]: + if not enabled: + yield None + return + + post_process = _unwrap_to_post_process_module(model) + if post_process is None or not _is_compatible(post_process): + yield None + return + + output_layer = post_process.output_layer + slot = _CaptureSlot() + original_forward = output_layer.forward + + config = getattr(post_process, "config", None) + sequence_parallel = bool(getattr(config, "sequence_parallel", False)) + tp_world_size = mpu.get_tensor_model_parallel_world_size() + needs_sp_gather = sequence_parallel and tp_world_size > 1 + + def _patched_forward(input_, weight=None, runtime_gather_output=None): + actual_weight = weight if weight is not None else output_layer.weight + + hidden = input_ + if needs_sp_gather: + hidden = gather_from_sequence_parallel_region(hidden) + + if hidden.dtype != actual_weight.dtype: + hidden = hidden.to(actual_weight.dtype) + + slot.hidden = hidden + slot.weight = actual_weight + return hidden, None + + output_layer.forward = _patched_forward # type: ignore[assignment] + try: + yield slot + finally: + try: + del output_layer.forward + except AttributeError: + output_layer.forward = original_forward # type: ignore[assignment] + + +__all__ = [ + "FUSED_LCE_HIDDEN_KEY", + "FUSED_LCE_WEIGHT_KEY", + "capture_lm_head_hidden", +] diff --git a/areal/models/kernel/__init__.py b/areal/models/kernel/__init__.py new file mode 100644 index 0000000000..0d61c6f2cf --- /dev/null +++ b/areal/models/kernel/__init__.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Triton-based fused linear-cross-entropy kernels for AReaL. + +The kernel implementations under :mod:`areal.models.kernel.kernels` fuse +the matmul with cross-entropy reduction, preserving numerical semantics +while avoiding materialization of the ``[num_tokens, vocab_size]`` logits +tensor. The :class:`LinearCrossEntropy` autograd function exposed below +provides a memory-efficient drop-in replacement for the materialized +``logits = hidden @ weight.T`` followed by softmax / log-softmax / +entropy computation. + +The :mod:`areal.models.kernel.functional` submodule additionally provides +high-level wrappers (``linear_cross_entropy_logprobs`` / +``linear_cross_entropy_logprobs_entropy``) that fall back to a +materialized reference implementation when the fused kernel is +unavailable. +""" + +from areal.models.kernel.functional import ( + linear_cross_entropy_logprobs, + linear_cross_entropy_logprobs_entropy, +) +from areal.models.kernel.linear_cross_entropy import ( + LinearCrossEntropy, + linear_cross_entropy, +) + +__all__ = [ + "LinearCrossEntropy", + "linear_cross_entropy", + "linear_cross_entropy_logprobs", + "linear_cross_entropy_logprobs_entropy", +] diff --git a/areal/models/kernel/functional.py b/areal/models/kernel/functional.py new file mode 100644 index 0000000000..f08b5108e0 --- /dev/null +++ b/areal/models/kernel/functional.py @@ -0,0 +1,186 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Fused linear cross-entropy entry points for AReaL. + +These wrappers bridge the fused Triton kernel into AReaL's +:func:`gather_logprobs_entropy` interface so the Megatron path can opt in +via a single config flag. They fall back to the materialised reference +path when Triton is unavailable or inputs are not on CUDA. +""" + +from __future__ import annotations + +import os + +import torch +import torch.distributed as dist + +from areal.utils import logging + +logger = logging.getLogger("LinearCrossEntropy") + + +def _force_fallback() -> bool: + return os.environ.get("AREAL_DISABLE_FUSED_LCE", "0") == "1" + + +def _kernel_available() -> bool: + if _force_fallback(): + return False + if not torch.cuda.is_available(): + return False + try: + import triton # noqa: F401 + except ImportError: + return False + return True + + +def _reference_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float, + tp_group: dist.ProcessGroup | None, + return_max_logits: bool = False, +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): + flat_hidden = hidden.reshape(-1, hidden.shape[-1]) + flat_labels = labels.reshape(-1) + + logits = torch.matmul(flat_hidden.float(), weight.float().t()) + if temperature != 1.0: + logits = logits / temperature + + if tp_group is not None and dist.get_world_size(tp_group) > 1: + world_size = dist.get_world_size(tp_group) + gathered = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(gathered, logits, group=tp_group) + logits = torch.cat(gathered, dim=-1) + + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + log_probs_labels = log_softmax.gather( + dim=-1, index=flat_labels.unsqueeze(-1) + ).squeeze(-1) + probs = log_softmax.exp() + entropy = -(probs * log_softmax).sum(dim=-1) + if return_max_logits: + # Return max of the post-temperature logits, scaled back by ``temperature`` + # so the value matches ``raw_logits.max(-1).values`` (matches the + # non-fused telemetry path exactly). + max_logits = logits.detach().max(dim=-1).values.float() + if temperature != 1.0: + max_logits = max_logits * temperature + return log_probs_labels, entropy, max_logits + return log_probs_labels, entropy + + +def linear_cross_entropy_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + tp_group: dist.ProcessGroup | None = None, + return_max_logits: bool = False, +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): + """Compute per-token log-prob and entropy via the fused kernel. + + Falls back to the materialised reference path when the fused kernel is + unavailable. + + Args: + hidden: ``(..., hidden_size)`` last-layer hidden states. + weight: ``(vocab_size, hidden_size)`` lm-head weight; may be + vocab-sharded when ``tp_group`` is set. + labels: ``(...,)`` integer label ids. With TP, labels must hold + *global* vocab ids. + temperature: softmax temperature. + tp_group: optional tensor-parallel group when ``weight`` is sharded. + return_max_logits: when ``True``, additionally returns the per-token + max of the **raw** (pre-temperature) logits, shape ``labels.shape``, + dtype ``float32``. The fused kernel internally tracks + ``max(logits/temperature)``; we multiply it back by ``temperature`` + so the value is numerically identical to + ``raw_logits.max(-1).values`` from the non-fused path. + + Returns: + ``(logprobs, entropy)`` both shaped like ``labels``; or + ``(logprobs, entropy, max_logits)`` when ``return_max_logits=True``. + """ + leading_shape = labels.shape + + if _kernel_available(): + from areal.models.kernel.linear_cross_entropy import linear_cross_entropy + + if hidden.device.type != "cuda": + logger.warning( + "Fused LCE requested but hidden is on %s; falling back to reference.", + hidden.device, + ) + else: + try: + if return_max_logits: + logprobs, entropy, max_logits = linear_cross_entropy( + hidden, + weight, + labels, + temperature, + "none", + tp_group, + return_max_logits=True, + ) + return ( + logprobs.reshape(leading_shape), + entropy.reshape(leading_shape), + max_logits.reshape(leading_shape), + ) + logprobs, entropy = linear_cross_entropy( + hidden, + weight, + labels, + temperature, + "none", + tp_group, + ) + return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) + except Exception as exc: + logger.warning( + "Fused LCE kernel raised %s; falling back to reference.", + exc, + ) + + if return_max_logits: + logprobs, entropy, max_logits = _reference_logprobs_entropy( + hidden, + weight, + labels, + temperature, + tp_group, + return_max_logits=True, + ) + return ( + logprobs.reshape(leading_shape), + entropy.reshape(leading_shape), + max_logits.reshape(leading_shape), + ) + logprobs, entropy = _reference_logprobs_entropy( + hidden, weight, labels, temperature, tp_group + ) + return logprobs.reshape(leading_shape), entropy.reshape(leading_shape) + + +def linear_cross_entropy_logprobs( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + tp_group: dist.ProcessGroup | None = None, +) -> torch.Tensor: + """Logprobs-only counterpart of :func:`linear_cross_entropy_logprobs_entropy`.""" + logprobs, _ = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature, tp_group + ) + return logprobs diff --git a/areal/models/kernel/kernels.py b/areal/models/kernel/kernels.py new file mode 100644 index 0000000000..00575ef9ff --- /dev/null +++ b/areal/models/kernel/kernels.py @@ -0,0 +1,1038 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Implementations of the linear cross entropy with token entropy kernel. + +Ref some code from verl. +The Triton kernel implementations fuse the matmul with cross-entropy +reduction so that the ``[num_tokens, vocab_size]`` logits tensor is never +materialized, trading kernel-launch overhead for large memory savings. +""" + +import torch +import torch.distributed as dist + + +def _is_cuda_available() -> bool: + return torch.cuda.is_available() + + +def get_device_capability(): + if torch.cuda.is_available(): + return torch.cuda.get_device_capability() + return (0, 0) + + +def get_device_name() -> str: + if torch.cuda.is_available(): + return "cuda" + return "cpu" + + +def get_torch_device(): + return torch.cuda + + +is_cuda_available = _is_cuda_available() + + +try: + import triton + import triton.language as tl + + HAVE_TRITON = True + SUPPORT_CUDA_TMA = ( + is_cuda_available + and get_device_capability()[0] >= 9 + and hasattr(tl, "make_tensor_descriptor") + ) + +except ImportError: + HAVE_TRITON = False + SUPPORT_CUDA_TMA = False + +if not HAVE_TRITON: + from contextlib import contextmanager + from unittest.mock import MagicMock + + @contextmanager + def null_decorator(*args, **kwargs): + if len(kwargs) == 0 and len(args) == 1 and callable(args[0]): + return args[0] + else: + + def inner(func): + return func + + return inner + + triton = MagicMock() + triton.jit = null_decorator + triton.autotune = null_decorator + tl = MagicMock() + +elif SUPPORT_CUDA_TMA: + # TMA descriptors require a global memory allocation + def alloc_fn(size: int, alignment: int, stream: int | None): + return torch.empty(size, device=get_device_name(), dtype=torch.int8) + + # https://github.com/triton-lang/triton/commit/43625fc968b693ab51884ca95adbcf3e43483fd0 + # Triton 3.5.0 stores allocators in ContextVar; values do not propagate to new + # threads by default. Some execution paths use thread pools (e.g., + # concurrent.futures), so we set a ContextVar *default* to avoid falling + # back to NullAllocator in worker threads. + try: + import contextvars + + import triton.runtime._allocation as _triton_allocation + + if isinstance( + getattr(_triton_allocation, "_allocator", None), contextvars.ContextVar + ): + _triton_allocation._allocator = contextvars.ContextVar( + _triton_allocation._allocator.name, + default=alloc_fn, + ) + except (ImportError, AttributeError): + pass + + triton.set_allocator(alloc_fn) + + +_REDUCTION_NONE = 0 + + +def get_entropy_reduction_enum_number(reduction: str) -> int: + if reduction == "none": + return _REDUCTION_NONE + raise ValueError(f"Only reduction='none' is supported, got {reduction!r}") + + +_USE_TRITON = True + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256, "BLOCK_SIZE_K": 32}, + num_stages=3, + num_warps=8, + ) + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_kernel_general_mainloop( + rank, + hidden_ptr, + weight_ptr, + labels_ptr, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + rcp_temperature: tl.float32, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + USE_TMA: tl.constexpr, +): + """ + forward mainloop + """ + pid = tl.program_id(axis=0) + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + pid_m = pid % num_pid_m + pid_n = pid // num_pid_m + + # create pointers for the first blocks of hidden + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + + else: + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + + # load labels for this block + labels = tl.load(labels_ptr + offs_am, mask=offs_am < num_tokens) + + # traverse over N dimension + # _max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _max = tl.full((BLOCK_SIZE_M,), -float("inf"), dtype=tl.float32) + _accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + _logprobs = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + vocab_bound = min((pid_n + 1) * vocab_per_split, vocab_size) + for n in range(0, num_pid_n): + start_offs_bn = pid_n * vocab_per_split + n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + if not USE_TMA: + # weight_ptrs = weight_ptr + (offs_k[:, None] * stride_weight_k + offs_bn[None, :] * stride_weight_n) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) + + # iterate over K dimension + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + # load the next block of hidden and weight + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + # load the next block of hidden and weight + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), + other=0.0, + ) + + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & ( + offs_bn[:, None] + < (min((pid_n + 1) * vocab_per_split, vocab_size)) + ), + other=0.0, + ) + + # advance the ptrs to the next K block + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + + # GEMM + logits = tl.dot(_hidden, _weight.trans(), logits) + + if not USE_TMA: + # reset hidden_ptrs for next iteration + hidden_ptrs -= hidden_size * stride_hidden_k + + # scale logits by temperature + logits *= rcp_temperature + + logits_for_lse = tl.where(offs_bn[None, :] < vocab_bound, logits, float("-inf")) + + # update global maximum + _max_old = _max + m_pid_n = tl.max(logits_for_lse, axis=1) + _max = tl.maximum(_max_old, m_pid_n) + + exp_logits = tl.exp(logits_for_lse - _max[:, None]) + coeff = tl.exp(_max_old - _max) + _accu = coeff * _accu + tl.sum(exp_logits, axis=1) + + _entropy_b = _entropy_b * coeff + tl.sum(logits * exp_logits, axis=1) + + label_mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + _logprobs += tl.sum(logits * label_mask, axis=1) + + # store maximum + offs_max_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_max_n = pid_n + maximum_ptrs = max_ptr + offs_max_n * stride_max_n + offs_max_m * stride_max_m + tl.store( + maximum_ptrs, _max, mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits) + ) + + # store entropy + accu_ptrs = accu_ptr + offs_max_n * stride_accu_n + offs_max_m * stride_accu_m + tl.store( + accu_ptrs, + _accu, + mask=(offs_max_m < num_tokens) & (offs_max_n[None] < num_splits), + ) + entropy_b_ptrs = ( + entropy_b_ptr + + offs_max_n * stride_entropy_b_n + + offs_max_m * stride_entropy_b_m + ) + tl.store( + entropy_b_ptrs, + _entropy_b, + mask=(offs_max_m < num_tokens) & (offs_max_n < num_splits), + ) + # store logprobs + vocab_left_idx = pid_n * vocab_per_split + rank * vocab_size + vocab_right_idx = min((pid_n + 1) * vocab_per_split, vocab_size) + rank * vocab_size + mask = (labels >= vocab_left_idx) & (labels < vocab_right_idx) + mask &= offs_am < num_tokens + global_logprobs_ptrs = global_logprobs_ptr + offs_am * stride_global_logprobs + # tl.atomic_add(global_logprobs_ptrs, _logprobs, mask=mask) + tl.store(global_logprobs_ptrs, _logprobs, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) +@triton.jit +def efficient_entropy_triton_kernel_epilogue( + max_ptr, + stride_max_m: tl.int64, + stride_max_n: tl.int64, + num_tokens, + num_splits, + global_max_ptr, + stride_global_max: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + global_entropy_ptr, + stride_global_entropy: tl.int64, + global_logprobs_ptr, + stride_global_logprobs: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + """ + foward epilogue + """ + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + max_ptrs = ( + max_ptr + offs_m[:, None] * stride_max_m + offs_n[None, :] * stride_max_n + ) + + _max = tl.load( + max_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + accu_ptrs = ( + accu_ptr + offs_m[:, None] * stride_accu_m + offs_n[None, :] * stride_accu_n + ) + _accu = tl.load( + accu_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + entropy_b_ptrs = ( + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n + ) + _entropy_b = tl.load( + entropy_b_ptrs, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduction + _max_old = global_max + _local_max = tl.max(_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + _scale = tl.exp(_max - global_max[:, None]) + _coeff = tl.exp(_max_old - global_max) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) + + # store + maximum_ptrs = global_max_ptr + offs_m * stride_global_max + tl.store(maximum_ptrs, global_max, mask=offs_m < num_tokens) + + # store entropy_b + global_entropy_b = tl.fdiv(global_entropy_b, global_accu) # entropy_b + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) + + # store entropy + global_accu_ptrs = global_accu_ptr + offs_m * stride_global_accu + tl.store(global_accu_ptrs, global_accu, mask=offs_m < num_tokens) + global_entropy = tl.log(global_accu) + global_max - global_entropy_b # entropy_a + global_entropy_ptrs = global_entropy_ptr + offs_m * stride_global_entropy + tl.store(global_entropy_ptrs, global_entropy, mask=offs_m < num_tokens) + # update logprobs + global_logprobs_ptrs = global_logprobs_ptr + offs_m * stride_global_logprobs + global_logprobs = tl.load(global_logprobs_ptrs, mask=offs_m < num_tokens) + global_logprobs = global_max + tl.log(global_accu) - global_logprobs + + global_logprobs = -1 * global_logprobs + tl.store(global_logprobs_ptrs, global_logprobs, mask=offs_m < num_tokens) + + +@triton.autotune( + configs=[triton.Config({"BLOCK_SIZE_M": 16, "BLOCK_SIZE_N": 64})], + key=["num_tokens", "num_splits"], +) +@triton.jit +def efficient_entropy_triton_kernel_epilogue_tp( + num_tokens, + num_splits, + reduced_max_ptr, + stride_reduced_max_m: tl.int64, + stride_reduced_max_n: tl.int64, + original_max_ptr, + stride_original_max_m: tl.int64, + stride_original_max_n: tl.int64, + accu_ptr, + stride_accu_m: tl.int64, + stride_accu_n: tl.int64, + entropy_b_ptr, + stride_entropy_b_m: tl.int64, + stride_entropy_b_n: tl.int64, + global_max_ptr, + stride_global_max: tl.int64, + global_accu_ptr, + stride_global_accu: tl.int64, + global_entropy_b_ptr, + stride_global_entropy_b: tl.int64, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + global_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_accu = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + global_entropy_b = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32) + for pid_n in range(0, tl.cdiv(num_splits, BLOCK_SIZE_N)): + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + + _reduced_max = tl.load( + reduced_max_ptr + + offs_m[:, None] * stride_reduced_max_m + + offs_n[None, :] * stride_reduced_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _original_max = tl.load( + original_max_ptr + + offs_m[:, None] * stride_original_max_m + + offs_n[None, :] * stride_original_max_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + _accu = tl.load( + accu_ptr + + offs_m[:, None] * stride_accu_m + + offs_n[None, :] * stride_accu_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + + # local reduce-max + _max_old = global_max + _local_max = tl.max(_reduced_max, axis=1) + global_max = tl.maximum(global_max, _local_max) + + # update accumulate + _coeff = tl.exp(_max_old - global_max) + _scale = tl.exp(_original_max - global_max[:, None]) + global_accu = _coeff * global_accu + tl.sum(_scale * _accu, axis=1) + + # update entropy_b + _entropy_b = tl.load( + entropy_b_ptr + + offs_m[:, None] * stride_entropy_b_m + + offs_n[None, :] * stride_entropy_b_n, + mask=(offs_m[:, None] < num_tokens) & (offs_n[None, :] < num_splits), + other=0.0, + ) + global_entropy_b = _coeff * global_entropy_b + tl.sum( + _scale * _entropy_b, axis=1 + ) + + # store + tl.store( + global_max_ptr + offs_m * stride_global_max, + global_max, + mask=offs_m < num_tokens, + ) + tl.store( + global_accu_ptr + offs_m * stride_global_accu, + global_accu, + mask=offs_m < num_tokens, + ) + tl.store( + global_entropy_b_ptr + offs_m * stride_global_entropy_b, + global_entropy_b, + mask=offs_m < num_tokens, + ) + + +@triton.autotune(configs=[triton.Config({"BLOCK_SIZE_M": 16})], key=["num_tokens"]) +@triton.jit +def efficient_entropy_triton_epilogue_tp_update( + num_tokens, + logprobs_ptr, + stride_logprobs: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accumulate_ptr, + stride_accumulate: tl.int64, + entropy_b_ptr, + stride_entropy_b: tl.int64, + entropy_ptr, + stride_entropy: tl.int64, + logprobs_out_ptr, + BLOCK_SIZE_M: tl.constexpr, +): + pid_m = tl.program_id(axis=0) + + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + + maximum = tl.load(maximum_ptr + offs_m * stride_maximum, mask=offs_m < num_tokens) + accumulate = tl.load( + accumulate_ptr + offs_m * stride_accumulate, mask=offs_m < num_tokens + ) + + entropy_b = tl.load( + entropy_b_ptr + offs_m * stride_entropy_b, mask=offs_m < num_tokens + ) + entropy_b = tl.fdiv(entropy_b, accumulate) + tl.store( + entropy_b_ptr + offs_m * stride_entropy_b, entropy_b, mask=offs_m < num_tokens + ) + + entropy = tl.log(accumulate) + maximum - entropy_b + tl.store(entropy_ptr + offs_m * stride_entropy, entropy, mask=offs_m < num_tokens) + + logprobs = tl.load( + logprobs_ptr + offs_m * stride_logprobs, mask=offs_m < num_tokens + ) + logprobs = maximum + tl.log(accumulate) - logprobs + + logprobs = -1 * logprobs + tl.store( + logprobs_out_ptr + offs_m * stride_logprobs, logprobs, mask=offs_m < num_tokens + ) + + +_dedicated_stream, _dedicated_events = None, None + + +def efficient_entropy_forward( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + reduction: int | None = _REDUCTION_NONE, + temperature: float | None = 1.0, + dist_process_group: dist.ProcessGroup | None = None, +) -> list[torch.Tensor]: + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) + + if dist_process_group is not None and not hasattr( + efficient_entropy_forward, "_initialized" + ): + global _dedicated_stream, _dedicated_events + _dedicated_stream = get_torch_device().Stream(hidden.device) + _dedicated_events = [get_torch_device().Event() for _ in range(2)] + efficient_entropy_forward._initialized = True + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + if reduction != _REDUCTION_NONE: + raise ValueError(f"Invalid reduction: {reduction}") + if dist_process_group is None: + logprobs = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + else: + logprobs = torch.zeros((num_tokens,), device=hidden.device, dtype=torch.float32) + + entropy = torch.empty((num_tokens,), device=hidden.device, dtype=torch.float32) + assert logprobs.is_contiguous() and entropy.is_contiguous() + + maximum = torch.empty_like(entropy) + accumulate_and_entropy_b = torch.empty( + (num_tokens * 2,), device=hidden.device, dtype=torch.float32 + ) + accumulate_and_entropy_b_view = accumulate_and_entropy_b.view(2, num_tokens) + accumulate = accumulate_and_entropy_b_view[0, :] + entropy_b = accumulate_and_entropy_b_view[1, :] + assert ( + maximum.is_contiguous() + and accumulate.is_contiguous() + and entropy_b.is_contiguous() + ) + + vocab_per_split = 1024 + assert vocab_per_split % 128 == 0 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _max = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _accu = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + _entropy_b = torch.empty( + (num_tokens, num_splits), device=hidden.device, dtype=torch.float32 + ) + + _logprobs = logprobs + + assert _accu.is_contiguous() and _entropy_b.is_contiguous() and _max.is_contiguous() + assert _accu.is_cuda and _entropy_b.is_cuda and _max.is_cuda + + if _USE_TRITON: + # 1D kernel launch, then split the tile + def mainloop_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) * num_splits,) + + efficient_entropy_kernel_general_mainloop[mainloop_grid]( + _rank, + hidden, + weight, + labels, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + hidden.stride(0), + hidden.stride(1), + weight.stride(0), + weight.stride(1), + _max, + _max.stride(0), + _max.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + _logprobs, + _logprobs.stride(0), + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA + and hidden.stride(1) == 1 + and weight.stride(1) == 1, + ) + else: + raise AssertionError("Triton is required for efficient entropy kernel") + + # reduction on maximum and maximum_indices + def epilogue_grid(meta): + return (triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]),) + + if dist_process_group is None: + efficient_entropy_triton_kernel_epilogue[epilogue_grid]( + _max, + _max.stride(0), + _max.stride(1), + num_tokens, + num_splits, + maximum, + maximum.stride(0), + _accu, + _accu.stride(0), + _accu.stride(1), + accumulate, + accumulate.stride(0), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + _logprobs, + _logprobs.stride(0), + ) + else: + # tensor-parallel + _max_backup = _max.clone() + dist.all_reduce(_max, op=dist.ReduceOp.MAX, group=dist_process_group) + + get_torch_device().current_stream().record_event(_dedicated_events[0]) + with get_torch_device().stream(_dedicated_stream): + _dedicated_stream.wait_event(_dedicated_events[0]) + dist.all_reduce(_logprobs, op=dist.ReduceOp.SUM, group=dist_process_group) + _dedicated_stream.record_event(_dedicated_events[1]) + + efficient_entropy_triton_kernel_epilogue_tp[epilogue_grid]( + num_tokens, + num_splits, + _max, + _max.stride(0), + _max.stride(1), + _max_backup, + _max_backup.stride(0), + _max_backup.stride(1), + _accu, + _accu.stride(0), + _accu.stride(1), + _entropy_b, + _entropy_b.stride(0), + _entropy_b.stride(1), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + ) + get_torch_device().current_stream().wait_event(_dedicated_events[1]) + + dist.all_reduce( + accumulate_and_entropy_b, op=dist.ReduceOp.SUM, group=dist_process_group + ) + + # update logprobs & entropy + efficient_entropy_triton_epilogue_tp_update[epilogue_grid]( + num_tokens, + _logprobs, + _logprobs.stride(0), + maximum, + maximum.stride(0), + accumulate, + accumulate.stride(0), + entropy_b, + entropy_b.stride(0), + entropy, + entropy.stride(0), + logprobs, + ) + + return (logprobs, entropy, maximum, accumulate, entropy_b) + + +@triton.autotune( + configs=[ + triton.Config( + { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 256, + "BLOCK_SIZE_K": 32, + "GROUP_SIZE_M": 16, + }, + num_stages=3, + num_warps=8, + ), + ], + key=["num_tokens", "hidden_size", "vocab_size"], +) +@triton.jit +def efficient_entropy_backward_kernel_general_d_logits_split_N( + split_idx: int, + num_tokens: int, + hidden_size: int, + vocab_size: int, + vocab_per_split: int, + rank: int, + hidden_ptr, + stride_hidden_m: tl.int64, + stride_hidden_k: tl.int64, + weight_ptr, + stride_weight_n: tl.int64, + stride_weight_k: tl.int64, + labels_ptr, + stride_labels: tl.int64, + maximum_ptr, + stride_maximum: tl.int64, + accu_ptr, + stride_accu: tl.int64, + d_entropy_ptr, + stride_d_entropy: tl.int64, + d_logprobs_ptr, + stride_d_logprobs: tl.int64, + entropy_b_ptr, + stride_entropy_b: tl.int64, + d_logits_ptr, + stride_d_logits_m: tl.int64, + stride_d_logits_n: tl.int64, + rcp_temperature: tl.float32, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + USE_TMA: tl.constexpr, +): + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(num_tokens, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(vocab_per_split, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + start_offs_am = pid_m * BLOCK_SIZE_M + offs_am = start_offs_am + tl.arange(0, BLOCK_SIZE_M) + start_offs_bn = split_idx * vocab_per_split + pid_n * BLOCK_SIZE_N + offs_bn = start_offs_bn + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + + maximum = tl.load( + maximum_ptr + offs_am * stride_maximum, mask=offs_am < num_tokens, other=0.0 + ) + accu = tl.load( + accu_ptr + offs_am * stride_accu, mask=offs_am < num_tokens, other=1e-6 + ) + accu_rcp = tl.fdiv(1.0, accu) + d_entropy = tl.load( + d_entropy_ptr + offs_am * stride_d_entropy, mask=offs_am < num_tokens, other=0.0 + ) + d_logprobs = tl.load( + d_logprobs_ptr + offs_am * stride_d_logprobs, + mask=offs_am < num_tokens, + other=0.0, + ) + d_logprobs = -1 * d_logprobs + entropy_b = tl.load( + entropy_b_ptr + offs_am * stride_entropy_b, mask=offs_am < num_tokens, other=0.0 + ) + labels = tl.load( + labels_ptr + offs_am * stride_labels, mask=offs_am < num_tokens, other=0 + ) + + logits = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + if USE_TMA: + # using TMA and device-side descriptor creation + hidden_desc = tl.make_tensor_descriptor( + hidden_ptr, + shape=[num_tokens, hidden_size], + strides=[stride_hidden_m, 1], + block_shape=[BLOCK_SIZE_M, BLOCK_SIZE_K], + ) + weight_desc = tl.make_tensor_descriptor( + weight_ptr, + shape=[vocab_size, hidden_size], + strides=[stride_weight_n, 1], + block_shape=[BLOCK_SIZE_N, BLOCK_SIZE_K], + ) + else: + hidden_ptrs = hidden_ptr + ( + offs_am[:, None] * stride_hidden_m + offs_k[None, :] * stride_hidden_k + ) + weight_ptrs = weight_ptr + ( + offs_bn[:, None] * stride_weight_n + offs_k[None, :] * stride_weight_k + ) + vocab_right_bound = min((split_idx + 1) * vocab_per_split, vocab_size) + + for k in range(0, tl.cdiv(hidden_size, BLOCK_SIZE_K)): + if USE_TMA: + start_offs_k = k * BLOCK_SIZE_K + _hidden = hidden_desc.load([start_offs_am, start_offs_k]) + _weight = weight_desc.load([start_offs_bn, start_offs_k]) + else: + _hidden = tl.load( + hidden_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_am[:, None] < num_tokens), + other=0.0, + ) + _weight = tl.load( + weight_ptrs, + mask=(offs_k[None, :] < hidden_size - k * BLOCK_SIZE_K) + & (offs_bn[:, None] < vocab_right_bound), + other=0.0, + ) + hidden_ptrs += BLOCK_SIZE_K * stride_hidden_k + weight_ptrs += BLOCK_SIZE_K * stride_weight_k + logits = tl.dot(_hidden, _weight.T, logits) + + logits *= rcp_temperature + exp_logits = tl.exp(logits - maximum[:, None]) + + mask = (offs_bn + rank * vocab_size)[None, :] == labels[:, None] + d_logits = d_logprobs[:, None] * (exp_logits * accu_rcp[:, None] - mask) + d_logits += ( + d_entropy[:, None] + * (-exp_logits * accu_rcp[:, None]) + * (logits - entropy_b[:, None]) + ) + + d_logits *= rcp_temperature + + # filter d_logits with mask + result_offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + mask = (offs_am[:, None] < num_tokens) & (result_offs_n[None, :] < vocab_per_split) + + tl.store( + d_logits_ptr + + offs_am[:, None] * stride_d_logits_m + + result_offs_n[None, :] * stride_d_logits_n, + d_logits, + mask, + ) + + +def efficient_entropy_backward( + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + maximum: torch.Tensor, + acc: torch.Tensor, + entropy_b: torch.Tensor, + reduction: int | None = _REDUCTION_NONE, + should_return_fp32_grad: bool = False, + temperature: float | None = 1.0, + dist_process_group: dist.ProcessGroup | None = None, +) -> list[torch.Tensor]: + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda + assert weight.device == hidden.device and labels.device == hidden.device + assert hidden.dim() == 2 and weight.dim() == 2 and labels.dim() == 1 + assert hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + assert hidden.shape[0] == labels.shape[0] and hidden.shape[1] == weight.shape[1] + + _rank = 0 if dist_process_group is None else dist.get_rank(dist_process_group) + _world_size = ( + 1 if dist_process_group is None else dist.get_world_size(dist_process_group) + ) + + num_tokens, hidden_size = hidden.shape + num_tokens = labels.shape[0] + vocab_size, hidden_size = weight.shape + assert hidden_size % 128 == 0 + + if reduction != _REDUCTION_NONE: + raise ValueError(f"Invalid reduction: {reduction}") + assert dlogprobs.shape == (num_tokens,) + + assert dlogprobs.is_contiguous() and dentropy.is_contiguous() + assert dlogprobs.is_cuda and dentropy.is_cuda + assert dlogprobs.device == hidden.device and dlogprobs.device == dentropy.device + assert dentropy.shape == (num_tokens,) + + grad_dtype = torch.float32 if should_return_fp32_grad else hidden.dtype + d_hidden = torch.empty_like(hidden, dtype=grad_dtype, device=hidden.device) + d_weight = torch.empty_like(weight, dtype=grad_dtype, device=weight.device) + assert d_hidden.is_contiguous() and d_weight.is_contiguous() + + assert maximum.is_contiguous() and acc.is_contiguous() + assert maximum.device == hidden.device and acc.device == hidden.device + assert maximum.shape == labels.shape == acc.shape + assert maximum.is_cuda and acc.is_cuda + + assert entropy_b.is_contiguous() and entropy_b.is_cuda + assert entropy_b.shape == (num_tokens,) + + vocab_per_split = 9504 + num_splits = (vocab_size + vocab_per_split - 1) // vocab_per_split + + _d_logits = torch.empty( + (num_tokens, vocab_per_split), device=hidden.device, dtype=hidden.dtype + ).contiguous() + assert _d_logits.is_contiguous() + + def d_logits_grid(meta): + return ( + triton.cdiv(num_tokens, meta["BLOCK_SIZE_M"]) + * triton.cdiv(vocab_per_split, meta["BLOCK_SIZE_N"]), + ) + + for split_idx in range(num_splits): + efficient_entropy_backward_kernel_general_d_logits_split_N[d_logits_grid]( + split_idx, + num_tokens, + hidden_size, + vocab_size, + vocab_per_split, + _rank, + hidden, + hidden.stride(0), + hidden.stride(1), + weight, + weight.stride(0), + weight.stride(1), + labels, + labels.stride(0), + maximum, + maximum.stride(0), + acc, + acc.stride(0), + dentropy, + dentropy.stride(0), + dlogprobs, + dlogprobs.stride(0), + entropy_b, + entropy_b.stride(0), + _d_logits, + _d_logits.stride(0), + _d_logits.stride(1), + 1.0 / temperature, + USE_TMA=SUPPORT_CUDA_TMA + and hidden.stride(1) == 1 + and weight.stride(1) == 1, + ) + + split_start = split_idx * vocab_per_split + split_end = min(split_start + vocab_per_split, vocab_size) + current_d_logits = _d_logits[:, : split_end - split_start] + current_weight = weight[split_start:split_end, :] + current_d_weight = d_weight[split_start:split_end, :] + + if split_idx == 0: + torch.matmul(current_d_logits, current_weight, out=d_hidden) + else: + d_hidden += torch.matmul(current_d_logits, current_weight) + torch.matmul(current_d_logits.T, hidden, out=current_d_weight) + return d_hidden, d_weight diff --git a/areal/models/kernel/linear_cross_entropy.py b/areal/models/kernel/linear_cross_entropy.py new file mode 100644 index 0000000000..32bbca3c86 --- /dev/null +++ b/areal/models/kernel/linear_cross_entropy.py @@ -0,0 +1,181 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Fused linear + cross-entropy autograd Function. + +Dispatches to a Triton kernel that fuses the matmul with cross-entropy +so that the ``[num_tokens, vocab_size]`` logits tensor is never materialised. +""" + +from __future__ import annotations + +import torch +import torch.distributed as dist + + +class LinearCrossEntropy(torch.autograd.Function): + """Fused linear + cross-entropy autograd Function. + + Args: + hidden: ``(num_tokens, hidden_size)`` contiguous CUDA tensor. + weight: ``(vocab_size, hidden_size)`` lm-head weight, contiguous CUDA. + labels: ``(num_tokens,)`` integer label ids on CUDA. + temperature: softmax temperature; defaults to ``1.0``. + reduction: only ``"none"`` is supported. + dist_process_group: optional TP group for vocab-sharded ``weight``. + ``labels`` must contain *global* vocab ids on every rank. + return_max_logits: when ``True``, the autograd Function additionally + returns the per-token raw-logit max (kernel-internal + ``max(logits/temperature)`` re-scaled by ``temperature``). The + extra output is detached / non-differentiable. + + Returns: + ``(logprobs, entropy)`` both shaped ``(num_tokens,)``; or + ``(logprobs, entropy, max_logits)`` when ``return_max_logits=True``. + """ + + @staticmethod + def forward( + ctx, + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float | None = 1.0, + reduction: str | None = "none", + dist_process_group: dist.ProcessGroup | None = None, + return_max_logits: bool = False, + ) -> ( + tuple[torch.Tensor, torch.Tensor] + | tuple[torch.Tensor, torch.Tensor, torch.Tensor] + ): + if not isinstance(temperature, float): + temperature = float(temperature) + if not isinstance(reduction, str): + raise TypeError(f"reduction must be str, got {type(reduction)}") + + from areal.models.kernel import kernels + + REDUCTION = kernels.get_entropy_reduction_enum_number(reduction.lower()) + + original_hidden_shape = hidden.shape + if hidden.dim() != 2: + hidden = hidden.reshape(-1, hidden.shape[-1]) + if labels.dim() != 1: + labels = labels.reshape(-1) + + assert hidden.is_cuda and weight.is_cuda and labels.is_cuda, ( + "LinearCrossEntropy requires CUDA inputs" + ) + assert ( + hidden.is_contiguous() and weight.is_contiguous() and labels.is_contiguous() + ), "LinearCrossEntropy requires contiguous tensors" + + ( + logprobs, + entropy, + _maximum, + _accumulate, + _entropy_b, + ) = kernels.efficient_entropy_forward( + hidden, + weight, + labels, + REDUCTION, + temperature, + dist_process_group, + ) + + ctx.save_for_backward(hidden, weight, labels, _maximum, _accumulate, _entropy_b) + ctx.original_hidden_shape = original_hidden_shape + ctx.REDUCTION = REDUCTION + ctx.dist_process_group = dist_process_group + ctx.should_return_fp32_grad = False + ctx.temperature = temperature + + if return_max_logits: + # ``_maximum`` is the per-token max of ``logits / temperature`` + # (post-temperature, online-softmax accumulator). Multiply back + # by ``temperature`` to recover the raw-logit max so the value + # matches ``raw_logits.max(-1).values`` from the non-fused path. + if temperature != 1.0: + max_logits = _maximum.detach() * temperature + else: + max_logits = _maximum.detach().clone() + return logprobs, entropy, max_logits + + return logprobs, entropy + + @staticmethod + def backward( + ctx, + dlogprobs: torch.Tensor, + dentropy: torch.Tensor, + dmax_logits: torch.Tensor | None = None, + ) -> tuple: + from areal.models.kernel import kernels + + ( + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ) = ctx.saved_tensors + + dlogprobs = dlogprobs.contiguous() + dentropy = dentropy.contiguous() + + d_hidden, d_weight = kernels.efficient_entropy_backward( + dlogprobs, + dentropy, + hidden, + weight, + labels, + _maximum, + _accumulate, + _entropy_b, + ctx.REDUCTION, + ctx.should_return_fp32_grad, + ctx.temperature, + ctx.dist_process_group, + ) + + # TP all-reduce on d_hidden: the fused path bypasses mcore's + # ColumnParallelLinear which normally inserts this reduction. + # d_weight does NOT need all-reduce (each rank owns its vocab shard). + if ( + ctx.dist_process_group is not None + and dist.get_world_size(ctx.dist_process_group) > 1 + ): + dist.all_reduce( + d_hidden, + op=dist.ReduceOp.SUM, + group=ctx.dist_process_group, + ) + + d_hidden = d_hidden.view(ctx.original_hidden_shape) + + return d_hidden, d_weight, None, None, None, None, None + + +def linear_cross_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, + reduction: str = "none", + dist_process_group: dist.ProcessGroup | None = None, + return_max_logits: bool = False, +) -> ( + tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, torch.Tensor] +): + """Functional wrapper around :class:`LinearCrossEntropy`.""" + return LinearCrossEntropy.apply( + hidden, + weight, + labels, + temperature, + reduction, + dist_process_group, + return_max_logits, + ) diff --git a/areal/trainer/ppo/actor.py b/areal/trainer/ppo/actor.py index 07944a31a5..b0cb676ccc 100644 --- a/areal/trainer/ppo/actor.py +++ b/areal/trainer/ppo/actor.py @@ -546,9 +546,13 @@ def grpo_loss_fn( if "filtered_fraction" in stat: stats_tracker.scalar(rs_filtered_fraction=stat["filtered_fraction"]) - if vocab_min_logits is not None and vocab_max_logits is not None: + if vocab_min_logits is not None: stats_tracker.stat( vocab_min_logits=vocab_min_logits, + denominator="n_tokens", + ) + if vocab_max_logits is not None: + stats_tracker.stat( vocab_max_logits=vocab_max_logits, denominator="n_tokens", ) diff --git a/areal/trainer/sft/lm_engine.py b/areal/trainer/sft/lm_engine.py index 6705af0dbd..f617649e2a 100644 --- a/areal/trainer/sft/lm_engine.py +++ b/areal/trainer/sft/lm_engine.py @@ -121,9 +121,13 @@ def compute_packed_sft_loss( stats_tracker.stat(ppl=(-seqlogp).exp().float(), denominator="n_seqs") stats_tracker.stat(loss=-logprobs.detach(), denominator="n_valid_tokens") - if vocab_min_logits is not None and vocab_max_logits is not None: + if vocab_min_logits is not None: stats_tracker.stat( vocab_min_logits=vocab_min_logits, + denominator="n_tokens", + ) + if vocab_max_logits is not None: + stats_tracker.stat( vocab_max_logits=vocab_max_logits, denominator="n_tokens", ) diff --git a/areal/utils/network.py b/areal/utils/network.py index 3481016039..0720200e00 100644 --- a/areal/utils/network.py +++ b/areal/utils/network.py @@ -23,6 +23,22 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: Raises: RuntimeError: If no suitable address can be determined """ + try: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.connect((probe_host, probe_port)) + return sock.getsockname()[0] + except OSError: + pass + + try: + with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: + sock.connect(("2001:4860:4860::8888", probe_port)) + ip6 = sock.getsockname()[0] + if ip6 and ip6 != "::1": + return ip6 + except OSError: + pass + try: hostname = socket.gethostname() infos = socket.getaddrinfo(hostname, None, socket.AF_UNSPEC, socket.SOCK_DGRAM) @@ -38,19 +54,7 @@ def gethostip(probe_host: str = "8.8.8.8", probe_port: int = 80) -> str: except socket.gaierror: pass - try: - with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: - sock.connect((probe_host, probe_port)) - return sock.getsockname()[0] - except OSError as e: - try: - with socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) as sock: - sock.connect(("2001:4860:4860::8888", probe_port)) - ip6 = sock.getsockname()[0] - if ip6 and ip6 != "::1": - return ip6 - except OSError: - raise RuntimeError("Could not determine host IP") from e + raise RuntimeError("Could not determine host IP") def get_loopback_ip() -> str: diff --git a/benchmark/kernels/bench_linear_cross_entropy.py b/benchmark/kernels/bench_linear_cross_entropy.py new file mode 100644 index 0000000000..5b470f1929 --- /dev/null +++ b/benchmark/kernels/bench_linear_cross_entropy.py @@ -0,0 +1,304 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Standalone benchmark for the fused linear-cross-entropy kernel. + +Designed to be run outside pytest to measure forward+backward latency and +peak memory for the materialised reference path and the fused Triton path. + +Usage:: + + # Qwen3 single-GPU full-vocab benchmark + uv run python -m benchmark.kernels.bench_linear_cross_entropy \\ + --mode both --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness + + # Qwen3 TP=2 benchmark. The reference path materialises only local + # [tokens, vocab/tp] logits and uses vocab-parallel reductions. + uv run torchrun --nproc_per_node=2 --nnodes=1 \\ + --master-addr=localhost --master_port=29501 \\ + -m benchmark.kernels.bench_linear_cross_entropy \\ + --mode both --tp-size 2 --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness + + # Qwen3 TP=4 benchmark + uv run torchrun --nproc_per_node=4 --nnodes=1 \\ + --master-addr=localhost --master_port=29501 \\ + -m benchmark.kernels.bench_linear_cross_entropy \\ + --mode both --tp-size 4 --tokens 2048 --hidden 4096 --vocab 152064 \\ + --dtype bfloat16 --warmup 5 --iters 15 --check-correctness +""" + +from __future__ import annotations + +import argparse +import gc +import math +import os +import sys + +import torch +import torch.distributed as dist + +from areal.utils.functional import gather_logprobs_entropy + + +def _setup_distributed(tp_size: int): + if tp_size == 1: + return None + if not dist.is_available(): + raise RuntimeError("torch.distributed is required when --tp-size > 1") + if not dist.is_initialized(): + required = ("RANK", "WORLD_SIZE", "LOCAL_RANK", "MASTER_PORT") + missing = [k for k in required if k not in os.environ] + if missing: + raise RuntimeError( + "--tp-size > 1 must be launched with torchrun; missing env vars: " + + ", ".join(missing) + ) + torch.cuda.set_device(int(os.environ["LOCAL_RANK"])) + dist.init_process_group(backend="nccl") + world_size = dist.get_world_size() + if world_size != tp_size: + raise RuntimeError( + f"--tp-size={tp_size} must match torchrun world_size={world_size}" + ) + torch.cuda.set_device(int(os.environ.get("LOCAL_RANK", dist.get_rank()))) + return dist.group.WORLD + + +def _rank(tp_group): + return dist.get_rank(tp_group) if tp_group is not None else 0 + + +def _world_size(tp_group): + return dist.get_world_size(tp_group) if tp_group is not None else 1 + + +def _make_inputs(num_tokens, hidden_size, vocab_size, dtype, tp_group=None, seed=0): + world_size = _world_size(tp_group) + rank = _rank(tp_group) + if vocab_size % world_size != 0: + raise ValueError( + f"vocab_size={vocab_size} must be divisible by tp_size={world_size}" + ) + local_vocab_size = vocab_size // world_size + + g = torch.Generator(device="cuda").manual_seed(seed) + hidden = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda", generator=g) + * 0.02 + ) + weight = ( + torch.randn( + local_vocab_size, + hidden_size, + dtype=dtype, + device="cuda", + generator=g, + ) + * 0.02 + ) + if tp_group is not None: + weight = weight + (rank * 0.001) + labels = torch.randint(0, vocab_size, (num_tokens,), device="cuda", generator=g) + return hidden.contiguous(), weight.contiguous(), labels.contiguous() + + +def _ref_step(hidden, weight, labels, temperature=1.0, tp_group=None): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + logits = h.float() @ w.float().t() + if tp_group is not None: + lp, ent = gather_logprobs_entropy( + logits, labels, temperature=temperature, tp_group=tp_group + ) + else: + log_softmax = torch.nn.functional.log_softmax(logits / temperature, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + (lp.sum() + ent.sum()).backward() + if tp_group is not None: + dist.all_reduce(h.grad, op=dist.ReduceOp.SUM, group=tp_group) + return lp.detach(), ent.detach(), h.grad.detach(), w.grad.detach() + + +def _fused_step(hidden, weight, labels, temperature=1.0, tp_group=None): + from areal.models.kernel import linear_cross_entropy + + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, temperature, "none", tp_group) + (lp.sum() + ent.sum()).backward() + return lp.detach(), ent.detach(), h.grad.detach(), w.grad.detach() + + +def _check_correctness(hidden, weight, labels, dtype, tp_group=None): + ref_lp, ref_ent, ref_dh, ref_dw = _ref_step( + hidden, weight, labels, tp_group=tp_group + ) + fused_lp, fused_ent, fused_dh, fused_dw = _fused_step( + hidden, weight, labels, tp_group=tp_group + ) + + if dtype == torch.float32: + rtol, atol = 1e-4, 1e-4 + elif dtype == torch.bfloat16: + rtol, atol = 3e-2, 3e-2 + else: + rtol, atol = 2e-2, 2e-2 + + torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_ent.float(), ref_ent.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_dh.float(), ref_dh.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_dw.float(), ref_dw.float(), rtol=rtol, atol=atol) + + +def _measure(label, fn, hidden, weight, labels, warmup, iters, tp_group=None): + nvtx = torch.cuda.nvtx + times = [] + mems = [] + + # Warmup + nvtx.range_push(f"{label}/warmup") + for _ in range(warmup): + fn(hidden, weight, labels, tp_group=tp_group) + gc.collect() + torch.cuda.empty_cache() + nvtx.range_pop() + + nvtx.range_push(f"{label}/measure") + for i in range(iters): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + nvtx.range_push(f"{label}/iter{i}") + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn(hidden, weight, labels, tp_group=tp_group) + end.record() + torch.cuda.synchronize() + nvtx.range_pop() + times.append(start.elapsed_time(end)) + mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + nvtx.range_pop() + + return times, mems + + +def _distributed_max(value, tp_group): + if tp_group is None: + return value + tensor = torch.tensor(value, dtype=torch.float64, device="cuda") + dist.all_reduce(tensor, op=dist.ReduceOp.MAX, group=tp_group) + return float(tensor.item()) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--tokens", type=int, default=4096) + parser.add_argument("--hidden", type=int, default=4096) + parser.add_argument("--vocab", type=int, default=152064) + parser.add_argument("--tp-size", type=int, default=1) + parser.add_argument( + "--dtype", + choices=["bfloat16", "float16", "float32"], + default="bfloat16", + ) + parser.add_argument("--warmup", type=int, default=3) + parser.add_argument("--iters", type=int, default=10) + parser.add_argument("--check-correctness", action="store_true") + parser.add_argument( + "--use-cuda-profiler-api", + action="store_true", + help="Wrap the measurement region with cudaProfilerStart/Stop.", + ) + parser.add_argument("--mode", choices=["both", "ref", "fused"], default="both") + args = parser.parse_args() + + if not torch.cuda.is_available(): + print("CUDA is not available; aborting.", file=sys.stderr) + sys.exit(1) + + tp_group = _setup_distributed(args.tp_size) + dtype = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + }[args.dtype] + hidden, weight, labels = _make_inputs( + args.tokens, args.hidden, args.vocab, dtype, tp_group=tp_group + ) + if _rank(tp_group) == 0: + print( + f"[bench] tokens={args.tokens} hidden={args.hidden} vocab={args.vocab} " + f"tp={args.tp_size} dtype={args.dtype} warmup={args.warmup} " + f"iters={args.iters}" + ) + + if args.check_correctness: + _check_correctness(hidden, weight, labels, dtype, tp_group=tp_group) + if _rank(tp_group) == 0: + print("[bench] correctness check passed") + + if args.use_cuda_profiler_api: + torch.cuda.cudart().cudaProfilerStart() + + results = {} + if args.mode in ("both", "ref"): + t, m = _measure( + "reference", + _ref_step, + hidden, + weight, + labels, + args.warmup, + args.iters, + tp_group=tp_group, + ) + results["reference"] = (t, m) + if args.mode in ("both", "fused"): + t, m = _measure( + "fused", + _fused_step, + hidden, + weight, + labels, + args.warmup, + args.iters, + tp_group=tp_group, + ) + results["fused"] = (t, m) + + if args.use_cuda_profiler_api: + torch.cuda.cudart().cudaProfilerStop() + + summaries = {} + for name, (t, m) in results.items(): + local_median = sorted(t)[len(t) // 2] + local_peak = max(m) + summaries[name] = ( + _distributed_max(local_median, tp_group), + _distributed_max(local_peak, tp_group), + ) + + if _rank(tp_group) == 0: + for name, (median, peak) in summaries.items(): + print(f"[bench] {name:9s} median={median:7.2f}ms peak_mem={peak:8.1f}MB") + + if "reference" in summaries and "fused" in summaries: + ref_med, ref_peak = summaries["reference"] + fused_med, fused_peak = summaries["fused"] + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + print( + f"[bench] speedup={speedup:.2f}x fused_peak/ref_peak={mem_ratio:.2f}x" + ) + + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main() diff --git a/docs/en/cli_reference.md b/docs/en/cli_reference.md index 0b217a2673..8c6260daff 100644 --- a/docs/en/cli_reference.md +++ b/docs/en/cli_reference.md @@ -370,6 +370,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -442,6 +443,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -488,6 +490,7 @@ Core configuration for model training, including optimization and backend settin | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -993,6 +996,7 @@ fields. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1257,6 +1261,7 @@ Configuration class: TeacherConfig | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/docs/zh/cli_reference.md b/docs/zh/cli_reference.md index e9e6f11180..31bf3cac24 100644 --- a/docs/zh/cli_reference.md +++ b/docs/zh/cli_reference.md @@ -368,6 +368,7 @@ Configuration for PPO actor model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -440,6 +441,7 @@ Configuration for PPO critic model, a subclass of a TrainEngine. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -486,6 +488,7 @@ Core configuration for model training, including optimization and backend settin | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -991,6 +994,7 @@ fields. | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | @@ -1255,6 +1259,7 @@ Configuration class: TeacherConfig | `target_modules` | list of string | **Required** | lora target_modules. | | `peft_type` | string | `"lora"` | peft method type. Only LoRA is supported for now. | | `enable_tree_training` | boolean | `False` | Enable tree training with flex attention module. | +| `use_fused_linear_ce` | boolean | `False` | Fuse the linear projection with cross-entropy so that the \[num_tokens, vocab_size\] logits tensor is never materialised. Only effective for the Megatron actor backend with parallel_output=True. | | `scheduling_spec` | `tuple` | **Required** | Train engine schedule specs. Can accept 1 or 2 SchedulingSpec: if 1 spec provided, it's used for both worker and engine, engine is embedded in the worker; if 2 specs provided, first one is for worker, second one is for engine. Currently only used by the TrainController. | | `backend` | string | **Required** | Backend and parallelism strategy. Must include an explicit backend prefix, e.g. 'fsdp:d4', 'megatron:d4t2p2', 'archon:d2'. Required. | | `_version` | string | `"v1"` | Train controller implementation version. Use 'v1' for legacy TrainController, 'v2' for GatewayTrainController. **Choices:** `v1`, `v2` | diff --git a/tests/test_linear_cross_entropy.py b/tests/test_linear_cross_entropy.py new file mode 100644 index 0000000000..90d82f29d0 --- /dev/null +++ b/tests/test_linear_cross_entropy.py @@ -0,0 +1,464 @@ +# SPDX-License-Identifier: Apache-2.0 +""" +Correctness + performance tests for the fused linear-cross-entropy kernel. + +The test suite verifies that +:func:`areal.models.kernel.linear_cross_entropy_logprobs_entropy` produces +results numerically equivalent to the materialised ``logits @ weight`` + +``log_softmax`` reference, and that it provides a measurable wall-clock / +memory benefit over the reference path on representative LLM shapes. + +The performance assertions are intentionally loose (>=1.0x runtime, i.e. +"not slower") so they remain meaningful in CI where cudagraph capture and +power-state variability can swing absolute timings; the PRINTED report is +the authoritative artifact for review. + +Run only the correctness checks (fast, single-GPU):: + + pytest tests/test_linear_cross_entropy.py -k correctness -s + +Run the full benchmark (includes large-vocab cases, slow):: + + pytest tests/test_linear_cross_entropy.py -m slow -s +""" + +from __future__ import annotations + +import gc +import math + +import pytest +import torch + +CUDA_AVAILABLE = torch.cuda.is_available() +try: + import triton # noqa: F401 + + TRITON_AVAILABLE = True +except ImportError: + TRITON_AVAILABLE = False + + +pytestmark = pytest.mark.skipif( + not (CUDA_AVAILABLE and TRITON_AVAILABLE), + reason="Fused LCE requires CUDA + Triton", +) + + +def _reference_logprobs_entropy( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + temperature: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor]: + """Materialised-logits reference. Same math, no fusion.""" + logits = (hidden.float() @ weight.float().t()) / temperature + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + logprobs = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + entropy = -(probs * log_softmax).sum(dim=-1) + return logprobs, entropy + + +def _make_inputs( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + device: str = "cuda", + seed: int = 0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + g = torch.Generator(device=device).manual_seed(seed) + hidden = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + weight = ( + torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) + return hidden.contiguous(), weight.contiguous(), labels.contiguous() + + +# --------------------------------------------------------------------------- +# Correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size,dtype", + [ + (256, 512, 4096, torch.float32), + (512, 1024, 32000, torch.bfloat16), + (128, 768, 8192, torch.float16), + ], +) +def test_linear_cross_entropy_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, +) -> None: + """Fused forward output must match the materialised reference.""" + from areal.models.kernel import linear_cross_entropy_logprobs_entropy + + hidden, weight, labels = _make_inputs(num_tokens, hidden_size, vocab_size, dtype) + + ref_logprobs, ref_entropy = _reference_logprobs_entropy(hidden, weight, labels) + fused_logprobs, fused_entropy = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature=1.0 + ) + + # Tolerances are dtype-dependent. The fused kernel performs the same + # matmul + log-softmax math as the reference, so fp32 inputs should agree + # to within a few ulps (~1e-5). bf16 / fp16 inputs are widened only to + # absorb the documented matmul-accumulation drift; anything looser would + # mask real numerical regressions. + if dtype == torch.float32: + rtol, atol = 1e-5, 1e-5 + elif dtype == torch.bfloat16: + rtol, atol = 2e-2, 2e-2 + else: # float16 + rtol, atol = 1e-2, 1e-2 + + torch.testing.assert_close( + fused_logprobs.float(), ref_logprobs.float(), rtol=rtol, atol=atol + ) + torch.testing.assert_close( + fused_entropy.float(), ref_entropy.float(), rtol=rtol, atol=atol + ) + + +@pytest.mark.parametrize("temperature", [0.7, 1.0, 1.5]) +def test_linear_cross_entropy_temperature(temperature: float) -> None: + """Temperature scaling matches the reference for non-trivial values.""" + from areal.models.kernel import linear_cross_entropy_logprobs_entropy + + hidden, weight, labels = _make_inputs( + num_tokens=128, hidden_size=512, vocab_size=4096, dtype=torch.float32 + ) + ref_lp, ref_h = _reference_logprobs_entropy(hidden, weight, labels, temperature) + fused_lp, fused_h = linear_cross_entropy_logprobs_entropy( + hidden, weight, labels, temperature=temperature + ) + # fp32 inputs: fused vs reference must agree to ~1e-5 (a few ulps). + torch.testing.assert_close(fused_lp, ref_lp, rtol=1e-5, atol=1e-5) + torch.testing.assert_close(fused_h, ref_h, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + # Small shape: catches obvious correctness bugs cheaply. + (64, 256, 2048), + # Medium shape: typical SFT microbatch. + (512, 1024, 32000), + # Large shape: stresses the fused backward at LLM-class dimensions + # where the materialised reference begins to dominate memory but is + # still fp32-tractable on a single GPU. This is the configuration + # most likely to surface accumulation-order bugs in d_hidden / + # d_weight reductions. + (2048, 2048, 32000), + ], + ids=["small_64x256x2048", "medium_512x1024x32k", "large_2048x2048x32k"], +) +def test_linear_cross_entropy_backward_matches_reference( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """Backward gradients on hidden/weight match autograd through the reference. + + Runs across small / medium / large shapes so that any accumulation-order + drift in the fused d_hidden / d_weight kernels is caught at scale rather + than only on toy inputs. + """ + from areal.models.kernel import linear_cross_entropy + + hidden_a, weight_a, labels = _make_inputs( + num_tokens, hidden_size, vocab_size, torch.float32 + ) + hidden_b = hidden_a.clone() + weight_b = weight_a.clone() + hidden_a.requires_grad_(True) + weight_a.requires_grad_(True) + hidden_b.requires_grad_(True) + weight_b.requires_grad_(True) + + # Reference path + logits = hidden_b @ weight_b.t() + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ref_h = -(probs * log_softmax).sum(dim=-1) + (ref_lp.sum() + 0.5 * ref_h.sum()).backward() + + # Fused path + fused_lp, fused_h = linear_cross_entropy( + hidden_a, weight_a, labels, 1.0, "none", None + ) + (fused_lp.sum() + 0.5 * fused_h.sum()).backward() + + # fp32 inputs: backward must match the reference to ~1e-4. The fused + # kernel's d_weight accumulates ``num_tokens`` partial products, so we + # use a slightly looser absolute tolerance for d_weight at the largest + # shape; rtol stays tight to catch directional errors. + torch.testing.assert_close(hidden_a.grad, hidden_b.grad, rtol=1e-4, atol=1e-4) + weight_atol = 1e-4 if num_tokens <= 512 else 5e-4 + torch.testing.assert_close( + weight_a.grad, weight_b.grad, rtol=1e-4, atol=weight_atol + ) + + +# --------------------------------------------------------------------------- +# Tensor-parallel (TP=2) correctness + performance +# +# These tests are invoked through pytest, while the 2-rank distributed body is +# launched with subprocess.run(["torchrun", ...]) following the repository's +# distributed-test pattern. Users do not need to run torchrun manually. +# --------------------------------------------------------------------------- + + +def _tp2_available() -> bool: + """Whether we can launch a 2-rank TP test on this host.""" + if not (CUDA_AVAILABLE and TRITON_AVAILABLE): + return False + if torch.cuda.device_count() < 2: + return False + return True + + +_tp2_skip = pytest.mark.skipif( + not _tp2_available(), reason="TP=2 requires >= 2 CUDA GPUs" +) + + +def _run_lce_tp2_with_torchrun( + test_type: str, + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: str = "bfloat16", +) -> None: + import subprocess + + from areal.utils.network import find_free_ports + + port = find_free_ports(1)[0] + try: + subprocess.run( + [ + "torchrun", + "--nproc_per_node=2", + "--nnodes=1", + "--master-addr=localhost", + f"--master_port={port}", + "tests/torchrun/run_lce_tp2.py", + f"--test_type={test_type}", + f"--num_tokens={num_tokens}", + f"--hidden_size={hidden_size}", + f"--vocab_size={vocab_size}", + f"--dtype={dtype}", + ], + check=True, + capture_output=True, + text=True, + ) + except subprocess.CalledProcessError as e: + pytest.fail( + f"TP=2 LCE torchrun test failed:\nSTDOUT:\n{e.stdout}\nSTDERR:\n{e.stderr}" + ) + + +@_tp2_skip +@pytest.mark.multi_gpu +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size,dtype_str", + [ + (128, 512, 8192, "float32"), + (256, 1024, 32000, "bfloat16"), + ], +) +def test_linear_cross_entropy_tp2_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype_str: str, +) -> None: + """TP=2 fused forward+backward matches a full-vocab reference. + + The 2-rank worker is launched via torchrun inside this pytest test, so the + caller can use a normal pytest command. + """ + _run_lce_tp2_with_torchrun( + "correctness", num_tokens, hidden_size, vocab_size, dtype_str + ) + + +@_tp2_skip +@pytest.mark.multi_gpu +@pytest.mark.slow +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + (1024, 1024, 32000), + (2048, 4096, 152064), + ], +) +def test_linear_cross_entropy_tp2_performance_benchmark( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """TP=2 fused vs TP-materialised forward+backward time and peak memory.""" + _run_lce_tp2_with_torchrun("performance", num_tokens, hidden_size, vocab_size) + + +# --------------------------------------------------------------------------- +# Performance benchmark (single-GPU) +# --------------------------------------------------------------------------- + + +def _peak_memory_mb(fn, *args, **kwargs) -> tuple[float, float]: + """Return (elapsed_ms, peak_mem_mb) of a single forward+backward pass.""" + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + out = fn(*args, **kwargs) + if isinstance(out, tuple): + loss = sum( + t.float().sum() for t in out if t.requires_grad or t.grad_fn is not None + ) + else: + loss = out.float().sum() + loss.backward() + end.record() + torch.cuda.synchronize() + elapsed = start.elapsed_time(end) + peak = torch.cuda.max_memory_allocated() / (1024 * 1024) + return elapsed, peak + + +def _run_reference_forward_backward(hidden, weight, labels, temperature): + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + logits = (h.float() @ w.float().t()) / temperature + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ent = -(probs * log_softmax).sum(dim=-1) + return lp, ent + + +def _run_fused_forward_backward(hidden, weight, labels, temperature): + from areal.models.kernel import linear_cross_entropy + + h = hidden.detach().clone().requires_grad_(True) + w = weight.detach().clone().requires_grad_(True) + return linear_cross_entropy(h, w, labels, temperature, "none", None) + + +@pytest.mark.slow +@pytest.mark.parametrize( + "num_tokens,hidden_size,vocab_size", + [ + # Small: validates the speedup is measurable even on toy shapes. + (1024, 1024, 32000), + # Medium: typical 7B-class one-microbatch shape. + (4096, 4096, 128256), + # Large vocab: where fused kernel really wins (e.g. Qwen3). + (2048, 4096, 152064), + ], +) +def test_linear_cross_entropy_performance_benchmark( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + """Compare fused vs materialised forward+backward time and peak memory. + + Failures here mean the fused path *regressed* against the reference; the + captured numbers are also printed for human review. + """ + dtype = torch.bfloat16 + hidden, weight, labels = _make_inputs(num_tokens, hidden_size, vocab_size, dtype) + + # warm-up + for _ in range(2): + lp, ent = _run_reference_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + del lp, ent + gc.collect() + torch.cuda.empty_cache() + for _ in range(2): + lp, ent = _run_fused_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + del lp, ent + gc.collect() + torch.cuda.empty_cache() + + # Reference timing + ref_times = [] + ref_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + lp, ent = _run_reference_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + ref_times.append(start.elapsed_time(end)) + ref_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent + + # Fused timing + fused_times = [] + fused_mems = [] + for _ in range(5): + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + lp, ent = _run_fused_forward_backward(hidden, weight, labels, 1.0) + (lp.sum() + ent.sum()).backward() + end.record() + torch.cuda.synchronize() + fused_times.append(start.elapsed_time(end)) + fused_mems.append(torch.cuda.max_memory_allocated() / (1024 * 1024)) + del lp, ent + + ref_med = sorted(ref_times)[len(ref_times) // 2] + fused_med = sorted(fused_times)[len(fused_times) // 2] + ref_peak = max(ref_mems) + fused_peak = max(fused_mems) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + + print( + f"\n[LCE-Bench] tokens={num_tokens} hidden={hidden_size} vocab={vocab_size} " + f"dtype={dtype}\n" + f" reference: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" + f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" + f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" + ) + + # Soft assertions: fused path must not be drastically slower or more + # memory-hungry. Tight thresholds would cause flaky CI on shared GPUs. + assert fused_med < ref_med * 1.5, ( + f"Fused LCE is more than 1.5x slower than reference " + f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms). Please investigate." + ) + assert fused_peak < ref_peak * 1.2, ( + f"Fused LCE peak memory exceeds reference by >20% " + f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + ) diff --git a/tests/torchrun/run_lce_tp2.py b/tests/torchrun/run_lce_tp2.py new file mode 100644 index 0000000000..e489b57c29 --- /dev/null +++ b/tests/torchrun/run_lce_tp2.py @@ -0,0 +1,267 @@ +import argparse +import gc +import math +import os + +import torch +import torch.distributed as dist + +from areal.infra.platforms import current_platform +from areal.models.kernel import linear_cross_entropy +from areal.utils.functional import gather_logprobs_entropy + + +def _setup_distributed_environment() -> None: + if dist.is_initialized(): + return + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + master_addr = os.environ.get("MASTER_ADDR", "localhost") + master_port = os.environ["MASTER_PORT"] + dist.init_process_group( + backend="nccl", + init_method=f"tcp://{master_addr}:{master_port}", + world_size=world_size, + rank=rank, + ) + current_platform.set_device(rank) + + +def _get_tp_group() -> dist.ProcessGroup: + return dist.distributed_c10d._get_default_group() + + +def _make_tp_inputs( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, + device: str, + seed: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + rank = dist.get_rank() + world_size = dist.get_world_size() + vocab_per_rank = vocab_size // world_size + assert vocab_size % world_size == 0 + + g = torch.Generator(device=device).manual_seed(seed) + hidden = ( + torch.randn(num_tokens, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + labels = torch.randint(0, vocab_size, (num_tokens,), device=device, generator=g) + weight_full = ( + torch.randn(vocab_size, hidden_size, dtype=dtype, device=device, generator=g) + * 0.02 + ) + weight_shard = weight_full[ + rank * vocab_per_rank : (rank + 1) * vocab_per_rank + ].contiguous() + return ( + hidden.contiguous(), + labels.contiguous(), + weight_full.contiguous(), + weight_shard, + ) + + +def _run_full_reference( + hidden: torch.Tensor, + weight: torch.Tensor, + labels: torch.Tensor, + entropy_weight: float, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + hidden_ref = hidden.detach().clone().requires_grad_(True) + weight_ref = weight.detach().clone().requires_grad_(True) + logits = hidden_ref.float() @ weight_ref.float().t() + log_softmax = torch.nn.functional.log_softmax(logits, dim=-1) + ref_lp = log_softmax.gather(dim=-1, index=labels.unsqueeze(-1)).squeeze(-1) + probs = log_softmax.exp() + ref_h = -(probs * log_softmax).sum(dim=-1) + (ref_lp.sum() + entropy_weight * ref_h.sum()).backward() + return ref_lp, ref_h, hidden_ref.grad, weight_ref.grad + + +def _run_tp_materialized_step( + hidden: torch.Tensor, + weight_shard: torch.Tensor, + labels: torch.Tensor, + tp_group: dist.ProcessGroup, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + local_logits = h.float() @ w.float().t() + lp, ent = gather_logprobs_entropy(local_logits, labels, tp_group=tp_group) + (lp.sum() + ent.sum()).backward() + return lp, ent, h.grad, w.grad + + +def _run_fused_step( + hidden: torch.Tensor, + weight_shard: torch.Tensor, + labels: torch.Tensor, + tp_group: dist.ProcessGroup, + entropy_weight: float = 1.0, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + h = hidden.detach().clone().requires_grad_(True) + w = weight_shard.detach().clone().requires_grad_(True) + lp, ent = linear_cross_entropy(h, w, labels, 1.0, "none", tp_group) + (lp.sum() + entropy_weight * ent.sum()).backward() + return lp, ent, h.grad, w.grad + + +def _test_tp2_correctness( + num_tokens: int, + hidden_size: int, + vocab_size: int, + dtype: torch.dtype, +) -> None: + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size == 2 + device = current_platform.current_device() + tp_group = _get_tp_group() + + hidden, labels, weight_full, weight_shard = _make_tp_inputs( + num_tokens, hidden_size, vocab_size, dtype, device, seed=42 + ) + vocab_per_rank = vocab_size // world_size + + ref_lp, ref_h, ref_dh, ref_dw = _run_full_reference( + hidden, weight_full, labels, entropy_weight=0.5 + ) + fused_lp, fused_h, fused_dh, fused_dw = _run_fused_step( + hidden, weight_shard, labels, tp_group, entropy_weight=0.5 + ) + + if dtype == torch.float32: + rtol, atol = 2e-4, 2e-4 + else: + rtol, atol = 3e-2, 3e-2 + + torch.testing.assert_close(fused_lp.float(), ref_lp.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_h.float(), ref_h.float(), rtol=rtol, atol=atol) + torch.testing.assert_close(fused_dh.float(), ref_dh.float(), rtol=rtol, atol=atol) + torch.testing.assert_close( + fused_dw.float(), + ref_dw[rank * vocab_per_rank : (rank + 1) * vocab_per_rank].float(), + rtol=rtol, + atol=atol, + ) + + if rank == 0: + print( + f"[PASS] tp2_correctness: T={num_tokens} H={hidden_size} " + f"V={vocab_size} dtype={dtype}" + ) + + +def _time_step(fn) -> tuple[float, float]: + torch.cuda.synchronize() + torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + fn() + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end), torch.cuda.max_memory_allocated() / (1024 * 1024) + + +def _test_tp2_performance( + num_tokens: int, + hidden_size: int, + vocab_size: int, +) -> None: + rank = dist.get_rank() + world_size = dist.get_world_size() + assert world_size == 2 + device = current_platform.current_device() + dtype = torch.bfloat16 + tp_group = _get_tp_group() + + hidden, labels, _, weight_shard = _make_tp_inputs( + num_tokens, hidden_size, vocab_size, dtype, device, seed=0 + ) + + for _ in range(2): + _run_fused_step(hidden, weight_shard, labels, tp_group) + gc.collect() + torch.cuda.empty_cache() + for _ in range(2): + _run_tp_materialized_step(hidden, weight_shard, labels, tp_group) + gc.collect() + torch.cuda.empty_cache() + + fused_times = [] + fused_mems = [] + for _ in range(5): + t, m = _time_step( + lambda: _run_fused_step(hidden, weight_shard, labels, tp_group) + ) + fused_times.append(t) + fused_mems.append(m) + + ref_times = [] + ref_mems = [] + for _ in range(5): + t, m = _time_step( + lambda: _run_tp_materialized_step(hidden, weight_shard, labels, tp_group) + ) + ref_times.append(t) + ref_mems.append(m) + + ref_med = sorted(ref_times)[len(ref_times) // 2] + fused_med = sorted(fused_times)[len(fused_times) // 2] + ref_peak = max(ref_mems) + fused_peak = max(fused_mems) + speedup = ref_med / fused_med if fused_med > 0 else math.inf + mem_ratio = fused_peak / ref_peak if ref_peak > 0 else math.inf + + if rank == 0: + print( + f"\n[LCE-TP2-Bench] tokens={num_tokens} hidden={hidden_size} " + f"vocab={vocab_size} dtype={dtype}\n" + f" tp materialized: {ref_med:7.2f} ms / {ref_peak:7.1f} MB peak\n" + f" fused : {fused_med:7.2f} ms / {fused_peak:7.1f} MB peak\n" + f" speedup : {speedup:5.2f}x memory_ratio: {mem_ratio:5.2f}x" + ) + + assert fused_med < ref_med * 1.5, ( + f"Fused TP=2 LCE is more than 1.5x slower than TP materialized reference " + f"(fused={fused_med:.2f}ms ref={ref_med:.2f}ms)." + ) + assert fused_peak < ref_peak * 1.2, ( + f"Fused TP=2 LCE peak memory exceeds TP materialized reference by >20% " + f"(fused={fused_peak:.1f}MB ref={ref_peak:.1f}MB)." + ) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument( + "--test_type", choices=["correctness", "performance"], required=True + ) + parser.add_argument("--num_tokens", type=int, required=True) + parser.add_argument("--hidden_size", type=int, required=True) + parser.add_argument("--vocab_size", type=int, required=True) + parser.add_argument("--dtype", choices=["float32", "bfloat16"], default="bfloat16") + args = parser.parse_args() + + dtype = {"float32": torch.float32, "bfloat16": torch.bfloat16}[args.dtype] + _setup_distributed_environment() + try: + if args.test_type == "correctness": + _test_tp2_correctness( + args.num_tokens, args.hidden_size, args.vocab_size, dtype + ) + else: + _test_tp2_performance(args.num_tokens, args.hidden_size, args.vocab_size) + finally: + if dist.is_initialized(): + dist.destroy_process_group() + + +if __name__ == "__main__": + main()