Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions testing/python/autotune/test_tilelang_autotune_scalar_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import pytest
import torch
import tilelang
import tilelang.language as T
import tilelang.testing
from tilelang.autotuner import set_autotune_inputs


@tilelang.autotune(configs=[{"threads": 128}, {"threads": 256}], warmup=1, rep=1, timeout=60)
@tilelang.jit
def add_scalar(N: int = 4096, BLOCK_N: int = 512, threads: int = 128):
@T.prim_func
def kernel(A: T.Tensor((N,), T.float32), s: T.float32):
with T.Kernel(T.ceildiv(N, BLOCK_N), threads=threads) as pid_n:
A_local = T.alloc_fragment((BLOCK_N,), T.float32)
T.copy(A[pid_n * BLOCK_N], A_local)
for i in T.Parallel(BLOCK_N):
A_local[i] += s
T.copy(A_local, A[pid_n * BLOCK_N])

return kernel


def test_autotune_scalar_inputs_require_explicit_supply():
with pytest.raises(ValueError, match=r"set_autotune_inputs"):
add_scalar()


def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()

a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)

torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
Comment on lines +29 to +39
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Description: Inspect how existing tests guard CUDA-only allocations.
rg -n -C2 'pytest\.mark\.skipif|torch\.cuda\.is_available|requires_cuda|device="cuda"' testing

Repository: tile-ai/tilelang

Length of output: 50373


🏁 Script executed:

#!/bin/bash
# Examine the specific test file mentioned in the review
cat -n testing/python/autotune/test_tilelang_autotune_scalar_inputs.py

Repository: tile-ai/tilelang

Length of output: 1704


🏁 Script executed:

#!/bin/bash
# Check how other tests in the autotune directory are guarded
ls -la testing/python/autotune/*.py | head -20
rg -n '@pytest.mark.skipif|@tilelang.testing.requires_cuda|import pytest|import torch' testing/python/autotune/ -A 2 | head -100

Repository: tile-ai/tilelang

Length of output: 3362


🏁 Script executed:

#!/bin/bash
# Check if there's a conftest or similar setup in autotune directory that might guard all tests
ls -la testing/python/autotune/
fd conftest testing/python/autotune/

Repository: tile-ai/tilelang

Length of output: 692


🏁 Script executed:

#!/bin/bash
# Check what tilelang.testing.requires_cuda is
rg -A5 'def requires_cuda|requires_cuda =' tilelang/testing --type python | head -30

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

#!/antml:bash
# Also check for imports and definition
fd -e py testing | xargs grep -l 'requires_cuda' | head -5
find . -path ./node_modules -prune -o -name '*.py' -type f -exec grep -l 'requires_cuda.*=' {} \; | head -5

Repository: tile-ai/tilelang

Length of output: 90


🏁 Script executed:

#!/bin/bash
# Look for tilelang.testing module definition
find . -name "testing.py" -o -name "testing" -type d | grep tilelang
ls -la tilelang/testing* 2>/dev/null || echo "No direct tilelang/testing files"

Repository: tile-ai/tilelang

Length of output: 322


🏁 Script executed:

#!/bin/bash
# Check the tilelang/testing/__init__.py file
cat tilelang/testing/__init__.py

Repository: tile-ai/tilelang

Length of output: 3930


Add CUDA guard to test requiring GPU.

The test allocates CUDA tensors on lines 30 and 35, so it will fail on CPU-only runners during setup. Use the repository's standard decorator:

Proposed fix
+@tilelang.testing.requires_cuda
 def test_autotune_scalar_inputs_with_set_autotune_inputs():
     tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()
a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)
torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
`@tilelang.testing.requires_cuda`
def test_autotune_scalar_inputs_with_set_autotune_inputs():
tune_a = torch.randn((4096,), device="cuda", dtype=torch.float32)
tune_s = 0.1
with set_autotune_inputs(tune_a, tune_s):
kernel = add_scalar()
a = torch.randn((4096,), device="cuda", dtype=torch.float32)
before = a.clone()
kernel(a, tune_s)
torch.testing.assert_close(a, before + tune_s, rtol=1e-4, atol=1e-4)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@testing/python/autotune/test_tilelang_autotune_scalar_inputs.py` around lines
29 - 39, The test test_autotune_scalar_inputs_with_set_autotune_inputs allocates
CUDA tensors via set_autotune_inputs and local tensors (add_scalar), so add the
repository's standard CUDA guard to the test (e.g., decorate
test_autotune_scalar_inputs_with_set_autotune_inputs with the project's
`@requires_cuda` or a pytest.skipif(not torch.cuda.is_available()) equivalent) and
import the guard if necessary; this ensures the test is skipped on CPU-only
runners while leaving set_autotune_inputs and add_scalar unchanged.



if __name__ == "__main__":
tilelang.testing.main()
26 changes: 24 additions & 2 deletions tilelang/autotuner/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,28 @@ def set_kernel_parameters(self, k_parameters: tuple[tuple[Any, ...], tuple[tuple
self._kernel_parameters = k_parameters
self._function_parameters = f_parameters

def _validate_input_supply_requirements(self, prim_func: PrimFunc, out_idx: list[int] | int | None):
if self.profile_args.supply_prog is not None:
return

if out_idx is None:
result_idx = []
elif isinstance(out_idx, int):
result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
else:
result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
Comment on lines +398 to +403
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Validate out_idx before normalizing negative indices.

Line 298 and Line 300 currently accept out-of-range values such as -len(params) - 1, normalizing them to invalid indices. Mirror the existing result-index legalization so this validator does not silently accept bad output indices or raise a misleading scalar-input error.

🐛 Proposed fix
-        if out_idx is None:
-            result_idx = []
-        elif isinstance(out_idx, int):
-            result_idx = [len(prim_func.params) + out_idx if out_idx < 0 else out_idx]
-        else:
-            result_idx = [len(prim_func.params) + idx if idx < 0 else idx for idx in out_idx]
+        param_count = len(prim_func.params)
+        if out_idx is None:
+            result_idx = set()
+        elif isinstance(out_idx, int):
+            if out_idx >= param_count or out_idx < -param_count:
+                raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+            result_idx = {param_count + out_idx if out_idx < 0 else out_idx}
+        elif isinstance(out_idx, list):
+            result_idx = set()
+            for idx in out_idx:
+                if idx >= param_count or idx < -param_count:
+                    raise ValueError(f"out_idx should be between {-param_count} and {param_count - 1}")
+                result_idx.add(param_count + idx if idx < 0 else idx)
+        else:
+            raise ValueError("out_idx should be a list of integers")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@tilelang/autotuner/tuner.py` around lines 295 - 300, Validate out_idx values
before converting negative indices: for the single-int and iterable branches,
check each original out_idx is within the valid result-index range (use the same
bounds logic as the existing result-index legalization helper) instead of
blindly adding len(prim_func.params); if any index is out-of-range raise the
same validation error used elsewhere (not the scalar-input error). Ensure you
reference out_idx, prim_func.params and produce result_idx only after passing
this bounds check so values like -len(params)-1 are rejected rather than
normalized to an invalid index.


for i, param in enumerate(prim_func.params):
if i in result_idx:
continue
if param not in prim_func.buffer_map:
kernel_name = get_prim_func_name(prim_func, "<unknown>")
raise ValueError(
f"Auto-tuning kernel '{kernel_name}' does not support auto-generated scalar inputs. "
f"Found scalar input parameter '{param.name}'. "
"Please provide concrete inputs with `with set_autotune_inputs(...)`."
)

def generate_cache_key(self, parameters: dict[str, Any], extra_parameters: dict[str, Any]) -> AutotuneResult | None:
"""Generate a cache key for the auto-tuning process."""

Expand Down Expand Up @@ -701,6 +723,8 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T:
return_kernel = kwargs.pop("__return_kernel", False)

mode = self.jit_impl.initialize_jit_mode(*args, **kwargs)
autotuner = self.get_tunner()
autotuner._validate_input_supply_requirements(self.jit_impl.get_tir(*args, **kwargs), self.jit_impl.out_idx)

norm_args = _normalize_value(args, sort_dict_items=True)
norm_kwargs = _normalize_value(kwargs, sort_dict_items=True)
Expand All @@ -711,7 +735,6 @@ def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> JITKernel | _T:
def jit_compile(**config_arg):
return self.jit_impl(*args, **kwargs, __tune_params=config_arg)

autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)
else:
Expand All @@ -721,7 +744,6 @@ def jit_compile(**config_arg):
merged.update(config_arg)
return self.jit_impl.compile(*args, **merged)

autotuner = self.get_tunner()
autotuner.jit_compile = jit_compile
autotuner.set_kernel_parameters(key, self.jit_impl.signature.parameters)

Expand Down
Loading