diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 147742bb05..75bb85f5ec 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -599,6 +599,10 @@ def _sort_chunks_by_map_kernel( input_ptr, row_id_map_ptr, probs_ptr, + # Pre-allocated output buffer for JAX input_output_aliases. + # Aliased to output_ptr in JAX so they point to the same memory. + # In PyTorch, pass the same tensor as output_ptr. + output_buf_ptr, # pylint: disable=unused-argument # strides stride_input_token, stride_input_hidden, diff --git a/transformer_engine/jax/triton_extensions/permutation.py b/transformer_engine/jax/triton_extensions/permutation.py index 0c80f9f18c..98c54e52bb 100644 --- a/transformer_engine/jax/triton_extensions/permutation.py +++ b/transformer_engine/jax/triton_extensions/permutation.py @@ -1666,10 +1666,19 @@ class SortChunksByMapPrimitive(BasePrimitive): @staticmethod def abstract( - inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs + inp_aval, + row_id_map_aval, + probs_aval, + output_buf_aval=None, # Pre-allocated output buffer (inner primitive only) + *, + num_tokens, + hidden_size, + is_forward, + with_probs, ): """Shape/dtype inference.""" del row_id_map_aval, is_forward + del output_buf_aval # Used for input_output_aliases only output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype) @@ -1684,10 +1693,14 @@ def abstract( def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs): """Forward to inner primitive.""" assert SortChunksByMapPrimitive.inner_primitive is not None + + output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype) + return SortChunksByMapPrimitive.inner_primitive.bind( inp, row_id_map, probs, + output_buf, num_tokens=num_tokens, hidden_size=hidden_size, is_forward=is_forward, @@ -1695,7 +1708,9 @@ def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs ) @staticmethod - def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs): + def lowering( + ctx, inp, row_id_map, probs, output_buf, *, num_tokens, hidden_size, is_forward, with_probs + ): """MLIR lowering using triton_call_lowering.""" # Compute strides inp_stride_token = hidden_size @@ -1709,13 +1724,22 @@ def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward block_size = _get_min_block_size(_sort_chunks_by_map_kernel) grid = (num_tokens, triton.cdiv(hidden_size, block_size)) + # Declare input_output_aliases so XLA knows output slot 0 is claimed by + # input 3 (output_buf). This prevents XLA from implicitly aliasing any + # other input (like output_grad in backward) to the output buffer. + # Input indices: 0=inp, 1=row_id_map, 2=probs, 3=output_buf + # Output indices: 0=output, 1=permuted_probs + input_output_aliases = {3: 0} + return triton_call_lowering( ctx, _sort_chunks_by_map_kernel, inp, row_id_map, probs, + output_buf, grid=grid, + input_output_aliases=input_output_aliases, constexprs={ "stride_input_token": inp_stride_token, "stride_input_hidden": inp_stride_hidden, diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 6b5de9ab0f..4902bc686c 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -427,6 +427,7 @@ def sort_chunks_by_map( inp, row_id_map, probs, + output, # no use in Pytorch side, serves as WAR for JAX side inp.stride(0), inp.stride(1), output.stride(0),