-
Notifications
You must be signed in to change notification settings - Fork 555
[Enhance] Reject default scalar params and support do_not_specialize for autotune
#2084
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| import pytest | ||
| import torch | ||
| import tilelang | ||
| import tilelang.language as T | ||
| import tilelang.testing | ||
| from tilelang.autotuner import set_autotune_inputs | ||
|
|
||
|
|
||
| @tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60) | ||
| @tilelang.jit | ||
| def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128): | ||
| @T.prim_func | ||
| def kernel(A: T.Tensor((N,), T.float32), s: T.float32): | ||
| with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n: | ||
| A_local = T.alloc_fragment((BLOCK_N,), T.float32) | ||
| T.copy(A[pid_n * BLOCK_N], A_local) | ||
| for i in T.Parallel(BLOCK_N): | ||
| A_local[i] += s | ||
| T.copy(A_local, A[pid_n * BLOCK_N]) | ||
|
|
||
| return kernel | ||
|
|
||
|
|
||
| def test_autotune_scalar_inputs_require_explicit_supply(): | ||
| with pytest.raises(ValueError, match=r"set_autotune_inputs"): | ||
| add_scalar() | ||
|
|
||
|
|
||
| def test_autotune_scalar_inputs_with_set_autotune_inputs(): | ||
| tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32) | ||
| tune_s = 0.1 | ||
| with set_autotune_inputs(tune_a, tune_s): | ||
| kernel = add_scalar() | ||
|
|
||
| a = torch.randn((4096,), device="cuda", dtype=torch.float32) | ||
| before = a.clone() | ||
| kernel(a, tune_s) | ||
|
|
||
| torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| tilelang.testing.main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -288,6 +288,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple | |
| self._kernel_parameters = k_parameters | ||
| self._function_parameters = f_parameters | ||
|
|
||
| def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None): | ||
| if self.profile_args.supply_prog is not None: | ||
| return | ||
|
|
||
| if out_idx is None: | ||
| result_idx = [] | ||
| elif isinstance(out_idx, int): | ||
| result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx] | ||
| else: | ||
| result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx] | ||
|
Comment on lines
+398
to
+403
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Validate Line 298 and Line 300 currently accept out-of-range values such as 🐛 Proposed fix- if out_idx is None:
- result_idx = []
- elif isinstance(out_idx, int):
- result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
- else:
- result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
+ param_count = len(prim_func.params)
+ if out_idx is None:
+ result_idx = set()
+ elif isinstance(out_idx, int):
+ if out_idx >= param_count or out_idx < -param_count:
+ raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+ result_idx = {param_count + out_idx if out_idx < 0 else out_idx}
+ elif isinstance(out_idx, list):
+ result_idx = set()
+ for idx in out_idx:
+ if idx >= param_count or idx < -param_count:
+ raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+ result_idx.add(param_count + idx if idx < 0 else idx)
+ else:
+ raise ValueError("out_idx should be a list of integers")🤖 Prompt for AI Agents |
||
|
|
||
| for i, param in enumerate(prim_func.params): | ||
| if i in result_idx: | ||
| continue | ||
| if param not in prim_func.buffer_map: | ||
| kernel_name = get_prim_func_name(prim_func, "<unknown>") | ||
| raise ValueError( | ||
| f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. " | ||
| f"Found scalar input parameter '{param.name}'. " | ||
| "Please provide concrete inputs with `with set_autotune_inputs(...)`." | ||
| ) | ||
|
|
||
| def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None: | ||
| """Generate a cache key for the auto-tuning process.""" | ||
|
|
||
|
|
@@ -701,6 +723,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: | |
| return_kernel = kwargs.pop("__return_kernel", False) | ||
|
|
||
| mode = self.jit_impl.initialize_jit_mode(*args, **kwargs) | ||
| autotuner = self.get_tunner() | ||
| autotuner._validate_input_supply_requirements(self.jit_impl.get_tir(*args, **kwargs), self.jit_impl.out_idx) | ||
|
|
||
| norm_args = _normalize_value(args, sort_dict_items=True) | ||
| norm_kwargs = _normalize_value(kwargs, sort_dict_items=True) | ||
|
|
@@ -711,7 +735,6 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T: | |
| def jit_compile(**config_arg): | ||
| return self.jit_impl(*args, **kwargs, __tune_params=config_arg) | ||
|
|
||
| autotuner = self.get_tunner() | ||
| autotuner.jit_compile = jit_compile | ||
| autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) | ||
| else: | ||
|
|
@@ -721,7 +744,6 @@ def jit_compile(**config_arg): | |
| merged.update(config_arg) | ||
| return self.jit_impl.compile(*args, **merged) | ||
|
|
||
| autotuner = self.get_tunner() | ||
| autotuner.jit_compile = jit_compile | ||
| autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters) | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 50373
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 1704
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3362
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 692
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 90
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 322
🏁 Script executed:
Repository: tile-ai/tilelang
Length of output: 3930
Add CUDA guard to test requiring GPU.
The test allocates CUDA tensors on lines 30 and 35, so it will fail on CPU-only runners during setup. Use the repository's standard decorator:
Proposed fix
+@tilelang.testing.requires_cuda def test_autotune_scalar_inputs_with_set_autotune_inputs(): tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)📝 Committable suggestion
🤖 Prompt for AI Agents