diff --git a/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py b/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py new file mode 100644 index 0000000000..18698cb3d2 --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_do_not_specialize.py @@ -0,0 +1,150 @@ +"""Tests for the `do_not_specialize` parameter of @autotune decorator. + +`do_not_specialize` allows users to specify parameters whose value changes +should NOT trigger re-autotuning. The best config found on the first call +is reused regardless of the value of those parameters. +""" + +import pytest +import tilelang +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs +import tilelang.language as T +import torch + + +def get_configs(): + return [ + {"threads": 64}, + {"threads": 128}, + ] + + +# --------------------------------------------------------------------------- +# Kernel under test +# --------------------------------------------------------------------------- + + +@tilelang.autotune( + configs=get_configs(), + do_not_specialize=["N", "K"], + warmup=5, + rep=10, + timeout=30, +) +@tilelang.jit +def matmul_do_not_spec(M, N, K, threads=128): + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, 64), T.ceildiv(M, 64), threads=threads) as (bx, by): + A_shared = T.alloc_shared((64, 32), dtype) + B_shared = T.alloc_shared((64, 32), dtype) + C_local = T.alloc_fragment((64, 64), accum_dtype) + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, 32), num_stages=1): + T.copy(A[by * 64, k * 32], A_shared) + T.copy(B[bx * 64, k * 32], B_shared) + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + T.copy(C_local, C[by * 64, bx * 64]) + + return main + + +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _call_and_count_new_entries(fn, M, N, K, threads=None): + """Call fn and return how many *new* cache entries were added.""" + prev_size = len(fn._tuner_cache) + a = torch.randn(M, K, dtype=torch.float16, device="cuda") + b = torch.randn(N, K, dtype=torch.float16, device="cuda") + call_kwargs = {"M": M, "N": N, "K": K} + if threads is not None: + call_kwargs["threads"] = threads + with set_autotune_inputs([a, b]): + fn(**call_kwargs) + new_size = len(fn._tuner_cache) + return new_size - prev_size + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _clear_autotune_cache(): + """Clear the shared autotune cache before each test so tests are isolated.""" + matmul_do_not_spec._tuner_cache.clear() + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_same_key_for_different_values(): + """Changing N and K (in do_not_specialize) should reuse the same cache entry.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=512, K=512, threads=64) + + assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" + assert new2 == 0, f"do_not_specialize failed: second call with different N/K should reuse cache (0 new entries), got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_new_key_for_m_param(): + """Changing M (NOT in do_not_specialize) should trigger a new autotune run.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=1024, N=256, K=256, threads=64) + + assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" + assert new2 == 1, f"Expected 1 new cache entry because M is NOT in do_not_specialize, got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_kwargs_and_args(): + """do_not_specialize should work whether params are passed as args or kwargs.""" + a = torch.randn(512, 256, dtype=torch.float16, device="cuda") + b = torch.randn(256, 256, dtype=torch.float16, device="cuda") + + # First call: all kwargs + with set_autotune_inputs([a, b]): + matmul_do_not_spec(M=512, N=256, K=256, threads=64) + prev = len(matmul_do_not_spec._tuner_cache) + + # Second call: positional args, N and K differ but are in do_not_specialize + with set_autotune_inputs([a, b]): + matmul_do_not_spec(512, 512, 512, threads=64) + new = len(matmul_do_not_spec._tuner_cache) + + assert new == prev, f"do_not_specialize failed with positional args: cache grew from {prev} to {new} (expected no new entries)" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_threads_new_entry(): + """threads is NOT in do_not_specialize → different values = new cache entry.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=128) + + assert new1 == 1 + assert new2 == 1, f"Changing threads (not in do_not_specialize) should create a new cache entry, got {new2}" + + +@tilelang.testing.requires_cuda +def test_do_not_specialize_all_specialized_params_same_cache(): + """When all params except do_not_specialize are the same, cache is shared.""" + new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) + new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=999, K=333, threads=64) + + assert new1 == 1 + assert new2 == 0, f"All non-specialized params (M, threads) are same → should reuse cache, got {new2} new entries" + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py b/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py new file mode 100644 index 0000000000..2f444dc84a --- /dev/null +++ b/testing/python/autotune/test_tilelang_autotune_scalar_inputs.py @@ -0,0 +1,43 @@ +import pytest +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.autotuner import set_autotune_inputs + + +@tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60) +@tilelang.jit +def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128): + @T.prim_func + def kernel(A: T.Tensor((N,), T.float32), s: T.float32): + with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n: + A_local = T.alloc_fragment((BLOCK_N,), T.float32) + T.copy(A[pid_n * BLOCK_N], A_local) + for i in T.Parallel(BLOCK_N): + A_local[i] += s + T.copy(A_local, A[pid_n * BLOCK_N]) + + return kernel + + +def test_autotune_scalar_inputs_require_explicit_supply(): + with pytest.raises(ValueError, match=r"set_autotune_inputs"): + add_scalar() + + +def test_autotune_scalar_inputs_with_set_autotune_inputs(): + tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32) + tune_s = 0.1 + with set_autotune_inputs(tune_a, tune_s): + kernel = add_scalar() + + a = torch.randn((4096,), device="cuda", dtype=torch.float32) + before = a.clone() + kernel(a, tune_s) + + torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/tilelang/autotuner/tuner.py b/tilelang/autotuner/tuner.py index 9a82294825..66b81dc8ce 100644 --- a/tilelang/autotuner/tuner.py +++ b/tilelang/autotuner/tuner.py @@ -391,6 +391,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple self._kernel_parameters = k_parameters self._function_parameters = f_parameters + def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None): + if self.profile_args.supply_prog is not None: + return + + if out_idx is None: + result_idx = [] + elif isinstance(out_idx, int): + result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx] + else: + result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx] + + for i, param in enumerate(prim_func.params): + if i in result_idx: + continue + if param not in prim_func.buffer_map: + kernel_name = get_prim_func_name(prim_func, "") + raise ValueError( + f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. " + f"Found scalar input parameter '{param.name}'. " + "Please provide concrete inputs with `with set_autotune_inputs(...)`." + ) + def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: """Generate a cache key for the auto-tuning process.""" @@ -625,6 +647,12 @@ def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) self._memory_cache[key] = autotuner_result return autotuner_result + + # After confirming tuning will actually run, validate that scalar + # inputs can be supplied (either via supply_prog or set_autotune_inputs). + if hasattr(self, "_prim_func_for_validation"): + self._validate_input_supply_requirements(self._prim_func_for_validation, self.compile_args.out_idx) + # get the cpu count available_cpu_count = get_available_cpu_count() cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) @@ -770,6 +798,7 @@ class AutoTuneImpl(Generic[_P, _T]): skip_check: bool = False manual_check_prog: Callable = None cache_input_tensors: bool = False + do_not_specialize: tuple[str, ...] | list[str] | None = None def __post_init__(self): self._tuner_cache = {} @@ -804,9 +833,24 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: return_kernel = kwargs.pop("__return_kernel", False) mode = self.jit_impl.initialize_jit_mode(*args, **kwargs) - - norm_args = _normalize_value(args, sort_dict_items=True) - norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) + autotuner = self.get_tunner() + # Defer scalar-input validation to run(), after we know whether + # tuning will actually execute or be skipped because all tunable + # parameters are already provided by the caller. + autotuner._prim_func_for_validation = self.jit_impl.get_tir(*args, **kwargs) + + # Compute the cache key, excluding do_not_specialize parameters + # so that changing them does not trigger re-autotuning. + if self.do_not_specialize: + # Bind all args to kwargs so positional vs keyword doesn't affect key + bound = self.jit_impl.signature.bind(*args, **kwargs) + bound.apply_defaults() + filtered_kwargs = {k: v for k, v in bound.arguments.items() if k not in self.do_not_specialize} + norm_args = () + norm_kwargs = _normalize_value(filtered_kwargs, sort_dict_items=True) + else: + norm_args = _normalize_value(args, sort_dict_items=True) + norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) key = (norm_args, norm_kwargs) if key not in self._tuner_cache: if mode == "lazy": @@ -814,7 +858,6 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: def jit_compile(**config_arg): return self.jit_impl(*args, **kwargs, __tune_params=config_arg) - autotuner = self.get_tunner() autotuner.jit_compile = jit_compile autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) else: @@ -824,7 +867,6 @@ def jit_compile(**config_arg): merged.update(config_arg) return self.jit_impl.compile(*args, **merged) - autotuner = self.get_tunner() autotuner.jit_compile = jit_compile autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) @@ -866,6 +908,7 @@ def autotune( # This is the new public interface skip_check: bool = False, manual_check_prog: Callable = None, cache_input_tensors: bool = False, + do_not_specialize: tuple[str, ...] | list[str] | None = None, ): """ Just-In-Time (JIT) compiler decorator for TileLang functions. @@ -941,6 +984,7 @@ def decorator(impl): skip_check=skip_check, manual_check_prog=manual_check_prog, cache_input_tensors=cache_input_tensors, + do_not_specialize=do_not_specialize, ) return decorator