Skip to content

WAR sort_chunks_by_index intermittent failures in L0 JAX unitttest#2730

Merged
tdophung merged 2 commits into
NVIDIA:mainfrom
tdophung:sort-chunks-WAR
Mar 5, 2026
Merged

WAR sort_chunks_by_index intermittent failures in L0 JAX unitttest#2730
tdophung merged 2 commits into
NVIDIA:mainfrom
tdophung:sort-chunks-WAR

Conversation

@tdophung
Copy link
Copy Markdown
Collaborator

@tdophung tdophung commented Mar 4, 2026

WAR intermittent mismatch failures in sort_chunks_by_index sort_chunk_by_map bwd function (part 1/2)

Description

There is a bug in our CI where the following symptoms were observed:

  • ~50% element mismatch in backward pass gradients (computed_grad != ref_grad)
  • Forward pass and row_id_map are correct
  • Fails ~30% overall across systems, but deterministically when running the full L0 test suite via stress_test.sh (all tests/jax/ in one pytest process)
  • Never fails when running pytest test_permutation.py in isolation

After benig able to reproduce this on B200, only when running the whole L0_jax_unittest/test.sh, I found 3
interacting factors combine to produce the corruption:

1. XLA implicit buffer aliasing

XLA's buffer assignment (xla/service/buffer_assignment.cc) can assign the same GPU memory to the input (output_grad) and output (inp_grad) of the backward sort_chunks_by_map custom call. This happens because:

  • They have the same shape
  • XLA determines the input is dead after the operation

(*) We don't know if this is XLA fault or Python yet.

2. Permutation kernels cannot operate in-place

_sort_chunks_by_map_kernel (in transformer_engine/common/triton/permutation.py) reads from src_row and writes to dst_row where src_row != dst_row. When input and output share the same buffer, GPU thread blocks execute in waves — early blocks overwrite data that later blocks haven't read yet, causing corruption.

3. Triton autotuning amplifies corruption

TritonAutotunedKernelCall (in jaxlib/gpu/triton_kernels.cc) runs the kernel multiple times with different configs. It normally saves/restores aliased buffers between runs using input_output_aliases_with_sizes. But since no explicit aliases are declared for sort_chunks_by_map, the save/restore is inactive, and each autotuning trial corrupts the shared buffer further.

Some questions to answer/ponder:

Why only sort_chunks_by_map?

Other permutation kernels in TE (e.g., _permute_with_mask_map_kernel in the same file) already declare explicit input_output_aliases, which claims the output buffer slot and prevents XLA from implicitly aliasing a different input to that output. _sort_chunks_by_map_kernel lacked this explicit alias, leaving the output slot "unclaimed."

Why only in the full test suite?

Running all tests in a single pytest process changes XLA's buffer assignment decisions compared to an isolated run — likely due to different memory pressure, compilation caches, or HLO graph differences. The exact mechanism was investigated but not conclusively determined. TODO: I am currently trying to dump out VLOG prints in buffer_assignment.cc in XLA for this kernel specifically while running all tests (the reproducing set up), to determine whether the buffer reuse was intentional by XLA in some specific HLO graph, or was it a rare error somewhere in the Python stack. The answer to this is still TBD

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

(Part 1/2)
Pass a pre-allocated output buffer as an additional input and declare an explicit input_output_alias in sort_chunks_by_map_kernel ( Mirror what PermuteWithMaskMapPrimitive already does )

By declaring input_output_aliases={3: 0}, XLA knows output slot 0 (dummy output) is claimed by input 3 (output_buf). XLA/Python will not implicitly assign any other input (like output_grad) to that output buffer. Input and output are guaranteed separate memory. Therefore, no corruption leading to mismatch will happen.

(Part 2/2) is still needed to complete this fix and guarantee we will not see this issue again.
There is currently this bug: bug 5810384 that prevents triton_extension/utils.py from passing  the correct input_output_alias to our TritonAutotunedKernel call, therefore this WAR in part 1/2 will not take effect for the autotuning passes if part 2/2 is not submitted, and mismatches will still be observed in our CI

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…sort_chunk_by_map bwd function

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung tdophung marked this pull request as draft March 4, 2026 00:46
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Greptile Summary

