Skip to content
3 changes: 2 additions & 1 deletion tileops/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
LayerNormKernel,
RmsNormKernel,
)
from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel
from .pool import AvgPool1dKernel, AvgPool2dKernel, AvgPool3dKernel, MaxPool2dKernel
from .rope import (
RopeLlama31Kernel,
RopeLongRopeKernel,
Expand All @@ -72,6 +72,7 @@
"AvgPool1dKernel",
"AvgPool2dKernel",
"AvgPool3dKernel",
"MaxPool2dKernel",
"BatchNormBwdKernel",
"BatchNormFwdInferKernel",
"BatchNormFwdTrainKernel",
Expand Down
3 changes: 2 additions & 1 deletion tileops/kernels/pool/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .avg_pool1d import AvgPool1dKernel
from .avg_pool2d import AvgPool2dKernel
from .avg_pool3d import AvgPool3dKernel
from .max_pool2d import MaxPool2dKernel

__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel"]
__all__ = ["AvgPool1dKernel", "AvgPool2dKernel", "AvgPool3dKernel", "MaxPool2dKernel"]
17 changes: 14 additions & 3 deletions tileops/kernels/pool/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,13 @@ def validate_pool_params(
kernel_size: tuple[int, ...],
stride: tuple[int, ...],
padding: tuple[int, ...],
dilation: tuple[int, ...] | None = None,
divisor_override: int | None = None,
) -> None:
if len(kernel_size) != ndim or len(stride) != ndim or len(padding) != ndim:
raise ValueError("kernel_size, stride, and padding must match pooling dimensionality")
if dilation is not None and len(dilation) != ndim:
raise ValueError("dilation must match pooling dimensionality")

for name, values in (
("kernel_size", kernel_size),
Expand All @@ -62,9 +65,15 @@ def validate_pool_params(
if any(v < 0 for v in padding):
raise ValueError("padding must be non-negative")

if dilation is not None:
if not all(isinstance(v, int) and not isinstance(v, bool) for v in dilation):
raise TypeError("dilation must contain only ints")
if any(v <= 0 for v in dilation):
raise ValueError("dilation must be greater than zero")

for pad, kernel in zip(padding, kernel_size, strict=True):
if pad > kernel // 2:
raise ValueError("padding must be at most half of the effective kernel size")
raise ValueError("padding must be at most half of the kernel size")

if divisor_override is not None and (not isinstance(divisor_override, int) or isinstance(divisor_override, bool)):
raise TypeError("divisor_override must be an int or None")
Expand Down Expand Up @@ -103,11 +112,13 @@ def pool_output_dim(
stride: int,
padding: int,
ceil_mode: bool,
dilation: int = 1,
) -> int:
effective_kernel = (kernel_size - 1) * dilation + 1
if ceil_mode:
out = (input_size + 2 * padding - kernel_size + stride - 1) // stride + 1
out = (input_size + 2 * padding - effective_kernel + stride - 1) // stride + 1
else:
out = (input_size + 2 * padding - kernel_size) // stride + 1
out = (input_size + 2 * padding - effective_kernel) // stride + 1

if ceil_mode and out > 0 and (out - 1) * stride >= input_size + padding:
out -= 1
Expand Down
Loading
Loading