diff --git a/benchmarks/ops/bench_max_pool2d.py b/benchmarks/ops/bench_max_pool2d.py new file mode 100644 index 000000000..ee96ab710 --- /dev/null +++ b/benchmarks/ops/bench_max_pool2d.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from benchmarks.benchmark import BenchmarkBase, BenchmarkReport +from tileops.kernels.pool.common import pool_output_dim +from tileops.ops import MaxPool2dOp + + +class MaxPool2dBenchCase: + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = kernel_size + self.stride = kernel_size if stride is None else stride + self.padding = padding + self.dilation = dilation + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self) -> tuple[torch.Tensor]: + x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor: + return F.max_pool2d( + x, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ceil_mode=self.ceil_mode, + ) + + +class MaxPool2dBenchmark(BenchmarkBase): + def calculate_flops(self) -> Optional[float]: + t = self.workload + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] + + def calculate_memory(self) -> Optional[float]: + t = self.workload + out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) + out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) + return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize + + +_MAX_POOL2D_BASE_CASES = [ + (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), + (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), + (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), +] + +_MAX_POOL2D_BENCH_PARAMS = [ + pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") + for case in _MAX_POOL2D_BASE_CASES + for dtype in (torch.float16, torch.bfloat16) +] + + +@pytest.mark.parametrize( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", + _MAX_POOL2D_BENCH_PARAMS, +) +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_bench( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dBenchCase( + n, + c_in, + h_in, + w_in, + kernel_size, + stride, + padding, + dilation, + ceil_mode, + dtype, + ) + bm = MaxPool2dBenchmark(test) + inputs = test.gen_inputs() + (x,) = inputs + x_nchw = x.permute(0, 3, 1, 2).contiguous() + + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=False, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + result = bm.profile(op, *inputs) + BenchmarkReport.record(op, locals(), result, tag="tileops") + + result_bl = bm.profile(test.ref_program, x_nchw) + BenchmarkReport.record(op, locals(), result_bl, tag="torch") + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tests/ops/test_max_pool2d.py b/tests/ops/test_max_pool2d.py new file mode 100644 index 000000000..a3cf124b9 --- /dev/null +++ b/tests/ops/test_max_pool2d.py @@ -0,0 +1,471 @@ +from typing import Optional, Tuple + +import pytest +import torch +import torch.nn.functional as F + +from tests.test_base import FixtureBase, TestBase +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.ops import MaxPool2dOp + + +class _DummyValuesKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x + + +class _DummyValuesIndicesKernel(Kernel): + supported_archs = [80] + + def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + return x, torch.zeros_like(x, dtype=torch.int64) + + +class MaxPool2dFixture(FixtureBase): + PARAMS = [ + ( + "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, return_indices, ceil_mode, dtype, tune", + [ + pytest.param( + 2, 64, 56, 56, (3, 3), None, (1, 1), (1, 1), False, False, torch.float16, False, + marks=[pytest.mark.smoke, pytest.mark.packaging], + id="smoke-3x3-default-stride-fp16", + ), + pytest.param( + 1, 96, 29, 31, (3, 5), (2, 2), (1, 2), (1, 1), False, True, torch.float16, False, + marks=pytest.mark.full, + id="full-ceil-nonpow2-fp16", + ), + pytest.param( + 1, 80, 28, 30, (3, 3), (2, 2), (1, 1), (2, 1), False, False, torch.bfloat16, False, + marks=pytest.mark.full, + id="full-dilation-bf16", + ), + pytest.param( + 2, 64, 56, 56, (3, 3), (2, 2), (1, 1), (2, 2), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-2x2-fp16", + ), + pytest.param( + 1, 48, 35, 35, (3, 3), (1, 1), (1, 1), (3, 3), False, False, torch.float16, False, + marks=pytest.mark.full, + id="full-dilated-maxpool-3x3-fp16", + ), + pytest.param( + 1, 32, 16, 18, (2, 3), (2, 2), (0, 1), (1, 1), True, False, torch.float16, False, + marks=pytest.mark.full, + id="full-return-indices-fp16", + ), + ], + ), + ] + + +class MaxPool2dTest(TestBase): + def __init__( + self, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + ) -> None: + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.dilation = dilation + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + + def gen_inputs(self, n: int, c_in: int, h_in: int, w_in: int) -> tuple[torch.Tensor]: + x = torch.randn(n, h_in, w_in, c_in, device="cuda", dtype=self.dtype).contiguous() + return (x,) + + def ref_program(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + out = F.max_pool2d( + x.permute(0, 3, 1, 2).contiguous(), + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + return_indices=self.return_indices, + ceil_mode=self.ceil_mode, + ) + if self.return_indices: + values, indices = out + return ( + values.permute(0, 2, 3, 1).contiguous(), + indices.permute(0, 2, 3, 1).contiguous(), + ) + return out.permute(0, 2, 3, 1).contiguous() + + +@MaxPool2dFixture +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: Tuple[int, int], + stride: Optional[Tuple[int, int]], + padding: Tuple[int, int], + dilation: Tuple[int, int], + return_indices: bool, + ceil_mode: bool, + dtype: torch.dtype, + tune: bool, +) -> None: + test = MaxPool2dTest( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + ) + op = MaxPool2dOp( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + return_indices=return_indices, + ceil_mode=ceil_mode, + dtype=dtype, + tune=tune, + ) + atol = 1e-3 if dtype == torch.float16 else 1.6e-2 + rtol = 1e-3 if dtype == torch.float16 else 1.6e-2 + test.check(op, *test.gen_inputs(n, c_in, h_in, w_in), atol=atol, rtol=rtol) + + +@pytest.mark.smoke +def test_max_pool2d_dispatches_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=32, + h_in=28, + w_in=28, + kernel_size=(3, 3), + stride=(2, 2), + padding=(1, 1), + ) + assert isinstance(op.kernel, MaxPool2dKernel) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_indices_when_requested(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + kernel_map={"max_pool2d_kernel": _DummyValuesIndicesKernel}, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + assert indices.shape == x.shape + + +@pytest.mark.smoke +def test_max_pool2d_default_path_uses_values_only_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("indices kernel should not be used when return_indices=False") + + def return_values(*args, **kwargs): + x = args[-1] + return x + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + return_values, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + out = op(x) + assert out is x + + +@pytest.mark.smoke +def test_max_pool2d_indices_path_uses_values_indices_kernel(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("values-only kernel should not be used when return_indices=True") + + def return_values_indices(*args, **kwargs): + x = args[-1] + return x, torch.zeros_like(x, dtype=torch.int64) + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + fail_if_called, + ) + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + return_values_indices, + ) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + return_indices=True, + ) + x = torch.randn(1, 8, 8, 4, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values is x + assert indices.dtype == torch.int64 + + +@pytest.mark.smoke +def test_max_pool2d_rejects_non_positive_dilation() -> None: + with pytest.raises(ValueError, match="dilation must be greater than zero"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dilation=(1, 0), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_invalid_padding_for_effective_kernel() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(3, 1), + dilation=(2, 1), + ) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_padding_pyTorch_rejects_when_dilated() -> None: + with pytest.raises(ValueError, match="padding must be at most half"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + padding=(2, 2), + dilation=(2, 2), + ) + + +@pytest.mark.smoke +@pytest.mark.parametrize( + ("kwargs", "match"), + [ + ({"dilation": True}, "dilation must be an int or a tuple of 2 ints"), + ({"dilation": (1, True)}, "dilation must contain only ints"), + ({"kernel_size": True}, "kernel_size must be an int or a tuple of 2 ints"), + ({"stride": True}, "stride must be an int or a tuple of 2 ints"), + ({"padding": True}, "padding must be an int or a tuple of 2 ints"), + ], +) +def test_max_pool2d_rejects_invalid_param_types(kwargs: dict[str, object], match: str) -> None: + base_kwargs = { + "n": 1, + "c_in": 8, + "h_in": 16, + "w_in": 16, + "kernel_size": (3, 3), + } + base_kwargs.update(kwargs) + with pytest.raises((TypeError, ValueError), match=match): + MaxPool2dOp(**base_kwargs) + + +@pytest.mark.smoke +def test_max_pool2d_rejects_unsupported_dtype(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + with pytest.raises(ValueError, match="only supports dtypes"): + MaxPool2dOp( + n=1, + c_in=8, + h_in=16, + w_in=16, + kernel_size=(3, 3), + dtype=torch.float32, + ) + + +@pytest.mark.smoke +def test_max_pool2d_forward_rejects_non_cuda_input(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 8, 8, 4) + with pytest.raises(ValueError, match="CUDA"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_rejects_nchw_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=4, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 4, 8, 8, device="cuda", dtype=torch.float16) + with pytest.raises(ValueError, match="NHWC"): + op(x) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_forward_warns_on_ambiguous_nhwc_shape(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + op = MaxPool2dOp( + n=1, + c_in=8, + h_in=8, + w_in=8, + kernel_size=(2, 2), + stride=(2, 2), + kernel_map={"max_pool2d_kernel": _DummyValuesKernel}, + ) + x = torch.randn(1, 8, 8, 8, device="cuda", dtype=torch.float16) + with pytest.warns(UserWarning, match="ambiguous NHWC shape"): + out = op(x) + assert out is x + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_return_indices_handles_all_negative_infinity(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + x = torch.full((1, 3, 3, 1), float("-inf"), device="cuda", dtype=torch.float16) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=3, + w_in=3, + kernel_size=(2, 2), + stride=(1, 1), + return_indices=True, + ) + values, indices = op(x) + expected_values = torch.full((1, 2, 2, 1), float("-inf"), device="cuda", dtype=torch.float16) + expected_indices = torch.tensor([[[[0], [1]], [[3], [4]]]], device="cuda", dtype=torch.int64) + torch.testing.assert_close(values, expected_values) + torch.testing.assert_close(indices, expected_indices) + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_returns_empty_tensor_for_zero_sized_output(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("kernel wrapper should not be called for zero-sized outputs") + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_wrapped_kernel", + fail_if_called, + ) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=1, + w_in=1, + kernel_size=(2, 2), + stride=(2, 2), + dilation=(2, 2), + return_indices=False, + ) + x = torch.randn(1, 1, 1, 1, device="cuda", dtype=torch.float16) + out = op(x) + assert out.shape == (1, 0, 0, 1) + assert out.numel() == 0 + + +@pytest.mark.smoke +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +def test_max_pool2d_return_indices_returns_empty_tensors_for_zero_sized_output( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("tileops.ops.op.get_sm_version", lambda: 80) + + def fail_if_called(*args, **kwargs): + raise AssertionError("kernel wrapper should not be called for zero-sized outputs") + + monkeypatch.setattr( + "tileops.kernels.pool.max_pool2d._max_pool2d_values_indices_wrapped_kernel", + fail_if_called, + ) + op = MaxPool2dOp( + n=1, + c_in=1, + h_in=1, + w_in=1, + kernel_size=(2, 2), + stride=(2, 2), + dilation=(2, 2), + return_indices=True, + ) + x = torch.randn(1, 1, 1, 1, device="cuda", dtype=torch.float16) + values, indices = op(x) + assert values.shape == (1, 0, 0, 1) + assert values.numel() == 0 + assert indices.shape == (1, 0, 0, 1) + assert indices.dtype == torch.int64 + assert indices.numel() == 0 + + +if __name__ == "__main__": + pytest.main([__file__, "-vvs"]) diff --git a/tileops/kernels/__init__.py b/tileops/kernels/__init__.py index 1c1e8645a..18d38bb9f 100644 --- a/tileops/kernels/__init__.py +++ b/tileops/kernels/__init__.py @@ -58,7 +58,7 @@ LayerNormKernel, RmsNormKernel, ) -from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel +from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel, MaxPool2dKernel from .rope import ( RopeLlama31Kernel, RopeLongRopeKernel, @@ -72,6 +72,7 @@ "AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", + "MaxPool2dKernel", "BatchNormBwdKernel", "BatchNormFwdInferKernel", "BatchNormFwdTrainKernel", diff --git a/tileops/kernels/pool/__init__.py b/tileops/kernels/pool/__init__.py index bf026df3d..98fbb94fa 100644 --- a/tileops/kernels/pool/__init__.py +++ b/tileops/kernels/pool/__init__.py @@ -1,5 +1,6 @@ from .avg_pool1d import AvgPool1dKernel from .avg_pool2d import AvgPool2dKernel from .avg_pool3d import AvgPool3dKernel +from .max_pool2d import MaxPool2dKernel -__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel"] +__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", "MaxPool2dKernel"] diff --git a/tileops/kernels/pool/common.py b/tileops/kernels/pool/common.py index 9a4bf20d9..78add9acc 100644 --- a/tileops/kernels/pool/common.py +++ b/tileops/kernels/pool/common.py @@ -40,10 +40,13 @@ def validate_pool_params( kernel_size: tuple[int, ...], stride: tuple[int, ...], padding: tuple[int, ...], + dilation: tuple[int, ...] | None = None, divisor_override: int | None = None, ) -> None: if len(kernel_size) != ndim or len(stride) != ndim or len(padding) != ndim: raise ValueError("kernel_size, stride, and padding must match pooling dimensionality") + if dilation is not None and len(dilation) != ndim: + raise ValueError("dilation must match pooling dimensionality") for name, values in ( ("kernel_size", kernel_size), @@ -62,9 +65,15 @@ def validate_pool_params( if any(v < 0 for v in padding): raise ValueError("padding must be non-negative") + if dilation is not None: + if not all(isinstance(v, int) and not isinstance(v, bool) for v in dilation): + raise TypeError("dilation must contain only ints") + if any(v <= 0 for v in dilation): + raise ValueError("dilation must be greater than zero") + for pad, kernel in zip(padding, kernel_size, strict=True): if pad > kernel // 2: - raise ValueError("padding must be at most half of the effective kernel size") + raise ValueError("padding must be at most half of the kernel size") if divisor_override is not None and (not isinstance(divisor_override, int) or isinstance(divisor_override, bool)): raise TypeError("divisor_override must be an int or None") @@ -103,11 +112,13 @@ def pool_output_dim( stride: int, padding: int, ceil_mode: bool, + dilation: int = 1, ) -> int: + effective_kernel = (kernel_size - 1) * dilation + 1 if ceil_mode: - out = (input_size + 2 * padding - kernel_size + stride - 1) // stride + 1 + out = (input_size + 2 * padding - effective_kernel + stride - 1) // stride + 1 else: - out = (input_size + 2 * padding - kernel_size) // stride + 1 + out = (input_size + 2 * padding - effective_kernel) // stride + 1 if ceil_mode and out > 0 and (out - 1) * stride >= input_size + padding: out -= 1 diff --git a/tileops/kernels/pool/max_pool2d.py b/tileops/kernels/pool/max_pool2d.py new file mode 100644 index 000000000..86acbf3ad --- /dev/null +++ b/tileops/kernels/pool/max_pool2d.py @@ -0,0 +1,467 @@ +import functools +import itertools +from typing import Optional + +import tilelang +import tilelang.language as T +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool.common import pool_output_dim + +__all__ = ["MaxPool2dKernel"] + +_SUPPORTED_DTYPES = (torch.float16, torch.bfloat16) + + +@functools.lru_cache(maxsize=64) +def _max_pool2d_values_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str = "float16", +): + accum_dtype = "float32" + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + @tilelang.jit(out_idx=[1], compile_flags=["-O3", "-DENABLE_BF16"]) + def _max_pool2d_func(block_m: int, block_c: int, threads: int): + @T.prim_func + def _max_pool2d_main( + x: T.Tensor((n, h_in, w_in, c_in), dtype), # type: ignore + out: T.Tensor((n, out_h, out_w, c_in), dtype), # type: ignore + ): + with T.Kernel( + T.ceildiv(c_in, block_c), + T.ceildiv(n * out_h * out_w, block_m), + threads=threads, + ) as (bx, by): + out_flat = T.Tensor((n * out_h * out_w, c_in), dtype, out.data) + + for i, j in T.Parallel(block_m, block_c): + m_idx = by * block_m + i + c_idx = bx * block_c + j + if m_idx < n * out_h * out_w and c_idx < c_in: + batch = m_idx // (out_h * out_w) + out_idx = m_idx % (out_h * out_w) + oh = out_idx // out_w + ow = out_idx % out_w + max_val = T.alloc_var(T.float32) + max_val = -T.infinity(accum_dtype) + + for kh in T.serial(kernel_h): + for kw in T.serial(kernel_w): + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: + candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) + max_val = T.max(max_val, candidate) + + out_flat[m_idx, c_idx] = T.cast(max_val, dtype) + + return _max_pool2d_main + + return _max_pool2d_func + + +@functools.lru_cache(maxsize=64) +def _max_pool2d_values_indices_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str = "float16", +): + accum_dtype = "float32" + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + @tilelang.jit(out_idx=[1, 2], compile_flags=["-O3", "-DENABLE_BF16"]) + def _max_pool2d_func(block_m: int, block_c: int, threads: int): + @T.prim_func + def _max_pool2d_main( + x: T.Tensor((n, h_in, w_in, c_in), dtype), # type: ignore + out: T.Tensor((n, out_h, out_w, c_in), dtype), # type: ignore + out_indices: T.Tensor((n, out_h, out_w, c_in), "int64"), # type: ignore + ): + with T.Kernel( + T.ceildiv(c_in, block_c), + T.ceildiv(n * out_h * out_w, block_m), + threads=threads, + ) as (bx, by): + out_flat = T.Tensor((n * out_h * out_w, c_in), dtype, out.data) + indices_flat = T.Tensor((n * out_h * out_w, c_in), "int64", out_indices.data) + + for i, j in T.Parallel(block_m, block_c): + m_idx = by * block_m + i + c_idx = bx * block_c + j + if m_idx < n * out_h * out_w and c_idx < c_in: + batch = m_idx // (out_h * out_w) + out_idx = m_idx % (out_h * out_w) + oh = out_idx // out_w + ow = out_idx % out_w + window_h_start = oh * stride_h - pad_h + window_w_start = ow * stride_w - pad_w + first_kh = T.ceildiv(T.max(-window_h_start, 0), dilation_h) + first_kw = T.ceildiv(T.max(-window_w_start, 0), dilation_w) + first_ih = window_h_start + first_kh * dilation_h + first_iw = window_w_start + first_kw * dilation_w + has_valid = ( + first_kh < kernel_h + and first_kw < kernel_w + and first_ih >= 0 + and first_ih < h_in + and first_iw >= 0 + and first_iw < w_in + ) + max_val = T.alloc_var(T.float32) + max_index = T.alloc_var(T.int64) + max_val = T.if_then_else( + has_valid, + T.cast(x[batch, first_ih, first_iw, c_idx], accum_dtype), + -T.infinity(accum_dtype), + ) + max_index = T.if_then_else( + has_valid, + T.cast(first_ih, "int64") * T.cast(w_in, "int64") + + T.cast(first_iw, "int64"), + T.cast(pad_h, "int64") * T.cast(w_in, "int64") + T.cast(pad_w, "int64"), + ) + + for kh in T.serial(kernel_h): + for kw in T.serial(kernel_w): + ih = oh * stride_h + kh * dilation_h - pad_h + iw = ow * stride_w + kw * dilation_w - pad_w + if ih >= 0 and ih < h_in and iw >= 0 and iw < w_in: + candidate = T.cast(x[batch, ih, iw, c_idx], accum_dtype) + candidate_index = ( + T.cast(ih, "int64") * T.cast(w_in, "int64") + + T.cast(iw, "int64") + ) + if candidate > max_val: + max_val = candidate + max_index = candidate_index + + out_flat[m_idx, c_idx] = T.cast(max_val, dtype) + indices_flat[m_idx, c_idx] = max_index + + return _max_pool2d_main + + return _max_pool2d_func + + +@torch.library.custom_op("top::max_pool2d_values_wrapped_kernel", mutates_args=()) +def _max_pool2d_values_wrapped_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> torch.Tensor: + return _max_pool2d_values_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + dtype, + )(block_m, block_c, threads)(x) + + +@_max_pool2d_values_wrapped_kernel.register_fake +def _max_pool2d_values_wrapped_kernel_fake( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> torch.Tensor: + _ = (dtype, block_m, block_c, threads) + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + return torch.empty((n, out_h, out_w, c_in), dtype=x.dtype, device=x.device) + + +@torch.library.custom_op("top::max_pool2d_values_indices_wrapped_kernel", mutates_args=()) +def _max_pool2d_values_indices_wrapped_kernel( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + return _max_pool2d_values_indices_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + dtype, + )(block_m, block_c, threads)(x) + + +@_max_pool2d_values_indices_wrapped_kernel.register_fake +def _max_pool2d_values_indices_wrapped_kernel_fake( + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + dtype: str, + block_m: int, + block_c: int, + threads: int, + x: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + _ = (dtype, block_m, block_c, threads) + out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + return ( + torch.empty((n, out_h, out_w, c_in), dtype=x.dtype, device=x.device), + torch.empty((n, out_h, out_w, c_in), dtype=torch.int64, device=x.device), + ) + + +class MaxPool2dKernel(Kernel): + supported_archs: list[int] = [80, 86, 89, 90] + SUPPORTED_DTYPES = _SUPPORTED_DTYPES + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_h: int, + kernel_w: int, + stride_h: int, + stride_w: int, + pad_h: int, + pad_w: int, + dilation_h: int, + dilation_w: int, + ceil_mode: bool, + return_indices: bool, + dtype: torch.dtype, + config: Optional[dict] = None, + tune: bool = False, + ) -> None: + super().__init__() + if self.SUPPORTED_DTYPES is not None and dtype not in self.SUPPORTED_DTYPES: + supported = ", ".join(str(dt) for dt in self.SUPPORTED_DTYPES) + raise ValueError( + f"{self.__class__.__name__} only supports dtypes [{supported}], got {dtype}" + ) + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_h = kernel_h + self.kernel_w = kernel_w + self.stride_h = stride_h + self.stride_w = stride_w + self.pad_h = pad_h + self.pad_w = pad_w + self.dilation_h = dilation_h + self.dilation_w = dilation_w + self.ceil_mode = ceil_mode + self.return_indices = return_indices + self.dtype = dtype + self.out_h = pool_output_dim(h_in, kernel_h, stride_h, pad_h, ceil_mode, dilation_h) + self.out_w = pool_output_dim(w_in, kernel_w, stride_w, pad_w, ceil_mode, dilation_w) + + if return_indices: + self.kernel = _max_pool2d_values_indices_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) + else: + self.kernel = _max_pool2d_values_kernel( + n, + c_in, + h_in, + w_in, + kernel_h, + kernel_w, + stride_h, + stride_w, + pad_h, + pad_w, + dilation_h, + dilation_w, + ceil_mode, + self.dtype_str, + ) + self.init_config(config, tune) + + @property + def default_config(self) -> dict: + return { + "block_m": 128, + "block_c": 64, + "threads": 128, + } + + @property + def autotune_configs(self) -> list[dict]: + configs = itertools.product([64, 128, 256], [32, 64, 128], [128, 256]) + return [ + { + "block_m": block_m, + "block_c": block_c, + "threads": threads, + } + for block_m, block_c, threads in configs + ] + + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if self.out_h == 0 or self.out_w == 0: + empty_values = torch.empty( + (self.n, self.out_h, self.out_w, self.c_in), + dtype=x.dtype, + device=x.device, + ) + if self.return_indices: + empty_indices = torch.empty( + (self.n, self.out_h, self.out_w, self.c_in), + dtype=torch.int64, + device=x.device, + ) + return empty_values, empty_indices + return empty_values + if self.return_indices: + return _max_pool2d_values_indices_wrapped_kernel( + self.n, + self.c_in, + self.h_in, + self.w_in, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.ceil_mode, + self.dtype_str, + self.config["block_m"], + self.config["block_c"], + self.config["threads"], + x, + ) + return _max_pool2d_values_wrapped_kernel( + self.n, + self.c_in, + self.h_in, + self.w_in, + self.kernel_h, + self.kernel_w, + self.stride_h, + self.stride_w, + self.pad_h, + self.pad_w, + self.dilation_h, + self.dilation_w, + self.ceil_mode, + self.dtype_str, + self.config["block_m"], + self.config["block_c"], + self.config["threads"], + x, + ) diff --git a/tileops/ops/__init__.py b/tileops/ops/__init__.py index f7c8a6d50..a5069b29f 100644 --- a/tileops/ops/__init__.py +++ b/tileops/ops/__init__.py @@ -31,6 +31,7 @@ from .gqa_sliding_window_fwd import GqaSlidingWindowFwdOp from .gqa_sliding_window_varlen_fwd import GqaSlidingWindowVarlenFwdOp from .grouped_gemm import GroupedGemmOp +from .max_pool2d import MaxPool2dOp from .mha import MultiHeadAttentionBwdOp, MultiHeadAttentionFwdOp from .mha_decode import MultiHeadAttentionDecodeWithKVCacheOp from .mha_decode_paged import MultiHeadAttentionDecodePagedWithKVCacheOp @@ -95,6 +96,7 @@ "AvgPool1dOp", "AvgPool2dOp", "AvgPool3dOp", + "MaxPool2dOp", "AdaLayerNormOp", "AdaLayerNormZeroOp", "BatchNormBwdOp", diff --git a/tileops/ops/max_pool2d.py b/tileops/ops/max_pool2d.py new file mode 100644 index 000000000..8f1b2e577 --- /dev/null +++ b/tileops/ops/max_pool2d.py @@ -0,0 +1,101 @@ +from typing import Dict, Optional, Tuple + +import torch + +from tileops.kernels.kernel import Kernel +from tileops.kernels.pool import MaxPool2dKernel +from tileops.kernels.pool.common import ( + _normalize_pool_dims, + validate_channels_last_input, + validate_pool_params, +) + +from .op import Op + +__all__ = ["MaxPool2dOp"] + + +class MaxPool2dOp(Op): + """Max pooling over channels-last `NHWC` inputs.""" + + def __init__( + self, + n: int, + c_in: int, + h_in: int, + w_in: int, + kernel_size: int | Tuple[int, int], + stride: Optional[int | Tuple[int, int]] = None, + padding: int | Tuple[int, int] = 0, + dilation: int | Tuple[int, int] = 1, + return_indices: bool = False, + ceil_mode: bool = False, + dtype: torch.dtype = torch.float16, + kernel_map: Optional[Dict[str, Kernel]] = None, + tune: bool = False, + ) -> None: + self.n = n + self.c_in = c_in + self.h_in = h_in + self.w_in = w_in + self.kernel_size = _normalize_pool_dims("kernel_size", kernel_size, 2) + self.stride = ( + self.kernel_size + if stride is None + else _normalize_pool_dims("stride", stride, 2) + ) + self.padding = _normalize_pool_dims("padding", padding, 2) + self.dilation = _normalize_pool_dims("dilation", dilation, 2) + self.return_indices = return_indices + self.ceil_mode = ceil_mode + self.dtype = dtype + validate_pool_params( + ndim=2, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + ) + + self.dispatch_kernel(kernel_map) + if "max_pool2d_kernel" not in self.kernel_map: + raise NotImplementedError("MaxPool2dOp requires 'max_pool2d_kernel' in kernel_map") + self.kernel = self.kernel_map["max_pool2d_kernel"]( + n=n, + c_in=c_in, + h_in=h_in, + w_in=w_in, + kernel_h=self.kernel_size[0], + kernel_w=self.kernel_size[1], + stride_h=self.stride[0], + stride_w=self.stride[1], + pad_h=self.padding[0], + pad_w=self.padding[1], + dilation_h=self.dilation[0], + dilation_w=self.dilation[1], + ceil_mode=ceil_mode, + return_indices=return_indices, + dtype=dtype, + tune=tune, + ) + + @property + def default_kernel_map(self) -> Dict[str, Kernel]: + return {"max_pool2d_kernel": MaxPool2dKernel} + + def forward(self, x: torch.Tensor) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: + if not x.is_cuda: + raise ValueError("Input must be a CUDA tensor") + if x.dtype != self.dtype: + raise ValueError(f"Expected x.dtype {self.dtype}, got {x.dtype}") + validate_channels_last_input( + op_name=type(self).__name__, + x_shape=tuple(x.shape), + expected_shape=(self.n, self.h_in, self.w_in, self.c_in), + layout="NHWC", + ambiguous_layout_shape=(self.n, self.c_in, self.h_in, self.w_in), + ) + if self.return_indices: + values, indices = self.kernel(x) + return values, indices + return self.kernel(x)