This PR introduces a workaround (Part 1/2) for intermittent gradient corruption in the sort_chunks_by_map backward kernel by mirroring the input_output_aliases pattern already used by PermuteWithMaskMapPrimitive. The root cause is XLA's buffer assignment assigning the same GPU memory to both the output_grad input and inp_grad output of the backward custom call (they share shape), combined with _sort_chunks_by_map_kernel being unable to operate in-place (reads src_row, writes dst_row where src_row ≠ dst_row).

  • Fix approach: A pre-allocated output_buf tensor is added as a 4th input to the inner JAX primitive and declared in input_output_aliases = {3: 0}, which "claims" output slot 0 so XLA cannot implicitly alias any other live input (e.g., output_grad) to that slot.
  • PyTorch callsite: The same output tensor is passed as the new output_buf_ptr argument — safe because the kernel never reads this parameter.
  • Potential gap: When with_probs=True, the permuted_probs output (slot 1) has no corresponding alias, leaving the probspermuted_probs path theoretically susceptible to the same aliasing. PermuteWithMaskMapPrimitive protects both slots conditionally.
  • Known incompleteness: As stated in the PR, Part 2/2 is required to fix a separate bug in triton_extension/utils.py that prevents input_output_aliases from being forwarded to TritonAutotunedKernelCall. Until Part 2/2 lands, the autotuning passes remain unprotected and mismatches can still occur in CI.

Confidence Score: 3/5

  • Partial WAR that correctly prevents non-autotuning aliasing corruption; safe to merge with awareness that Part 2/2 is required to fully resolve CI failures
  • The core aliasing prevention mechanism is technically correct and follows an established pattern in the codebase. However, the fix is intentionally incomplete (Part 1/2): the permuted_probs output slot (slot 1) lacks an alias declaration for with_probs=True, and the autotuning path remains unprotected until Part 2/2 ships. The PR description itself acknowledges mismatches may still appear in CI until the companion fix arrives.
  • transformer_engine/jax/triton_extensions/permutation.py — the input_output_aliases declaration only covers output slot 0; slot 1 (permuted_probs) is unprotected when with_probs=True

Important Files Changed

Filename Overview
transformer_engine/common/triton/permutation.py Adds output_buf_ptr as a new (unused) parameter to _sort_chunks_by_map_kernel, enabling JAX's input_output_aliases mechanism. The parameter is correctly marked pylint: disable=unused-argument and its position in the argument list is consistent between JAX lowering and PyTorch call-site.
transformer_engine/jax/triton_extensions/permutation.py Core of the WAR: adds output_buf input to the inner primitive, declares input_output_aliases = {3: 0}, and correctly mirrors the PermuteWithMaskMapPrimitive pattern. However, when with_probs=True, the permuted_probs output (slot 1) has no corresponding alias, leaving a potential aliasing gap for that output path.
transformer_engine/pytorch/triton/permutation.py Correctly passes output as the new output_buf_ptr argument to match the updated kernel signature. The same tensor is passed for both output_buf_ptr (4th arg) and output_ptr (11th arg), which is intentional and safe since the kernel never reads output_buf_ptr.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant OuterPrimitive as SortChunksByMapPrimitive<br/>(outer)
    participant InnerPrimitive as SortChunksByMapPrimitive<br/>(inner)
    participant XLA
    participant Kernel as _sort_chunks_by_map_kernel

    Caller->>OuterPrimitive: bind(inp, row_id_map, probs, ...)
    OuterPrimitive->>OuterPrimitive: impl() — allocate output_buf = jnp.empty(...)
    OuterPrimitive->>InnerPrimitive: bind(inp, row_id_map, probs, output_buf, ...)
    InnerPrimitive->>XLA: lowering() with input_output_aliases={3:0}
    Note over XLA: Knows output slot 0 is claimed<br/>by input 3 (output_buf).<br/>Cannot alias inp/output_grad → output.
    XLA->>Kernel: launch(input_ptr, row_id_map_ptr, probs_ptr,<br/>output_buf_ptr[unused], ..., output_ptr, permuted_probs_ptr)
    Note over Kernel: output_buf_ptr and output_ptr<br/>share the same GPU buffer (via alias)
    Kernel-->>XLA: writes inp→output_ptr (no in-place corruption)
    XLA-->>Caller: (output, permuted_probs)
