diff --git a/transformer_engine/jax/triton_extensions/utils.py b/transformer_engine/jax/triton_extensions/utils.py index 95ee370c81..f94d353940 100644 --- a/transformer_engine/jax/triton_extensions/utils.py +++ b/transformer_engine/jax/triton_extensions/utils.py @@ -485,6 +485,7 @@ def _normalize_grid(grid_tuple): # NVTE_JAX_ENFORCE_TRITON_AUTOTUNING=1 to raise an error instead, prompting the # user to upgrade JAX for improved performance. is_autotuned = isinstance(kernel_fn, autotuner.Autotuner) + used_autotuned_launch = False if is_autotuned and not is_triton_autotuned_alias_safe(): val = os.environ.get("NVTE_JAX_ENFORCE_TRITON_AUTOTUNING", "0") try: @@ -574,6 +575,7 @@ def _normalize_grid(grid_tuple): kernel_calls, input_output_aliases_with_sizes, ) + used_autotuned_launch = True else: # Regular kernel: compile single config. @@ -630,7 +632,9 @@ def _normalize_grid(grid_tuple): ffi_operand_output_aliases = None compressed_call_proto = zlib.compress(call_proto) - if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION): + if not used_autotuned_launch and jax_version_meet_requirement( + TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION + ): rule = jax.ffi.ffi_lowering( "triton_kernel_call_ffi", backend_config={"opaque": ir.StringAttr.get(compressed_call_proto)},