Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,10 @@ def _sort_chunks_by_map_kernel(
input_ptr,
row_id_map_ptr,
probs_ptr,
# Pre-allocated output buffer for JAX input_output_aliases.
# Aliased to output_ptr in JAX so they point to the same memory.
# In PyTorch, pass the same tensor as output_ptr.
output_buf_ptr, # pylint: disable=unused-argument
# strides
stride_input_token,
stride_input_hidden,
Expand Down
28 changes: 26 additions & 2 deletions transformer_engine/jax/triton_extensions/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1666,10 +1666,19 @@ class SortChunksByMapPrimitive(BasePrimitive):

@staticmethod
def abstract(
inp_aval, row_id_map_aval, probs_aval, *, num_tokens, hidden_size, is_forward, with_probs
inp_aval,
row_id_map_aval,
probs_aval,
output_buf_aval=None, # Pre-allocated output buffer (inner primitive only)
*,
num_tokens,
hidden_size,
is_forward,
with_probs,
):
"""Shape/dtype inference."""
del row_id_map_aval, is_forward
del output_buf_aval # Used for input_output_aliases only

output_aval = jax.core.ShapedArray((num_tokens, hidden_size), inp_aval.dtype)

Expand All @@ -1684,18 +1693,24 @@ def abstract(
def impl(inp, row_id_map, probs, num_tokens, hidden_size, is_forward, with_probs):
"""Forward to inner primitive."""
assert SortChunksByMapPrimitive.inner_primitive is not None

output_buf = jnp.empty((num_tokens, hidden_size), dtype=inp.dtype)

return SortChunksByMapPrimitive.inner_primitive.bind(
inp,
row_id_map,
probs,
output_buf,
num_tokens=num_tokens,
hidden_size=hidden_size,
is_forward=is_forward,
with_probs=with_probs,
)

@staticmethod
def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward, with_probs):
def lowering(
ctx, inp, row_id_map, probs, output_buf, *, num_tokens, hidden_size, is_forward, with_probs
):
"""MLIR lowering using triton_call_lowering."""
# Compute strides
inp_stride_token = hidden_size
Expand All @@ -1709,13 +1724,22 @@ def lowering(ctx, inp, row_id_map, probs, *, num_tokens, hidden_size, is_forward
block_size = _get_min_block_size(_sort_chunks_by_map_kernel)
grid = (num_tokens, triton.cdiv(hidden_size, block_size))

# Declare input_output_aliases so XLA knows output slot 0 is claimed by
# input 3 (output_buf). This prevents XLA from implicitly aliasing any
# other input (like output_grad in backward) to the output buffer.
# Input indices: 0=inp, 1=row_id_map, 2=probs, 3=output_buf
# Output indices: 0=output, 1=permuted_probs
input_output_aliases = {3: 0}

return triton_call_lowering(
ctx,
_sort_chunks_by_map_kernel,
inp,
row_id_map,
probs,
output_buf,
grid=grid,
input_output_aliases=input_output_aliases,
constexprs={
"stride_input_token": inp_stride_token,
"stride_input_hidden": inp_stride_hidden,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ def sort_chunks_by_map(
inp,
row_id_map,
probs,
output, # no use in Pytorch side, serves as WAR for JAX side
inp.stride(0),
inp.stride(1),
output.stride(0),
Expand Down
Loading