Loading

Last reviewed commit: 43b6991

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (1)

transformer_engine/pytorch/triton/permutation.py, line 441
Missing output_buf_ptr argument breaks PyTorch kernel call

The _sort_chunks_by_map_kernel in transformer_engine/common/triton/permutation.py has output_buf_ptr inserted as the 4th positional argument (after probs_ptr and before the strides). However, this PyTorch call site does not pass the new output_buf_ptr, causing all subsequent positional arguments to be shifted by one position and mismatched:

  • inp.stride(0) → gets bound to output_buf_ptr (expects a tensor pointer, not an int)
  • inp.stride(1) → gets bound to stride_input_token
  • output.stride(0) → gets bound to stride_input_hidden
  • … and so on

This will cause either a Triton runtime error or silent data corruption in PyTorch as strides and data pointers are swapped.

Per the comment in the kernel (# In PyTorch, pass the same tensor as output_ptr), output should be passed as the output_buf_ptr argument:

    _sort_chunks_by_map_kernel[grid](
        inp,
        row_id_map,
        probs,
        output,  # output_buf_ptr: pass same tensor as output_ptr
        inp.stride(0),
        inp.stride(1),
        output.stride(0),
        output.stride(1),
        probs.stride(0) if probs is not None else None,
        permuted_probs.stride(0) if permuted_probs is not None else None,
        output,
        permuted_probs,
        hidden_size,
        PERMUTE_PROBS=probs is not None,
        FORWARD=is_forward,
    )

Signed-off-by: tdophung <tdophung@nvidia.com>
@tdophung
Copy link
Copy Markdown
Collaborator Author

tdophung commented Mar 4, 2026

/te-ci

@tdophung
Copy link
Copy Markdown
Collaborator Author

tdophung commented Mar 4, 2026

Note that the CI is not passing because part 2/2 is not submitted

Copy link
Copy Markdown
Collaborator

@phu0ngng phu0ngng left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 4, 2026

Additional Comments (2)

transformer_engine/jax/triton_extensions/permutation.py, line 1732
Incomplete aliasing for permuted_probs output (slot 1)

When with_probs=True, the permuted_probs output (output slot 1) has no corresponding alias declared. Both probs (input) and permuted_probs (output) share the same shape (num_tokens,) — the same XLA buffer-reuse condition that caused the corruption for activations applies equally here. If XLA assigns the probs input buffer to the permuted_probs output, _sort_chunks_by_map_kernel would experience the same read-before-write corruption described in the PR.

To fully mirror the aliasing pattern from PermuteWithMaskMapPrimitive and protect both outputs, consider conditionally declaring the second alias when with_probs=True:

# Input indices: 0=inp, 1=row_id_map, 2=probs, 3=output_buf
# Output indices: 0=output, 1=permuted_probs
input_output_aliases = {3: 0}
if with_probs:
    input_output_aliases[...] = 1  # alias a permuted_probs_buf to output slot 1

This would require adding a permuted_probs_buf input analogous to output_buf, similar to what PermuteWithMaskMapPrimitive does with permuted_probs_buf_aval. The PR description acknowledges this is Part 1/2, but if with_probs=True is ever used in the backward pass, this gap could still surface the aliasing bug.


transformer_engine/pytorch/triton/permutation.py, line 430
Misleading comment about output_buf_ptr in PyTorch

The comment # no use in Pytorch side, serves as WAR for JAX side describes the motivation for the parameter but doesn't clarify that PyTorch is passing output (the same object used as output_ptr on line 437). A future reader maintaining the kernel signature may not understand why output is passed here or what happens if the two become different tensors.

Consider clarifying:

        output,  # output_buf_ptr: unused by this kernel, but required by the shared kernel signature (WAR for JAX input_output_aliases)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM pending CI, thanks!

@tdophung tdophung merged commit d40b9de into NVIDIA:main Mar 5, 2026
25 of 33 checks passed
tdophung added a commit to tdophung/TransformerEngine that referenced this pull request May 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants