Skip to content

feat: Support Linear Cross Entropy fuse kernel#1322

Open
TaoZex wants to merge 31 commits into
areal-project:mainfrom
TaoZex:lm
Open

feat: Support Linear Cross Entropy fuse kernel#1322
TaoZex wants to merge 31 commits into
areal-project:mainfrom
TaoZex:lm

Conversation

@TaoZex
Copy link
Copy Markdown
Collaborator

@TaoZex TaoZex commented May 10, 2026

Description

Adds a fused Linear Cross Entropy (LCE) path for Megatron training to avoid materialising full [tokens, vocab] logits.

Key changes:

  • Adds Triton-based fused LCE forward/backward for per-token logprobs and entropy.
  • Integrates fused LCE into Megatron via LM-head hidden/weight capture.
  • Supports tensor-parallel vocab sharding, including TP forward reductions and d_hidden all-reduce in backward.
  • Keeps safe fallback to the materialised reference path when fused LCE is unavailable.
  • Adds focused correctness, TP, and benchmark coverage for fused vs materialised LCE.

Related Issue

Fixes #TBD

Type of Change

  • 🐛 Bug fix
  • ✨ New feature
  • 💥 Breaking change
  • 📝 Documentation update
  • ♻️ Refactoring
  • ⚡ Performance improvement
  • ✅ Test coverage improvement

Checklist

  • I have read the Contributing Guide
  • Pre-commit hooks pass (pre-commit run --all-files)
  • Relevant tests pass; new tests added for new functionality
  • Documentation updated (if applicable; built with ./docs/build_all.sh)
  • Branch is up to date with main
  • Self-reviewed via /review-pr command
  • This PR was created by a coding agent via /create-pr
  • This PR is a breaking change

Breaking Change Details (if applicable):

N/A

Additional Context

Key files:

  • areal/utils/kernel/kernels.py: implements the Triton fused LCE kernels, including forward logprob/entropy computation and split-N backward.
  • areal/utils/kernel/linear_cross_entropy.py: exposes the fused LCE autograd function and handles TP d_hidden all-reduce in backward.
  • areal/utils/functional/linear_cross_entropy.py: provides AReaL-facing wrappers with fallback to the materialised reference path.
  • areal/engine/megatron_utils/fused_lce_capture.py: captures LM-head hidden states and weights without materialising logits.
  • areal/engine/megatron_engine.py: wires fused LCE into the Megatron training/logprob path behind actor.use_fused_linear_ce.
  • tests/test_linear_cross_entropy.py and tests/torchrun/run_lce_tp2.py: cover single-GPU and TP=2 correctness/performance checks.
  • benchmark/bench_linear_cross_entropy.py: provides standalone fused vs materialised latency/memory benchmarking, including TP mode.

Need help? Check the Contributing Guide or ask in
https://github.com/inclusionAI/AReaL/discussions !

@TaoZex TaoZex marked this pull request as ready for review May 10, 2026 04:30
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a fused linear cross-entropy (LCE) optimization using Triton kernels to avoid materializing large logit tensors, which significantly reduces memory overhead and improves performance for the Megatron backend. The implementation includes a context manager for capturing hidden states, a custom autograd function, and comprehensive benchmarking and testing suites. Review feedback identifies opportunities to improve telemetry accuracy by using actual logit values instead of logprobs, suggests moving kernel alignment assertions into a compatibility check for graceful fallbacks, and recommends using in-place operations in the backward pass for better efficiency.

Comment thread areal/engine/megatron_engine.py Outdated
Comment thread areal/utils/kernel/kernels.py
Comment thread areal/utils/kernel/kernels.py
@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 10, 2026

tests/test_linear_cross_entropy.py

This file validates fused Linear Cross Entropy (LCE) across correctness, gradients, TP=2 distributed execution, and performance.

Test Coverage

  • Forward correctness

    • Checks fused logprobs and entropy against the materialized logits -> log_softmax reference.
    • Covers float32, bfloat16, float16.
    • Covers temperatures 0.7, 1.0, 1.5.
  • Backward correctness

    • Checks fused hidden.grad and weight.grad against PyTorch autograd reference.
    • Covers small, medium, and large shapes:
      • 64 x 256 x 2048
      • 512 x 1024 x 32000
      • 2048 x 2048 x 32000
  • TP=2 correctness and performance

    • Launches torchrun --nproc_per_node=2 internally via subprocess.run.
    • Users can run it with normal pytest; no manual torchrun is needed.
    • Validates fused TP=2 forward/backward against the full-vocab reference.
  • Single-GPU performance

    • Compares fused vs materialized forward + backward latency and peak memory.
    • Covers representative LLM shapes, including large-vocab Qwen-style cases.

