-
Notifications
You must be signed in to change notification settings - Fork 553
[Enhance] Reject default scalar params and support do_not_specialize for autotune
#2084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| """Tests for the `do_not_specialize` parameter of @autotune decorator. | ||
|
|
||
| `do_not_specialize` allows users to specify parameters whose value changes | ||
| should NOT trigger re-autotuning. The best config found on the first call | ||
| is reused regardless of the value of those parameters. | ||
| """ | ||
|
|
||
| import pytest | ||
| import tilelang | ||
| import tilelang.testing | ||
| from tilelang.autotuner import set_autotune_inputs | ||
| import tilelang.language as T | ||
| import torch | ||
|
|
||
|
|
||
| def get_configs(): | ||
| return [ | ||
| {"threads": 64}, | ||
| {"threads": 128}, | ||
| ] | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Kernel under test | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @tilelang.autotune( | ||
| configs=get_configs(), | ||
| do_not_specialize=["N", "K"], | ||
| warmup=5, | ||
| rep=10, | ||
| timeout=30, | ||
| ) | ||
| @tilelang.jit | ||
| def matmul_do_not_spec(M, N, K, threads=128): | ||
| dtype = T.float16 | ||
| accum_dtype = T.float32 | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((N, K), dtype), | ||
| C: T.Tensor((M, N), dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, 64), T.ceildiv(M, 64), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared((64, 32), dtype) | ||
| B_shared = T.alloc_shared((64, 32), dtype) | ||
| C_local = T.alloc_fragment((64, 64), accum_dtype) | ||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, 32), num_stages=1): | ||
| T.copy(A[by * 64, k * 32], A_shared) | ||
| T.copy(B[bx * 64, k * 32], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local, transpose_B=True) | ||
| T.copy(C_local, C[by * 64, bx * 64]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Test helpers | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| def _call_and_count_new_entries(fn, M, N, K, threads=None): | ||
| """Call fn and return how many *new* cache entries were added.""" | ||
| prev_size = len(fn._tuner_cache) | ||
| a = torch.randn(M, K, dtype=torch.float16, device="cuda") | ||
| b = torch.randn(N, K, dtype=torch.float16, device="cuda") | ||
| call_kwargs = {"M": M, "N": N, "K": K} | ||
| if threads is not None: | ||
| call_kwargs["threads"] = threads | ||
| with set_autotune_inputs([a, b]): | ||
| fn(**call_kwargs) | ||
| new_size = len(fn._tuner_cache) | ||
| return new_size - prev_size | ||
|
|
||
|
|
||
| # --------------------------------------------------------------------------- | ||
| # Tests | ||
| # --------------------------------------------------------------------------- | ||
|
|
||
|
|
||
| @pytest.fixture(autouse=True) | ||
| def _clear_autotune_cache(): | ||
| """Clear the shared autotune cache before each test so tests are isolated.""" | ||
| matmul_do_not_spec._tuner_cache.clear() | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_do_not_specialize_same_key_for_different_values(): | ||
| """Changing N and K (in do_not_specialize) should reuse the same cache entry.""" | ||
| new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) | ||
| new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=512, K=512, threads=64) | ||
|
|
||
| assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" | ||
| assert new2 == 0, f"do_not_specialize failed: second call with different N/K should reuse cache (0 new entries), got {new2}" | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_do_not_specialize_new_key_for_m_param(): | ||
| """Changing M (NOT in do_not_specialize) should trigger a new autotune run.""" | ||
| new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) | ||
| new2 = _call_and_count_new_entries(matmul_do_not_spec, M=1024, N=256, K=256, threads=64) | ||
|
|
||
| assert new1 == 1, f"First call should create 1 new cache entry, got {new1}" | ||
| assert new2 == 1, f"Expected 1 new cache entry because M is NOT in do_not_specialize, got {new2}" | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_do_not_specialize_kwargs_and_args(): | ||
| """do_not_specialize should work whether params are passed as args or kwargs.""" | ||
| a = torch.randn(512, 256, dtype=torch.float16, device="cuda") | ||
| b = torch.randn(256, 256, dtype=torch.float16, device="cuda") | ||
|
|
||
| # First call: all kwargs | ||
| with set_autotune_inputs([a, b]): | ||
| matmul_do_not_spec(M=512, N=256, K=256, threads=64) | ||
| prev = len(matmul_do_not_spec._tuner_cache) | ||
|
|
||
| # Second call: positional args, N and K differ but are in do_not_specialize | ||
| with set_autotune_inputs([a, b]): | ||
| matmul_do_not_spec(512, 512, 512, threads=64) | ||
| new = len(matmul_do_not_spec._tuner_cache) | ||
|
|
||
| assert new == prev, f"do_not_specialize failed with positional args: cache grew from {prev} to {new} (expected no new entries)" | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_do_not_specialize_threads_new_entry(): | ||
| """threads is NOT in do_not_specialize → different values = new cache entry.""" | ||
| new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) | ||
| new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=128) | ||
|
|
||
| assert new1 == 1 | ||
| assert new2 == 1, f"Changing threads (not in do_not_specialize) should create a new cache entry, got {new2}" | ||
|
|
||
|
|
||
| @tilelang.testing.requires_cuda | ||
| def test_do_not_specialize_all_specialized_params_same_cache(): | ||
| """When all params except do_not_specialize are the same, cache is shared.""" | ||
| new1 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=256, K=256, threads=64) | ||
| new2 = _call_and_count_new_entries(matmul_do_not_spec, M=512, N=999, K=333, threads=64) | ||
|
|
||
| assert new1 == 1 | ||
| assert new2 == 0, f"All non-specialized params (M, threads) are same → should reuse cache, got {new2} new entries" | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import pytest | ||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import tilelang.testing | ||
| from tilelang.autotuner import set_autotune_inputs | ||
|
|
||
|
|
||
| @tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60) | ||
| @tilelang.jit | ||
| def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128): | ||
| @T.prim_func | ||
| def kernel(A: T.Tensor((N,), T.float32), s: T.float32): | ||
| with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n: | ||
| A_local = T.alloc_fragment((BLOCK_N,), T.float32) | ||
| T.copy(A[pid_n * BLOCK_N], A_local) | ||
| for i in T.Parallel(BLOCK_N): | ||
| A_local[i] += s | ||
| T.copy(A_local, A[pid_n * BLOCK_N]) | ||
|
|
||
| return kernel | ||
|
|
||
|
|
||
| def test_autotune_scalar_inputs_require_explicit_supply(): | ||
| with pytest.raises(ValueError, match=r"set_autotune_inputs"): | ||
| add_scalar() | ||
|
|
||
|
|
||
| def test_autotune_scalar_inputs_with_set_autotune_inputs(): | ||
| tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32) | ||
| tune_s = 0.1 | ||
| with set_autotune_inputs(tune_a, tune_s): | ||
| kernel = add_scalar() | ||
|
|
||
| a = torch.randn((4096,), device="cuda", dtype=torch.float32) | ||
| before = a.clone() | ||
| kernel(a, tune_s) | ||
|
|
||
| torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -391,6 +391,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple | |
| self._kernel_parameters = k_parameters | ||
| self._function_parameters = f_parameters | ||
|
|
||
| def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None): | ||
| if self.profile_args.supply_prog is not None: | ||
| return | ||
|
|
||
| if out_idx is None: | ||
| result_idx = [] | ||
| elif isinstance(out_idx, int): | ||
| result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx] | ||
| else: | ||
| result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx] | ||
|
Comment on lines
+398
to
+403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate Line 298 and Line 300 currently accept out-of-range values such as 🐛 Proposed fix- if out_idx is None:
- result_idx = []
- elif isinstance(out_idx, int):
- result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
- else:
- result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
+ param_count = len(prim_func.params)
+ if out_idx is None:
+ result_idx = set()
+ elif isinstance(out_idx, int):
+ if out_idx >= param_count or out_idx < -param_count:
+ raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+ result_idx = {param_count + out_idx if out_idx < 0 else out_idx}
+ elif isinstance(out_idx, list):
+ result_idx = set()
+ for idx in out_idx:
+ if idx >= param_count or idx < -param_count:
+ raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+ result_idx.add(param_count + idx if idx < 0 else idx)
+ else:
+ raise ValueError("out_idx should be a list of integers")🤖 Prompt for AI Agents |
||
|
|
||
| for i, param in enumerate(prim_func.params): | ||
| if i in result_idx: | ||
| continue | ||
| if param not in prim_func.buffer_map: | ||
| kernel_name = get_prim_func_name(prim_func, "<unknown>") | ||
| raise ValueError( | ||
| f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. " | ||
| f"Found scalar input parameter '{param.name}'. " | ||
| "Please provide concrete inputs with `with set_autotune_inputs(...)`." | ||
| ) | ||
|
|
||
| def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: | ||
| """Generate a cache key for the auto-tuning process.""" | ||
|
|
||
|
|
@@ -625,6 +647,12 @@ def check_tunable_argument_value(key, parameters, key_args_tuple) -> bool: | |
| autotuner_result = AutotuneResult(libcode=jit_kernel.get_kernel_source(), func=jit_kernel.prim_func, kernel=jit_kernel) | ||
| self._memory_cache[key] = autotuner_result | ||
| return autotuner_result | ||
|
|
||
| # After confirming tuning will actually run, validate that scalar | ||
| # inputs can be supplied (either via supply_prog or set_autotune_inputs). | ||
| if hasattr(self, "_prim_func_for_validation"): | ||
| self._validate_input_supply_requirements(self._prim_func_for_validation, self.compile_args.out_idx) | ||
|
|
||
| # get the cpu count | ||
| available_cpu_count = get_available_cpu_count() | ||
| cpu_utilizations = float(env.TILELANG_AUTO_TUNING_CPU_UTILITIES) | ||
|
|
@@ -770,6 +798,7 @@ class AutoTuneImpl(Generic[_P, _T]): | |
| skip_check: bool = False | ||
| manual_check_prog: Callable = None | ||
| cache_input_tensors: bool = False | ||
| do_not_specialize: tuple[str, ...] | list[str] | None = None | ||
|
|
||
| def __post_init__(self): | ||
| self._tuner_cache = {} | ||
|
|
@@ -804,17 +833,31 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: | |
| return_kernel = kwargs.pop("__return_kernel", False) | ||
|
|
||
| mode = self.jit_impl.initialize_jit_mode(*args, **kwargs) | ||
|
|
||
| norm_args = _normalize_value(args, sort_dict_items=True) | ||
| norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) | ||
| autotuner = self.get_tunner() | ||
| # Defer scalar-input validation to run(), after we know whether | ||
| # tuning will actually execute or be skipped because all tunable | ||
| # parameters are already provided by the caller. | ||
| autotuner._prim_func_for_validation = self.jit_impl.get_tir(*args, **kwargs) | ||
|
|
||
| # Compute the cache key, excluding do_not_specialize parameters | ||
| # so that changing them does not trigger re-autotuning. | ||
| if self.do_not_specialize: | ||
| # Bind all args to kwargs so positional vs keyword doesn't affect key | ||
| bound = self.jit_impl.signature.bind(*args, **kwargs) | ||
| bound.apply_defaults() | ||
| filtered_kwargs = {k: v for k, v in bound.arguments.items() if k not in self.do_not_specialize} | ||
| norm_args = () | ||
| norm_kwargs = _normalize_value(filtered_kwargs, sort_dict_items=True) | ||
| else: | ||
| norm_args = _normalize_value(args, sort_dict_items=True) | ||
| norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) | ||
| key = (norm_args, norm_kwargs) | ||
| if key not in self._tuner_cache: | ||
| if mode == "lazy": | ||
|
|
||
| def jit_compile(**config_arg): | ||
| return self.jit_impl(*args, **kwargs, __tune_params=config_arg) | ||
|
|
||
| autotuner = self.get_tunner() | ||
| autotuner.jit_compile = jit_compile | ||
| autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) | ||
| else: | ||
|
|
@@ -824,7 +867,6 @@ def jit_compile(**config_arg): | |
| merged.update(config_arg) | ||
| return self.jit_impl.compile(*args, **merged) | ||
|
|
||
| autotuner = self.get_tunner() | ||
| autotuner.jit_compile = jit_compile | ||
| autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) | ||
|
|
||
|
|
@@ -866,6 +908,7 @@ def autotune( # This is the new public interface | |
| skip_check: bool = False, | ||
| manual_check_prog: Callable = None, | ||
| cache_input_tensors: bool = False, | ||
| do_not_specialize: tuple[str, ...] | list[str] | None = None, | ||
| ): | ||
| """ | ||
| Just-In-Time (JIT) compiler decorator for TileLang functions. | ||
|
|
@@ -941,6 +984,7 @@ def decorator(impl): | |
| skip_check=skip_check, | ||
| manual_check_prog=manual_check_prog, | ||
| cache_input_tensors=cache_input_tensors, | ||
| do_not_specialize=do_not_specialize, | ||
| ) | ||
|
|
||
| return decorator | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 50373
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1704
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3362
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 692
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 322
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3930
Add CUDA guard to test requiring GPU.
The test allocates CUDA tensors on lines 30 and 35, so it will fail on CPU-only runners during setup. Use the repository's standard decorator:
Proposed fix
+@tilelang.testing.requires_cuda def test_autotune_scalar_inputs_with_set_autotune_inputs(): tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)📝 Committable suggestion
🤖 Prompt for AI Agents