-
Notifications
You must be signed in to change notification settings - Fork 495
feat: Support Linear Cross Entropy fuse kernel #1322
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 25 commits
c2852f3
fc5211b
e494738
9495141
77ea280
47fd3e2
203a7e4
e2e5d70
fecfcbd
0b5eefe
e69674e
d444cbe
07f03d8
003296b
0020faa
fa82bb9
4e9f8c7
64b91a4
92c4298
c640e03
8b1a243
a6952f3
6346396
0b044e5
5e35bbe
09eff54
825d44c
8aaff93
292dab1
8addb48
c5525d6
9a6b467
7fcc99b
fa91e9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -59,6 +59,11 @@ | |
| from areal.engine.megatron_utils.checkpointer import MegatronCheckpointManager | ||
| from areal.engine.megatron_utils.deterministic import set_deterministic_algorithms | ||
| from areal.engine.megatron_utils.fp8 import FP8BlockwiseTensorHelper | ||
| from areal.engine.megatron_utils.fused_lce_capture import ( | ||
| FUSED_LCE_HIDDEN_KEY, | ||
| FUSED_LCE_WEIGHT_KEY, | ||
| capture_lm_head_hidden, | ||
| ) | ||
| from areal.engine.megatron_utils.megatron import ( | ||
| all_gather_param, | ||
| convert_to_hf, | ||
|
|
@@ -106,7 +111,12 @@ | |
| split_padded_tensor_dict_into_mb_list, | ||
| unpad_logits, | ||
| ) | ||
| from areal.utils.functional import gather_logprobs, gather_logprobs_entropy | ||
| from areal.utils.functional import ( | ||
| gather_logprobs, | ||
| gather_logprobs_entropy, | ||
| linear_cross_entropy_logprobs, | ||
| linear_cross_entropy_logprobs_entropy, | ||
| ) | ||
| from areal.utils.hf_utils import load_hf_tokenizer | ||
| from areal.utils.lock import DistributedLock | ||
| from areal.utils.network import find_free_ports, format_host_for_url, gethostip | ||
|
|
@@ -710,6 +720,12 @@ def forward_backward_batch( | |
| ) -> None: | ||
| self._ensure_ready() | ||
|
|
||
| use_fused_lce = ( | ||
| getattr(self.config, "use_fused_linear_ce", False) | ||
| and not self.config.is_critic | ||
| and not self.enable_tree_training | ||
| ) | ||
|
|
||
| def forward_step(batch_iter, model): | ||
| mb_input: MicroBatchItem = next(batch_iter) | ||
|
|
||
|
|
@@ -740,12 +756,33 @@ def forward_step(batch_iter, model): | |
| cp_size = mpu.get_context_parallel_world_size() | ||
| cp_local = cp_size > 1 | ||
|
|
||
| output = packed_context_parallel_forward( | ||
| model, | ||
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| model_vp_stage_for_capture = getattr(model, "vp_stage", 0) | ||
| should_capture = ( | ||
| use_fused_lce | ||
| and mpu.is_pipeline_last_stage( | ||
| ignore_virtual=False, vp_stage=model_vp_stage_for_capture | ||
| ) | ||
| and not cp_local | ||
| ) | ||
|
|
||
| with capture_lm_head_hidden( | ||
| model, enabled=should_capture | ||
| ) as capture: | ||
| output = packed_context_parallel_forward( | ||
| model, | ||
| mb_input.padded_mb, | ||
| gather_cp_output=not cp_local, | ||
| ) | ||
|
|
||
| if ( | ||
| capture is not None | ||
| and capture.hidden is not None | ||
| and capture.weight is not None | ||
| ): | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = capture.hidden | ||
| mb_input.orig_mb[FUSED_LCE_WEIGHT_KEY] = capture.weight | ||
| mb_input.orig_mb["_fused_lce_active"] = True | ||
|
|
||
| # Release tree attention metadata after forward pass | ||
| for key in tree_attn_keys: | ||
| del mb_input.padded_mb[key] | ||
|
|
@@ -784,6 +821,17 @@ def _process_output(input_, output_): | |
| cu_seqlens=cu_seqlens, | ||
| old_cu_seqlens=mb_input.old_cu_seqlens, | ||
| ) | ||
| # Re-align Float16Module's fp32 hidden to lm-head weight dtype. | ||
| if mb_input.orig_mb.get("_fused_lce_active", False): | ||
| fused_weight = mb_input.orig_mb.get( | ||
| FUSED_LCE_WEIGHT_KEY | ||
| ) | ||
| if ( | ||
| fused_weight is not None | ||
| and output.dtype != fused_weight.dtype | ||
| ): | ||
| output = output.to(fused_weight.dtype) | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output | ||
|
Comment on lines
+919
to
+925
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we usually require fp32 logits, will this downcast operation cause a precision issue?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The fused LCE kernel internally accumulates the matrix multiplication in In practice, the non-fused computation path follows: In contrast, the fused path maintains |
||
| return output, functools.partial(_process_output, mb_input.orig_mb) | ||
|
|
||
| forward_backward_func = get_forward_backward_func() | ||
|
|
@@ -845,7 +893,9 @@ def process_output( | |
| ) | ||
|
|
||
| # Step 4: Optimizer step | ||
| return self.optimizer_step() | ||
| result = self.optimizer_step() | ||
|
|
||
| return result | ||
|
|
||
| @torch.no_grad() | ||
| def eval_batch( | ||
|
|
@@ -1814,16 +1864,37 @@ def _compute_logprobs_and_loss( | |
| labels = cp_local_labels | ||
| else: | ||
| labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) | ||
| logprobs, entropy = gather_logprobs_entropy( | ||
| output, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| vocab_min_logits = output.detach().min(-1).values.float() | ||
| vocab_max_logits = output.detach().max(-1).values.float() | ||
| fused_active = inputs.get("_fused_lce_active", False) | ||
| fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) | ||
| fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_active | ||
| and fused_hidden is not None | ||
| and fused_weight is not None | ||
| ): | ||
|
TaoZex marked this conversation as resolved.
Outdated
|
||
| logprobs, entropy = linear_cross_entropy_logprobs_entropy( | ||
| fused_hidden, | ||
| fused_weight, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| proxy = logprobs.detach().float() | ||
| vocab_min_logits = proxy | ||
| vocab_max_logits = proxy | ||
|
TaoZex marked this conversation as resolved.
Outdated
|
||
| else: | ||
| logprobs, entropy = gather_logprobs_entropy( | ||
| output, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| vocab_min_logits = output.detach().min(-1).values.float() | ||
| vocab_max_logits = output.detach().max(-1).values.float() | ||
| loss = loss_fn( | ||
| logprobs, | ||
| entropy, | ||
|
|
@@ -1860,6 +1931,24 @@ def _compute_forward_result( | |
| ) | ||
| return logprobs | ||
| labels = torch.roll(inputs["input_ids"], shifts=-1, dims=-1) | ||
| fused_active = inputs.get("_fused_lce_active", False) | ||
| fused_hidden = inputs.get(FUSED_LCE_HIDDEN_KEY) | ||
| fused_weight = inputs.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_active | ||
| and fused_hidden is not None | ||
| and fused_weight is not None | ||
| ): | ||
| logprobs = linear_cross_entropy_logprobs( | ||
| fused_hidden, | ||
| fused_weight, | ||
| labels, | ||
| temperature=self.config.temperature, | ||
| tp_group=mpu.get_tensor_model_parallel_group() | ||
| if mpu.get_tensor_model_parallel_world_size() > 1 | ||
| else None, | ||
| ) | ||
| return logprobs | ||
| logprobs = gather_logprobs( | ||
| output, | ||
| labels, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| """ | ||
| LM-head hidden-state capture for the fused linear-cross-entropy fast path. | ||
|
|
||
| The fused LCE kernel needs ``(hidden, weight)`` instead of materialised | ||
| ``[seq, vocab]`` logits. This module temporarily monkey-patches | ||
| ``output_layer.forward`` to capture those tensors for one microbatch. | ||
|
|
||
| Compatibility: incompatible with MuP (``use_mup``), MTP | ||
| (``mtp_num_layers > 0``), and critic heads. The engine falls back to | ||
| the materialised path automatically when any of these conditions hold. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Iterator | ||
| from contextlib import contextmanager | ||
| from dataclasses import dataclass | ||
|
|
||
| import torch | ||
| from megatron.core import parallel_state as mpu | ||
| from megatron.core.tensor_parallel.mappings import ( | ||
| gather_from_sequence_parallel_region, | ||
| ) | ||
|
|
||
| from areal.utils import logging | ||
|
|
||
| logger = logging.getLogger("FusedLCECapture") | ||
|
|
||
| FUSED_LCE_HIDDEN_KEY = "_fused_lce_hidden" | ||
| FUSED_LCE_WEIGHT_KEY = "_fused_lce_weight" | ||
|
|
||
|
|
||
| @dataclass | ||
| class _CaptureSlot: | ||
| hidden: torch.Tensor | None = None | ||
| weight: torch.Tensor | None = None | ||
|
|
||
|
|
||
| def _unwrap_to_post_process_module(model: torch.nn.Module) -> torch.nn.Module | None: | ||
| inner = model | ||
| for _ in range(8): | ||
| if hasattr(inner, "output_layer") and inner.output_layer is not None: | ||
| return inner | ||
| if not hasattr(inner, "module"): | ||
| return None | ||
| inner = inner.module | ||
| return None | ||
|
|
||
|
|
||
| def _is_compatible(post_process_module: torch.nn.Module) -> bool: | ||
| config = getattr(post_process_module, "config", None) | ||
| if config is None: | ||
| return False | ||
|
|
||
| if getattr(config, "use_mup", False): | ||
| logger.warning( | ||
| "Fused LCE disabled: MuP scaling is enabled (config.use_mup=True)." | ||
| ) | ||
| return False | ||
| if getattr(config, "mtp_num_layers", 0): | ||
| logger.warning( | ||
| "Fused LCE disabled: MTP is enabled (config.mtp_num_layers>0)." | ||
| ) | ||
| return False | ||
|
|
||
| output_layer = getattr(post_process_module, "output_layer", None) | ||
| if output_layer is None: | ||
| return False | ||
|
|
||
| parallel_output = getattr(post_process_module, "parallel_output", True) | ||
| if not parallel_output: | ||
| logger.warning( | ||
| "Fused LCE disabled: model has parallel_output=False; " | ||
| "would require an extra TP gather." | ||
| ) | ||
| return False | ||
|
|
||
| return True | ||
|
|
||
|
|
||
| @contextmanager | ||
| def capture_lm_head_hidden( | ||
| model: torch.nn.Module, *, enabled: bool | ||
| ) -> Iterator[_CaptureSlot | None]: | ||
| if not enabled: | ||
| yield None | ||
| return | ||
|
|
||
| post_process = _unwrap_to_post_process_module(model) | ||
| if post_process is None or not _is_compatible(post_process): | ||
| yield None | ||
| return | ||
|
|
||
| output_layer = post_process.output_layer | ||
| slot = _CaptureSlot() | ||
| original_forward = output_layer.forward | ||
|
|
||
| config = getattr(post_process, "config", None) | ||
| sequence_parallel = bool(getattr(config, "sequence_parallel", False)) | ||
| tp_world_size = mpu.get_tensor_model_parallel_world_size() | ||
| needs_sp_gather = sequence_parallel and tp_world_size > 1 | ||
|
|
||
| def _patched_forward(input_, weight=None, runtime_gather_output=None): | ||
| actual_weight = weight if weight is not None else output_layer.weight | ||
|
|
||
| hidden = input_ | ||
| if needs_sp_gather: | ||
| hidden = gather_from_sequence_parallel_region(hidden) | ||
|
|
||
| if hidden.dtype != actual_weight.dtype: | ||
| hidden = hidden.to(actual_weight.dtype) | ||
|
|
||
| slot.hidden = hidden | ||
| slot.weight = actual_weight | ||
| return hidden, None | ||
|
|
||
| output_layer.forward = _patched_forward # type: ignore[assignment] | ||
| try: | ||
| yield slot | ||
| finally: | ||
| try: | ||
| del output_layer.forward | ||
| except AttributeError: | ||
| output_layer.forward = original_forward # type: ignore[assignment] | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "FUSED_LCE_HIDDEN_KEY", | ||
| "FUSED_LCE_WEIGHT_KEY", | ||
| "capture_lm_head_hidden", | ||
| ] |
Uh oh!
There was an error while loading. Please reload this page.