Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@
import jax.numpy as jnp

from ..version_utils import (
TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION,
TRITON_EXTENSION_MIN_JAX_VERSION,
is_triton_extension_supported,
jax_version_meet_requirement,
)


Expand Down Expand Up @@ -474,23 +476,31 @@ def lowering(ctx, x, *, block_size):

kernel_calls.append((config_call, str(config)))

# IMPORTANT: We pass an empty tuple for input_output_aliases_with_sizes.
#
# Background:
# 1. jax.ffi.ffi_lowering(operand_output_aliases=...) is a HINT to XLA that an
# output can reuse an input's buffer. XLA may or may not honor this.
# 2. TritonAutotunedKernelCall's input_output_aliases_with_sizes triggers
# save/restore logic during autotuning (see jaxlib/gpu/triton_kernels.cc:630-701).
#
# The problem: The save phase (triton_kernels.cc:632) only saves if buffers[input_idx] == buffers[output_idx],
# but the restore phase (triton_kernels.cc:697-700) unconditionally iterates over all aliases and tries
# to access input_copies[input_idx]. If XLA didn't actually alias the buffers, input_copies[input_idx] doesn't exist, creating an empty vector whose .data() returns nullptr, causing CUDA_ERROR_INVALID_VALUE during the restore memcpy.
#
# WAR: Don't pass aliases to TritonAutotunedKernelCall.
input_output_aliases_with_sizes = ()
Comment thread
jberchtold-nvidia marked this conversation as resolved.
if input_output_aliases:
if jax_version_meet_requirement(TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION):
num_inputs = len(ctx.avals_in)
aliases = []
for input_idx, output_idx in input_output_aliases.items():
aval = ctx.avals_in[input_idx]
size_bytes = aval.size * jnp.dtype(aval.dtype).itemsize
# AutotunedKernelCall expects buffer indices (inputs + outputs).
buffer_output_idx = num_inputs + output_idx
aliases.append((input_idx, buffer_output_idx, size_bytes))
input_output_aliases_with_sizes = tuple(aliases)
else:
warnings.warn(
f"JAX >= {TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION} is required "
"to safely pass input_output_aliases to TritonAutotunedKernelCall. "
"Passing empty aliases as a workaround (jax-ml/jax#35218).",
UserWarning,
stacklevel=2,
)

kernel_call = gpu_triton.TritonAutotunedKernelCall(
f"{actual_kernel_fn.__name__}_autotuned",
kernel_calls,
(), # Empty to avoid buggy save/restore in jaxlib/gpu/triton_kernels.cc
input_output_aliases_with_sizes,
)

else:
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/jax/version_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ def jax_version_meet_requirement(version: str):
# Minimum JAX version required for Triton kernel dispatch (jaxlib < 0.8.0 segfaults).
TRITON_EXTENSION_MIN_JAX_VERSION = "0.8.0"

# Minimum JAX version for safe input_output_aliases in TritonAutotunedKernelCall.
# jaxlib/gpu/triton_kernels.cc had a bug in the autotuning save/restore loop:
# it iterated over all declared aliases unconditionally, but input_copies only
# contains entries for aliases where XLA actually shared buffers at runtime.
# Accessing a missing entry produced a null vector → CUDA_ERROR_INVALID_VALUE.
# Fixed by: https://github.com/jax-ml/jax/pull/35218 (merged 2026-03-17, main).
# Ships in JAX 0.9.3 (not yet released as of 2026-03-31).
TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION = "0.9.3"


def is_triton_extension_supported() -> bool:
"""Return True if the current JAX version supports Triton kernel dispatch.
Expand All @@ -40,4 +49,5 @@ def is_triton_extension_supported() -> bool:
"jax_version_meet_requirement",
"is_triton_extension_supported",
"TRITON_EXTENSION_MIN_JAX_VERSION",
"TRITON_AUTOTUNED_INPUT_OUTPUT_ALIAS_MIN_JAX_VERSION",
]
Loading