From ac2350ea536b517cb5c9056731be28c8a701db82 Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Wed, 22 Apr 2026 21:26:17 +0800 Subject: [PATCH 1/2] Reject autotune default scalar inputs --- .../test_tilelang_autotune_scalar_inputs.py | 43 +++++++++++++++++++ tilelang/autotuner/tuner.py | 26 ++++++++++- 2 files changed, 67 insertions(+), 2 deletions(-) create mode 100644 testing/python/autotune/test_tilelang_autotune_scalar_inputs.py diff --git a/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py b/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py new file mode 100644 index 0000000000..2f444dc84a --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py @@ -0,0 +1,43 @@ +import pytest +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs + + +@tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60) +@tilelang.jit +def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128): + @T.prim_func + def kernel(A: T.Tensor((N,), T.float32), s: T.float32): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n: + A_local = T.alloc_fragment((BLOCK_N,), T.float32) + T.copy(A[pid_n * BLOCK_N], A_local) + for i in T.Parallel(BLOCK_N): + A_local[i] += s + T.copy(A_local, A[pid_n * BLOCK_N]) + + return kernel + + +def test_autotune_scalar_inputs_require_explicit_supply(): + with pytest.raises(ValueError, match=r"set_autotune_inputs"): + add_scalar() + + +def test_autotune_scalar_inputs_with_set_autotune_inputs(): + tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32) + tune_s = 0.1 + with set_autotune_inputs(tune_a, tune_s): + kernel = add_scalar() + + a = torch.randn((4096,), device="cuda", dtype=torch.float32) + before = a.clone() + kernel(a, tune_s) + + torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 413a66602b..7a15b22104 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -288,6 +288,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple self._kernel_parameters = k_parameters self._function_parameters = f_parameters + def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None): + if self.profile_args.supply_prog is not None: + return + + if out_idx is None: + result_idx = [] + elif isinstance(out_idx, int): + result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx] + else: + result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx] + + for i, param in enumerate(prim_func.params): + if i in result_idx: + continue + if param not in prim_func.buffer_map: + kernel_name = get_prim_func_name(prim_func, "") + raise ValueError( + f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. " + f"Found scalar input parameter '{param.name}'. " + "Please provide concrete inputs with `with set_autotune_inputs(...)`." + ) + def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: """Generate a cache key for the auto-tuning process.""" @@ -701,6 +723,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: return_kernel = kwargs.pop("__return_kernel", False) mode = self.jit_impl.initialize_jit_mode(*args, **kwargs) + autotuner = self.get_tunner() + autotuner._validate_input_supply_requirements(self.jit_impl.get_tir(*args, **kwargs), self.jit_impl.out_idx) norm_args = _normalize_value(args, sort_dict_items=True) norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) @@ -711,7 +735,6 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: def jit_compile(**config_arg): return self.jit_impl(*args, **kwargs, __tune_params=config_arg) - autotuner = self.get_tunner() autotuner.jit_compile = jit_compile autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) else: @@ -721,7 +744,6 @@ def jit_compile(**config_arg): merged.update(config_arg) return self.jit_impl.compile(*args, **merged) - autotuner = self.get_tunner() autotuner.jit_compile = jit_compile autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) From 0219ee2945b9ece79f300cc9bfc8277c7f02bdac Mon Sep 17 00:00:00 2001 From: Rachmanino <18805904201@163.com> Date: Tue, 28 Apr 2026 22:20:52 +0800 Subject: [PATCH 2/2] [Autotune] Defer scalar input validation and add do_not_specialize option - Move _validate_input_supply_requirements from __call__ into run(), after the "tunable parameters already provided" check. This prevents spurious errors when all tunable params are supplied by the caller and the autotuner skips profiling entirely (fixes test_example_mha_fwd_varlen). - Add `do_not_specialize` parameter to @autotune decorator. Parameters listed here are excluded from the autotune cache key, so changing their values reuses the previously tuned configuration instead of triggering re-autotuning. Usage: @autotune(configs=get_configs(), do_not_specialize=["batch_size", "UQ"]) @tilelang.jit def kernel(batch_size, UQ, ...): ... Co-Authored-By: Claude Opus 4.6 --- ...est_tilelang_autotune_do_not_specialize.py | 150 ++++++++++++++++++ tilelang/autotuner/tuner.py | 30 +++- 2 files changed, 176 insertions(+), 4 deletions(-) create mode 100644 testing/python/autotune/test_tilelang_autotune_do_not_specialize.py diff --git a/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py b/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py new file mode 100644 index 0000000000..18698cb3d2 --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py @@ -0,0 +1,150 @@ +"""Tests for the `do_not_specialize` parameter of @autotune decorator. + +`do_not_specialize` allows users to specify parameters whose value changes +should NOT trigger re-autotuning. The best config found on the first call +is reused regardless of the value of those parameters. +""" + +import pytest +import tilelang +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs +import tilelang.language as T +import torch + + +def get_configs(): + return [ + {"threads": 64}, + {"threads": 128}, + ] + + +# --------------------------------------------------------------------------- +# Kernel under test +# --------------------------------------------------------------------------- + + +@tilelang.autotune( + configs=get_configs(), + do_not_specialize=["N", "K"], + warmup=5, + rep=10, + timeout=30, +) +@tilelang.jit +def matmul_do_not_spec(M, N, K, threads=128): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, 64), T.ceildiv(M, 64), threads=threads) as (bx, by): + A_shared = T.alloc_shared((64, 32), dtype) + B_shared = T.alloc_shared((64, 32), dtype) + C_local = T.alloc_fragment((64, 64), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, 32), num_stages=1): + T.copy(A[by * 64, k * 32], A_shared) + T.copy(B[bx * 64, k * 32], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * 64, bx * 64]) + + return main + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _call_and_count_new_entries(fn, M, N, K, threads=None): + """Call fn and return how many *new* cache entries were added.""" + prev_size = len(fn._tuner_cache) + a = torch.randn(M, K, dtype=torch.float16, device="cuda") + b = torch.randn(N, K, dtype=torch.float16, device="cuda") + call_kwargs = {"M": M, "N": N, "K": K} + if threads is not None: + call_kwargs["threads"] = threads + with set_autotune_inputs([a, b]): + fn(**call_kwargs) + new_size = len(fn._tuner_cache) + return new_size - prev_size + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clear_autotune_cache(): + """Clear the shared autotune cache before each test so tests are isolated.""" + matmul_do_not_spec._tuner_cache.clear() + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_same_key_for_different_values(): + """Changing N and K (in do_not_specialize) should reuse the same cache entry.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=512, K=512, threads=64) + + assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" + assert new2 == 0, f"do_not_specialize failed: second call with different N/K should reuse cache (0 new entries), got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_new_key_for_m_param(): + """Changing M (NOT in do_not_specialize) should trigger a new autotune run.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=1024, N=256, K=256, threads=64) + + assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" + assert new2 == 1, f"Expected 1 new cache entry because M is NOT in do_not_specialize, got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_kwargs_and_args(): + """do_not_specialize should work whether params are passed as args or kwargs.""" + a = torch.randn(512, 256, dtype=torch.float16, device="cuda") + b = torch.randn(256, 256, dtype=torch.float16, device="cuda") + + # First call: all kwargs + with set_autotune_inputs([a, b]): + matmul_do_not_spec(M=512, N=256, K=256, threads=64) + prev = len(matmul_do_not_spec._tuner_cache) + + # Second call: positional args, N and K differ but are in do_not_specialize + with set_autotune_inputs([a, b]): + matmul_do_not_spec(512, 512, 512, threads=64) + new = len(matmul_do_not_spec._tuner_cache) + + assert new == prev, f"do_not_specialize failed with positional args: cache grew from {prev} to {new} (expected no new entries)" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_threads_new_entry(): + """threads is NOT in do_not_specialize → different values = new cache entry.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=128) + + assert new1 == 1 + assert new2 == 1, f"Changing threads (not in do_not_specialize) should create a new cache entry, got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_all_specialized_params_same_cache(): + """When all params except do_not_specialize are the same, cache is shared.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=999, K=333, threads=64) + + assert new1 == 1 + assert new2 == 0, f"All non-specialized params (M, threads) are same → should reuse cache, got {new2} new entries" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 7a15b22104..7236c1c605 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -544,6 +544,12 @@ def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) self._memory_cache[key] = autotuner_result return autotuner_result + + # After confirming tuning will actually run, validate that scalar + # inputs can be supplied (either via supply_prog or set_autotune_inputs). + if hasattr(self, "_prim_func_for_validation"): + self._validate_input_supply_requirements(self._prim_func_for_validation, self.compile_args.out_idx) + # get the cpu count available_cpu_count = get_available_cpu_count() cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) @@ -689,6 +695,7 @@ class AutoTuneImpl(Generic[_P, _T]): skip_check: bool = False manual_check_prog: Callable = None cache_input_tensors: bool = False + do_not_specialize: tuple[str, ...] | list[str] | None = None def __post_init__(self): self._tuner_cache = {} @@ -724,10 +731,23 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: mode = self.jit_impl.initialize_jit_mode(*args, **kwargs) autotuner = self.get_tunner() - autotuner._validate_input_supply_requirements(self.jit_impl.get_tir(*args, **kwargs), self.jit_impl.out_idx) - - norm_args = _normalize_value(args, sort_dict_items=True) - norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) + # Defer scalar-input validation to run(), after we know whether + # tuning will actually execute or be skipped because all tunable + # parameters are already provided by the caller. + autotuner._prim_func_for_validation = self.jit_impl.get_tir(*args, **kwargs) + + # Compute the cache key, excluding do_not_specialize parameters + # so that changing them does not trigger re-autotuning. + if self.do_not_specialize: + # Bind all args to kwargs so positional vs keyword doesn't affect key + bound = self.jit_impl.signature.bind(*args, **kwargs) + bound.apply_defaults() + filtered_kwargs = {k: v for k, v in bound.arguments.items() if k not in self.do_not_specialize} + norm_args = () + norm_kwargs = _normalize_value(filtered_kwargs, sort_dict_items=True) + else: + norm_args = _normalize_value(args, sort_dict_items=True) + norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) key = (norm_args, norm_kwargs) if key not in self._tuner_cache: if mode == "lazy": @@ -785,6 +805,7 @@ def autotune( # This is the new public interface skip_check: bool = False, manual_check_prog: Callable = None, cache_input_tensors: bool = False, + do_not_specialize: tuple[str, ...] | list[str] | None = None, ): """ Just-In-Time (JIT) compiler decorator for TileLang functions. @@ -860,6 +881,7 @@ def decorator(impl): skip_check=skip_check, manual_check_prog=manual_check_prog, cache_input_tensors=cache_input_tensors, + do_not_specialize=do_not_specialize, ) return decorator