Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion transformer_engine/jax/triton_extensions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,7 @@ def _normalize_grid(grid_tuple):
# NVTE_JAX_ENFORCE_TRITON_AUTOTUNING=1 to raise an error instead, prompting the
# user to upgrade JAX for improved performance.
is_autotuned = isinstance(kernel_fn, autotuner.Autotuner)
used_autotuned_launch = False
if is_autotuned and not is_triton_autotuned_alias_safe():
val = os.environ.get("NVTE_JAX_ENFORCE_TRITON_AUTOTUNING", "0")
try:
Expand Down Expand Up @@ -574,6 +575,7 @@ def _normalize_grid(grid_tuple):
kernel_calls,
input_output_aliases_with_sizes,
)
used_autotuned_launch = True

else:
# Regular kernel: compile single config.
Expand Down Expand Up @@ -630,7 +632,9 @@ def _normalize_grid(grid_tuple):
ffi_operand_output_aliases = None

compressed_call_proto = zlib.compress(call_proto)
if jax_version_meet_requirement(TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION):
if not used_autotuned_launch and jax_version_meet_requirement(
TRITON_EXTENSION_CUDA_GRAPH_MIN_JAX_VERSION
):
rule = jax.ffi.ffi_lowering(
"triton_kernel_call_ffi",
backend_config={"opaque": ir.StringAttr.get(compressed_call_proto)},
Expand Down
Loading