feat: Support Linear Cross Entropy fuse kernel#1322
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces a fused linear cross-entropy (LCE) optimization using Triton kernels to avoid materializing large logit tensors, which significantly reduces memory overhead and improves performance for the Megatron backend. The implementation includes a context manager for capturing hidden states, a custom autograd function, and comprehensive benchmarking and testing suites. Review feedback identifies opportunities to improve telemetry accuracy by using actual logit values instead of logprobs, suggests moving kernel alignment assertions into a compatibility check for graceful fallbacks, and recommends using in-place operations in the backward pass for better efficiency.
|
| Case | rtol | atol |
|---|---|---|
Forward float32 |
1e-5 |
1e-5 |
Forward bfloat16 |
2e-2 |
2e-2 |
Forward float16 |
1e-2 |
1e-2 |
Temperature float32 |
1e-5 |
1e-5 |
Backward hidden.grad |
1e-4 |
1e-4 |
Backward weight.grad small/medium |
1e-4 |
1e-4 |
Backward weight.grad large |
1e-4 |
5e-4 |
These tolerances are strict enough to catch real numerical regressions while allowing expected low-precision accumulation drift.
Benchmark PurposeThe benchmark It measures:
This provides a focused way to evaluate the kernel-level benefit independently from full end-to-end training noise. Future FSDP AdaptationThe current pull request implements the core fused LCE capability and integrates it into the Megatron engine first. After this PR is merged, the same design can be adapted to the FSDP engine. |
garrett4wade
left a comment
There was a problem hiding this comment.
LGTM except that several coding style issues. We can make the code look much more better.
Besides, please fix the pre-commit error with pre-commit run --all-files
| use_fused_linear_ce: bool = field( | ||
| default=False, |
There was a problem hiding this comment.
Let's move this field to MegatronEngineConfig for now.
We should also note that "vocab_min_logits" stats won't be available when enabling this feature. Could you please open an issue about this and call for community's fix?
| if ( | ||
| fused_active | ||
| and fused_hidden is not None | ||
| and fused_weight is not None | ||
| ): |
There was a problem hiding this comment.
The if-else indentation is too deep here. Try extracting the whole if-else code into a separate granular method and use early return to eliminate the else.
| fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY) | ||
| if ( | ||
| fused_weight is not None | ||
| and output.dtype != fused_weight.dtype | ||
| ): | ||
| output = output.to(fused_weight.dtype) | ||
| mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output |
There was a problem hiding this comment.
Since we usually require fp32 logits, will this downcast operation cause a precision issue?
There was a problem hiding this comment.
Can we move it to areal/models/kernel? Maybe areal/utils/functional should be migrated too.
There was a problem hiding this comment.
How about benchmark/kernels/**?
| experiment_name: ${experiment_name} | ||
| trial_name: ${trial_name} | ||
| path: Qwen/Qwen2.5-1.5B-Instruct | ||
| path: /workspace/models/Qwen3-0.6B |






Description
Adds a fused Linear Cross Entropy (LCE) path for Megatron training to avoid materialising full
[tokens, vocab]logits.Key changes:
logprobsand entropy.d_hiddenall-reduce in backward.Related Issue
Fixes #TBD
Type of Change
Checklist
pre-commit run --all-files)./docs/build_all.sh)main/review-prcommand/create-prBreaking Change Details (if applicable):
N/A
Additional Context
Key files:
areal/utils/kernel/kernels.py: implements the Triton fused LCE kernels, including forward logprob/entropy computation and split-N backward.areal/utils/kernel/linear_cross_entropy.py: exposes the fused LCE autograd function and handles TPd_hiddenall-reduce in backward.areal/utils/functional/linear_cross_entropy.py: provides AReaL-facing wrappers with fallback to the materialised reference path.areal/engine/megatron_utils/fused_lce_capture.py: captures LM-head hidden states and weights without materialising logits.areal/engine/megatron_engine.py: wires fused LCE into the Megatron training/logprob path behindactor.use_fused_linear_ce.tests/test_linear_cross_entropy.pyandtests/torchrun/run_lce_tp2.py: cover single-GPU and TP=2 correctness/performance checks.benchmark/bench_linear_cross_entropy.py: provides standalone fused vs materialised latency/memory benchmarking, including TP mode.Need help? Check the Contributing Guide or ask in
https://github.com/inclusionAI/AReaL/discussions!