WAR sort_chunks_by_index intermittent failures in L0 JAX unitttest#2730
Conversation
…sort_chunk_by_map bwd function Signed-off-by: tdophung <tdophung@nvidia.com>
Greptile SummaryThis PR introduces a workaround (Part 1/2) for intermittent gradient corruption in the
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant OuterPrimitive as SortChunksByMapPrimitive<br/>(outer)
participant InnerPrimitive as SortChunksByMapPrimitive<br/>(inner)
participant XLA
participant Kernel as _sort_chunks_by_map_kernel
Caller->>OuterPrimitive: bind(inp, row_id_map, probs, ...)
OuterPrimitive->>OuterPrimitive: impl() — allocate output_buf = jnp.empty(...)
OuterPrimitive->>InnerPrimitive: bind(inp, row_id_map, probs, output_buf, ...)
InnerPrimitive->>XLA: lowering() with input_output_aliases={3:0}
Note over XLA: Knows output slot 0 is claimed<br/>by input 3 (output_buf).<br/>Cannot alias inp/output_grad → output.
XLA->>Kernel: launch(input_ptr, row_id_map_ptr, probs_ptr,<br/>output_buf_ptr[unused], ..., output_ptr, permuted_probs_ptr)
Note over Kernel: output_buf_ptr and output_ptr<br/>share the same GPU buffer (via alias)
Kernel-->>XLA: writes inp→output_ptr (no in-place corruption)
XLA-->>Caller: (output, permuted_probs)
Last reviewed commit: 43b6991 |
Additional Comments (1)
The
This will cause either a Triton runtime error or silent data corruption in PyTorch as strides and data pointers are swapped. Per the comment in the kernel ( |
Signed-off-by: tdophung <tdophung@nvidia.com>
|
/te-ci |
|
Note that the CI is not passing because part 2/2 is not submitted |
Additional Comments (2)
When To fully mirror the aliasing pattern from # Input indices: 0=inp, 1=row_id_map, 2=probs, 3=output_buf
# Output indices: 0=output, 1=permuted_probs
input_output_aliases = {3: 0}
if with_probs:
input_output_aliases[...] = 1 # alias a permuted_probs_buf to output slot 1This would require adding a
The comment Consider clarifying: Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time! |
jberchtold-nvidia
left a comment
There was a problem hiding this comment.
LGTM pending CI, thanks!
…ttest part 1 (NVIDIA#2730)" This reverts commit d40b9de.
WAR intermittent mismatch failures in sort_chunks_by_index sort_chunk_by_map bwd function (part 1/2)
Description
There is a bug in our CI where the following symptoms were observed:
computed_grad != ref_grad)row_id_mapare correctstress_test.sh(alltests/jax/in one pytest process)pytest test_permutation.pyin isolationAfter benig able to reproduce this on B200, only when running the whole L0_jax_unittest/test.sh, I found 3
interacting factors combine to produce the corruption:
1. XLA implicit buffer aliasing
XLA's buffer assignment (
xla/service/buffer_assignment.cc) can assign the same GPU memory to the input (output_grad) and output (inp_grad) of the backwardsort_chunks_by_mapcustom call. This happens because:(*) We don't know if this is XLA fault or Python yet.
2. Permutation kernels cannot operate in-place
_sort_chunks_by_map_kernel(intransformer_engine/common/triton/permutation.py) reads fromsrc_rowand writes todst_rowwheresrc_row != dst_row. When input and output share the same buffer, GPU thread blocks execute in waves — early blocks overwrite data that later blocks haven't read yet, causing corruption.3. Triton autotuning amplifies corruption
TritonAutotunedKernelCall(injaxlib/gpu/triton_kernels.cc) runs the kernel multiple times with different configs. It normally saves/restores aliased buffers between runs usinginput_output_aliases_with_sizes. But since no explicit aliases are declared forsort_chunks_by_map, the save/restore is inactive, and each autotuning trial corrupts the shared buffer further.Some questions to answer/ponder:
Why only
sort_chunks_by_map?Other permutation kernels in TE (e.g.,
_permute_with_mask_map_kernelin the same file) already declare explicitinput_output_aliases, which claims the output buffer slot and prevents XLA from implicitly aliasing a different input to that output._sort_chunks_by_map_kernellacked this explicit alias, leaving the output slot "unclaimed."Why only in the full test suite?
Running all tests in a single pytest process changes XLA's buffer assignment decisions compared to an isolated run — likely due to different memory pressure, compilation caches, or HLO graph differences. The exact mechanism was investigated but not conclusively determined. TODO: I am currently trying to dump out VLOG prints in
buffer_assignment.ccin XLA for this kernel specifically while running all tests (the reproducing set up), to determine whether the buffer reuse was intentional by XLA in some specific HLO graph, or was it a rare error somewhere in the Python stack. The answer to this is still TBDFixes # (issue)
Type of change
Changes
(Part 1/2)
Pass a pre-allocated output buffer as an additional input and declare an explicit
input_output_aliasinsort_chunks_by_map_kernel( Mirror whatPermuteWithMaskMapPrimitivealready does )By declaring
input_output_aliases={3: 0}, XLA knows output slot 0 (dummy output) is claimed by input 3 (output_buf). XLA/Python will not implicitly assign any other input (likeoutput_grad) to that output buffer. Input and output are guaranteed separate memory. Therefore, no corruption leading to mismatch will happen.(Part 2/2) is still needed to complete this fix and guarantee we will not see this issue again.
There is currently this bug: bug 5810384 that prevents
triton_extension/utils.pyfrom passing the correctinput_output_aliasto ourTritonAutotunedKernelcall, therefore this WAR in part 1/2 will not take effect for the autotuning passes if part 2/2 is not submitted, and mismatches will still be observed in our CIChecklist: