Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 150 additions & 0 deletions testing/python/autotune/test_tilelang_autotune_do_not_specialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""Tests for the `do_not_specialize` parameter of @autotune decorator.

`do_not_specialize` allows users to specify parameters whose value changes
should NOT trigger re-autotuning. The best config found on the first call
is reused regardless of the value of those parameters.
"""

import pytest
import tilelang
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs
import tilelang.language as T
import torch


def get_configs():
return [
{"threads": 64},
{"threads": 128},
]


# ---------------------------------------------------------------------------
# Kernel under test
# ---------------------------------------------------------------------------


@tilelang.autotune(
configs=get_configs(),
do_not_specialize=["N", "K"],
warmup=5,
rep=10,
timeout=30,
)
@tilelang.jit
def matmul_do_not_spec(M, N, K, threads=128):
dtype = T.float16
accum_dtype = T.float32

@T.prim_func
def main(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), dtype),
):
with T.Kernel(T.ceildiv(N, 64), T.ceildiv(M, 64), threads=threads) as (bx, by):
A_shared = T.alloc_shared((64, 32), dtype)
B_shared = T.alloc_shared((64, 32), dtype)
C_local = T.alloc_fragment((64, 64), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, 32), num_stages=1):
T.copy(A[by * 64, k * 32], A_shared)
T.copy(B[bx * 64, k * 32], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True)
T.copy(C_local, C[by * 64, bx * 64])

return main


# ---------------------------------------------------------------------------
# Test helpers
# ---------------------------------------------------------------------------


def _call_and_count_new_entries(fn, M, N, K, threads=None):
"""Call fn and return how many *new* cache entries were added."""
prev_size = len(fn._tuner_cache)
a = torch.randn(M, K, dtype=torch.float16, device="cuda")
b = torch.randn(N, K, dtype=torch.float16, device="cuda")
call_kwargs = {"M": M, "N": N, "K": K}
if threads is not None:
call_kwargs["threads"] = threads
with set_autotune_inputs([a, b]):
fn(**call_kwargs)
new_size = len(fn._tuner_cache)
return new_size - prev_size


# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------


@pytest.fixture(autouse=True)
def _clear_autotune_cache():
"""Clear the shared autotune cache before each test so tests are isolated."""
matmul_do_not_spec._tuner_cache.clear()


@tilelang.testing.requires_cuda
def test_do_not_specialize_same_key_for_different_values():
"""Changing N and K (in do_not_specialize) should reuse the same cache entry."""
new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64)
new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=512, K=512, threads=64)

assert new1 == 1, f"First call should create 1 new cache entry, got {new1}"
assert new2 == 0, f"do_not_specialize failed: second call with different N/K should reuse cache (0 new entries), got {new2}"


@tilelang.testing.requires_cuda
def test_do_not_specialize_new_key_for_m_param():
"""Changing M (NOT in do_not_specialize) should trigger a new autotune run."""
new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64)
new2 = _call_and_count_new_entries(matmul_do_not_spec, M=1024, N=256, K=256, threads=64)

assert new1 == 1, f"First call should create 1 new cache entry, got {new1}"
assert new2 == 1, f"Expected 1 new cache entry because M is NOT in do_not_specialize, got {new2}"


@tilelang.testing.requires_cuda
def test_do_not_specialize_kwargs_and_args():
"""do_not_specialize should work whether params are passed as args or kwargs."""
a = torch.randn(512, 256, dtype=torch.float16, device="cuda")
b = torch.randn(256, 256, dtype=torch.float16, device="cuda")

# First call: all kwargs
with set_autotune_inputs([a, b]):
matmul_do_not_spec(M=512, N=256, K=256, threads=64)
prev = len(matmul_do_not_spec._tuner_cache)

# Second call: positional args, N and K differ but are in do_not_specialize
with set_autotune_inputs([a, b]):
matmul_do_not_spec(512, 512, 512, threads=64)
new = len(matmul_do_not_spec._tuner_cache)

assert new == prev, f"do_not_specialize failed with positional args: cache grew from {prev} to {new} (expected no new entries)"


@tilelang.testing.requires_cuda
def test_do_not_specialize_threads_new_entry():
"""threads is NOT in do_not_specialize → different values = new cache entry."""
new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64)
new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=128)

assert new1 == 1
assert new2 == 1, f"Changing threads (not in do_not_specialize) should create a new cache entry, got {new2}"


@tilelang.testing.requires_cuda
def test_do_not_specialize_all_specialized_params_same_cache():
"""When all params except do_not_specialize are the same, cache is shared."""
new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64)
new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=999, K=333, threads=64)

assert new1 == 1
assert new2 == 0, f"All non-specialized params (M, threads) are same → should reuse cache, got {new2} new entries"


if __name__ == "__main__":
tilelang.testing.main()
43 changes: 43 additions & 0 deletions testing/python/autotune/test_tilelang_autotune_scalar_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs


@tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60)
@tilelang.jit
def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128):
@T.prim_func
def kernel(A: T.Tensor((N,), T.float32), s: T.float32):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n:
A_local = T.alloc_fragment((BLOCK_N,), T.float32)
T.copy(A[pid_n * BLOCK_N], A_local)
for i in T.Parallel(BLOCK_N):
A_local[i] += s
T.copy(A_local, A[pid_n * BLOCK_N])

return kernel


def test_autotune_scalar_inputs_require_explicit_supply():
with pytest.raises(ValueError, match=r"set_autotune_inputs"):
add_scalar()


def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()

a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)

torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
Comment on lines +29 to +39
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Inspect how existing tests guard CUDA-only allocations.
rg -n -C2 'pytest\.mark\.skipif|torch\.cuda\.is_available|requires_cuda|device="cuda"' testing

Repository: tile-ai/tilelang

Length of output: 50373


🏁 Script executed:

#!/bin/bash
# Examine the specific test file mentioned in the review
cat -n testing/python/autotune/test_tilelang_autotune_scalar_inputs.py

Repository: tile-ai/tilelang

Length of output: 1704


🏁 Script executed:

#!/bin/bash
# Check how other tests in the autotune directory are guarded
ls -la testing/python/autotune/*.py | head -20
rg -n '@pytest.mark.skipif|@tilelang.testing.requires_cuda|import pytest|import torch' testing/python/autotune/ -A 2 | head -100

Repository: tile-ai/tilelang

Length of output: 3362


🏁 Script executed:

#!/bin/bash
# Check if there's a conftest or similar setup in autotune directory that might guard all tests
ls -la testing/python/autotune/
fd conftest testing/python/autotune/

Repository: tile-ai/tilelang

Length of output: 692


🏁 Script executed:

#!/bin/bash
# Check what tilelang.testing.requires_cuda is
rg -A5 'def requires_cuda|requires_cuda =' tilelang/testing --type python | head -30

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

#!/antml:bash
# Also check for imports and definition
fd -e py testing | xargs grep -l 'requires_cuda' | head -5
find . -path ./node_modules -prune -o -name '*.py' -type f -exec grep -l 'requires_cuda.*=' {} \; | head -5

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

#!/bin/bash
# Look for tilelang.testing module definition
find . -name "testing.py" -o -name "testing" -type d | grep tilelang
ls -la tilelang/testing* 2>/dev/null || echo "No direct tilelang/testing files"

Repository: tile-ai/tilelang

Length of output: 322


🏁 Script executed:

#!/bin/bash
# Check the tilelang/testing/__init__.py file
cat tilelang/testing/__init__.py

Repository: tile-ai/tilelang

Length of output: 3930


Add CUDA guard to test requiring GPU.

The test allocates CUDA tensors on lines 30 and 35, so it will fail on CPU-only runners during setup. Use the repository's standard decorator:

Proposed fix
+@tilelang.testing.requires_cuda
 def test_autotune_scalar_inputs_with_set_autotune_inputs():
     tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()
a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)
torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
`@tilelang.testing.requires_cuda`
def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()
a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)
torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/autotune/test_tilelang_autotune_scalar_inputs.py` around lines
29 - 39, The test test_autotune_scalar_inputs_with_set_autotune_inputs allocates
CUDA tensors via set_autotune_inputs and local tensors (add_scalar), so add the
repository's standard CUDA guard to the test (e.g., decorate
test_autotune_scalar_inputs_with_set_autotune_inputs with the project's
`@requires_cuda` or a pytest.skipif(not torch.cuda.is_available()) equivalent) and
import the guard if necessary; this ensures the test is skipped on CPU-only
runners while leaving set_autotune_inputs and add_scalar unchanged.



if __name__ == "__main__":
tilelang.testing.main()
54 changes: 49 additions & 5 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters

def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None):
if self.profile_args.supply_prog is not None:
return

if out_idx is None:
result_idx = []
elif isinstance(out_idx, int):
result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
else:
result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
Comment on lines +398 to +403
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate out_idx before normalizing negative indices.

Line 298 and Line 300 currently accept out-of-range values such as -len(params) - 1, normalizing them to invalid indices. Mirror the existing result-index legalization so this validator does not silently accept bad output indices or raise a misleading scalar-input error.

🐛 Proposed fix
-        if out_idx is None:
-            result_idx = []
-        elif isinstance(out_idx, int):
-            result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
-        else:
-            result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
+        param_count = len(prim_func.params)
+        if out_idx is None:
+            result_idx = set()
+        elif isinstance(out_idx, int):
+            if out_idx >= param_count or out_idx < -param_count:
+                raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+            result_idx = {param_count + out_idx if out_idx < 0 else out_idx}
+        elif isinstance(out_idx, list):
+            result_idx = set()
+            for idx in out_idx:
+                if idx >= param_count or idx < -param_count:
+                    raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+                result_idx.add(param_count + idx if idx < 0 else idx)
+        else:
+            raise ValueError("out_idx should be a list of integers")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/autotuner/tuner.py` around lines 295 - 300, Validate out_idx values
before converting negative indices: for the single-int and iterable branches,
check each original out_idx is within the valid result-index range (use the same
bounds logic as the existing result-index legalization helper) instead of
blindly adding len(prim_func.params); if any index is out-of-range raise the
same validation error used elsewhere (not the scalar-input error). Ensure you
reference out_idx, prim_func.params and produce result_idx only after passing
this bounds check so values like -len(params)-1 are rejected rather than
normalized to an invalid index.


for i, param in enumerate(prim_func.params):
if i in result_idx:
continue
if param not in prim_func.buffer_map:
kernel_name = get_prim_func_name(prim_func, "<unknown>")
raise ValueError(
f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. "
f"Found scalar input parameter '{param.name}'. "
"Please provide concrete inputs with `with set_autotune_inputs(...)`."
)

def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process."""

Expand Down Expand Up @@ -625,6 +647,12 @@ def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool:
autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel)
self._memory_cache[key] = autotuner_result
return autotuner_result

# After confirming tuning will actually run, validate that scalar
# inputs can be supplied (either via supply_prog or set_autotune_inputs).
if hasattr(self, "_prim_func_for_validation"):
self._validate_input_supply_requirements(self._prim_func_for_validation, self.compile_args.out_idx)

# get the cpu count
available_cpu_count = get_available_cpu_count()
cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES)
Expand Down Expand Up @@ -770,6 +798,7 @@ class AutoTuneImpl(Generic[_P, _T]):
skip_check: bool = False
manual_check_prog: Callable = None
cache_input_tensors: bool = False
do_not_specialize: tuple[str, ...] | list[str] | None = None

def __post_init__(self):
self._tuner_cache = {}
Expand Down Expand Up @@ -804,17 +833,31 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T:
return_kernel = kwargs.pop("__return_kernel", False)

mode = self.jit_impl.initialize_jit_mode(*args, **kwargs)

norm_args = _normalize_value(args, sort_dict_items=True)
norm_kwargs = _normalize_value(kwargs, sort_dict_items=True)
autotuner = self.get_tunner()
# Defer scalar-input validation to run(), after we know whether
# tuning will actually execute or be skipped because all tunable
# parameters are already provided by the caller.
autotuner._prim_func_for_validation = self.jit_impl.get_tir(*args, **kwargs)

# Compute the cache key, excluding do_not_specialize parameters
# so that changing them does not trigger re-autotuning.
if self.do_not_specialize:
# Bind all args to kwargs so positional vs keyword doesn't affect key
bound = self.jit_impl.signature.bind(*args, **kwargs)
bound.apply_defaults()
filtered_kwargs = {k: v for k, v in bound.arguments.items() if k not in self.do_not_specialize}
norm_args = ()
norm_kwargs = _normalize_value(filtered_kwargs, sort_dict_items=True)
else:
norm_args = _normalize_value(args, sort_dict_items=True)
norm_kwargs = _normalize_value(kwargs, sort_dict_items=True)
key = (norm_args, norm_kwargs)
if key not in self._tuner_cache:
if mode == "lazy":

def jit_compile(**config_arg):
return self.jit_impl(*args, **kwargs, __tune_params=config_arg)

autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
else:
Expand All @@ -824,7 +867,6 @@ def jit_compile(**config_arg):
merged.update(config_arg)
return self.jit_impl.compile(*args, **merged)

autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)

Expand Down Expand Up @@ -866,6 +908,7 @@ def autotune( # This is the new public interface
skip_check: bool = False,
manual_check_prog: Callable = None,
cache_input_tensors: bool = False,
do_not_specialize: tuple[str, ...] | list[str] | None = None,
):
"""
Just-In-Time (JIT) compiler decorator for TileLang functions.
Expand Down Expand Up @@ -941,6 +984,7 @@ def decorator(impl):
skip_check=skip_check,
manual_check_prog=manual_check_prog,
cache_input_tensors=cache_input_tensors,
do_not_specialize=do_not_specialize,
)

return decorator
Loading