Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def _normalize_grid(grid_tuple):
# NVTE_JAX_ENFORCE_TRITON_AUTOTUNING=1 to raise an error instead, prompting the
# user to upgrade JAX for improved performance.
is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
used_autotuned_launch = False
if is_autotuned and not is_triton_autotuned_alias_safe():
val = os.environ.get("NVTE_JAX_ENFORCE_TRITON_AUTOTUNING", "0")
try:
Expand Down Expand Up @@ -574,6 +575,7 @@ def _normalize_grid(grid_tuple):
kernel_calls,
input_output_aliases_with_sizes,
)
used_autotuned_launch = True

else:
# Regular kernel: compile single config.
Expand Down Expand Up @@ -630,7 +632,10 @@ def _normalize_grid(grid_tuple):
ffi_operand_output_aliases = None

compressed_call_proto = zlib.compress(call_proto)
if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION):
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The new branch condition does not carry a comment explaining why autotuned launches must skip the new FFI. Without it, a future reader will only see the mechanism (the flag) but not the root cause (CUDA IMA with triton_kernel_call_ffi on autotuned kernels), making it easy to inadvertently remove the guard when refactoring.

Suggested change
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):
# Autotuned kernels must use the older "triton_kernel_call" FFI: the newer
# "triton_kernel_call_ffi" path triggers CUDA IMA (Illegal Memory Access)
# errors for autotuned kernels and must be bypassed until the upstream issue
# is resolved.
if (
not used_autotuned_launch
and jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION)
):

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

rule = jax.ffi.ffi_lowering(
"triton_kernel_call_ffi",
backend_config={"opaque": ir.StringAttr.get(compressed_call_proto)},
Expand Down
Loading