Accuracy Tolerances

The checks use tight numerical tolerances:

Case rtol atol
Forward float32 1e-5 1e-5
Forward bfloat16 2e-2 2e-2
Forward float16 1e-2 1e-2
Temperature float32 1e-5 1e-5
Backward hidden.grad 1e-4 1e-4
Backward weight.grad small/medium 1e-4 1e-4
Backward weight.grad large 1e-4 5e-4

These tolerances are strict enough to catch real numerical regressions while allowing expected low-precision accumulation drift.

Below are the test results:
image

@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 10, 2026

End-to-End Evaluation on H20

Qwen3-0.6B, TP=1

Task Reward

image
  • The task reward trend is almost identical between the baseline and fused LCE runs.
  • This confirms that the fused LCE path preserves training behavior and keeps accuracy-related metrics aligned with the baseline.

LCE Optimization

The fused LCE path significantly improves both step time and peak memory usage.
image

Metric Value
Baseline average step time 38.2 s
Fused LCE average step time 27.1 s
Average step time reduction 11.1 s
Average speedup percentage 29.52%
Baseline peak memory 55.00 GB
Fused LCE peak memory 32.92 GB
Peak memory reduction 22.08 GB

LCE Optimization Result: fused LCE achieves a 29.52% step-time reduction and saves 22.08 GB peak memory for Qwen3-0.6B with tp=1.


Qwen3-8B, TP=2

Task Reward

image
  • The task reward trend remains aligned with the baseline.
  • This further confirms that fused LCE does not introduce visible training quality regression under TP=2.

LCE Optimization

With tp=2, the vocabulary is split across two GPUs, so each rank handles roughly half of the vocab-related LCE workload. As expected, the step-time reduction is smaller than the tp=1 case.
image

Metric Value
Baseline average step time 101.9 s
Fused LCE average step time 96.8 s
Average step time reduction 5.1 s
Average speedup percentage 5.02%

LCE Optimization Result: fused LCE still provides a measurable 5.02% step-time reduction under tp=2, while preserving the same task reward trend as the baseline.

@TaoZex
Copy link
Copy Markdown
Collaborator Author

TaoZex commented May 10, 2026

Benchmark Purpose

The benchmark benchmark/bench_linear_cross_entropy.py is used to compare the fused Linear Cross Entropy path against the materialized reference path.

It measures:

  • Forward + backward latency
  • Peak GPU memory usage
  • Fused vs baseline speedup
  • Memory reduction from avoiding [tokens, vocab] logits materialization
  • Single-GPU and TP multi-GPU behavior

This provides a focused way to evaluate the kernel-level benefit independently from full end-to-end training noise.

Below is a benchmark example:
image


Future FSDP Adaptation

The current pull request implements the core fused LCE capability and integrates it into the Megatron engine first.

After this PR is merged, the same design can be adapted to the FSDP engine.

Copy link
Copy Markdown
Collaborator

@garrett4wade garrett4wade left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM except that several coding style issues. We can make the code look much more better.

Besides, please fix the pre-commit error with pre-commit run --all-files

Comment thread areal/api/cli_args.py
Comment on lines +1146 to +1147
use_fused_linear_ce: bool = field(
default=False,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's move this field to MegatronEngineConfig for now.

We should also note that "vocab_min_logits" stats won't be available when enabling this feature. Could you please open an issue about this and call for community's fix?

Comment on lines +2119 to +2123
if (
fused_active
and fused_hidden is not None
and fused_weight is not None
):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The if-else indentation is too deep here. Try extracting the whole if-else code into a separate granular method and use early return to eliminate the else.

Comment on lines +917 to +923
fused_weight = mb_input.orig_mb.get(FUSED_LCE_WEIGHT_KEY)
if (
fused_weight is not None
and output.dtype != fused_weight.dtype
):
output = output.to(fused_weight.dtype)
mb_input.orig_mb[FUSED_LCE_HIDDEN_KEY] = output
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we usually require fp32 logits, will this downcast operation cause a precision issue?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move it to areal/models/kernel? Maybe areal/utils/functional should be migrated too.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about benchmark/kernels/**?

experiment_name: ${experiment_name}
trial_name: ${trial_name}
path: Qwen/Qwen2.5-1.5B-Instruct
path: /workspace/models/Qwen3-0.6B
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should revert

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants