Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions testing/python/utils/test_tensor_supply_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
"""Regression tests for scalar-parameter input generation.

These tests cover `tilelang.utils.tensor.get_tensor_supply` for the
scalar (empty-shape) parameter case. Previously, calling the supplier
with a scalar `KernelParam` raised `ValueError`, which broke the
autotuner for any kernel signature that included a scalar value
parameter (e.g. `def kernel(A: T.Tensor(...), s: T.float32):`).

See: https://github.com/tile-ai/tilelang/issues/2081
"""

import pytest

import tilelang
import tilelang.testing
from tilelang import tvm
from tilelang.engine.param import KernelParam
from tilelang.utils.tensor import TensorSupplyType, get_tensor_supply


# (dtype string, expected Python type for the supplied scalar)
_SCALAR_DTYPE_CASES = [
("float32", float),
("float16", float),
("bfloat16", float),
("float64", float),
("int32", int),
("int64", int),
("int8", int),
("uint8", int),
("bool", bool),
]


@pytest.mark.parametrize("dtype_str,expected_py_type", _SCALAR_DTYPE_CASES)
@pytest.mark.parametrize("supply_type", list(TensorSupplyType))
def test_scalar_param_returns_python_scalar(dtype_str, expected_py_type, supply_type):
"""A scalar `KernelParam` should yield a Python scalar of the right
dtype family for every `TensorSupplyType`. This is the fallback
that allows the autotuner to invoke kernels that take scalar
value parameters; users can still supply explicit values via
`supply_prog`. Regression for #2081.
"""
param = KernelParam(dtype=tvm.DataType(dtype_str), shape=[])
supply = get_tensor_supply(supply_type)

value = supply(param)

assert isinstance(value, expected_py_type), (
f"Expected a {expected_py_type.__name__} for {dtype_str} scalar under {supply_type}, got {type(value).__name__} ({value!r})"
)


def test_scalar_supply_does_not_require_cuda():
"""The scalar fast path must not depend on a CUDA device, so that
autotuner input generation works on CPU-only hosts as well as on
GPU machines."""
param = KernelParam(dtype=tvm.DataType("float32"), shape=[])
supply = get_tensor_supply(TensorSupplyType.Integer)
# Should not raise, and should not touch CUDA at all.
value = supply(param)
assert isinstance(value, float)


if __name__ == "__main__":
tilelang.testing.main()
15 changes: 10 additions & 5 deletions tilelang/utils/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,18 @@ def get_tensor_supply(supply_type: TensorSupplyType = TensorSupplyType.Integer):
def get_tensor(param: KernelParam) -> torch.Tensor:
# Convert tvm.DataType to torch.dtype for tensor creation
dtype: torch.dtype = param.torch_dtype()
device = get_current_device()

# Scalar value parameter (empty shape), e.g. `s: T.float32` in the
# kernel signature. Return a Python scalar of the matching dtype
# family so the autotuner can invoke kernels with scalar arguments
# without crashing. Users that need a specific scalar value can
# still pass it explicitly via `supply_prog`. See #2081.
if hasattr(param, "shape") and not param.shape:
raise ValueError(
f"TensorType must have a shape, but got {type(param)}, "
"likely you are trying to generate a random tensor with a dynamic symbolic shape."
)
if hasattr(param, "is_boolean") and param.is_boolean():
return False
return 0.0 if dtype.is_floating_point else 0

device = get_current_device()

# Check if with dynamic symbolic shape
for shape in param.shape:
Expand Down
Loading