[Feat][POOL] Add max_pool2d operator#832
Conversation
There was a problem hiding this comment.
Code Review
This pull request implements the MaxPool2d operation, including a tilelang-based kernel, a high-level Op interface, and comprehensive tests and benchmarks. The implementation supports dilation and ceil_mode, with updates to common pooling utilities to accommodate these features. Feedback focuses on optimizing memory bandwidth by making index generation conditional on the return_indices flag and addressing a potential integer overflow in spatial index calculations.
Ibuki-wind
left a comment
There was a problem hiding this comment.
- tests/ops/test_max_pool2d.py:1 — Trust-boundary violation: an Implementation PR must not modify
tests/; move the correctness suite into a separate Test-stage PR that targets the merged operator interface. - benchmarks/ops/bench_max_pool2d.py:1 — Trust-boundary violation: an Implementation PR must not modify
benchmarks/; move the profiling coverage into a separate Benchmark-stage PR after the op lands.
Ibuki-wind
left a comment
There was a problem hiding this comment.
- tileops/kernels/pool/common.py:79 —
validate_pool_params()now validates max-pool padding against the effective dilated kernel, but PyTorch rejects padding greater thankernel_size // 2even when dilation > 1. This implementation accepts configurations that the reference API errors on. -> Validate max-pool padding with PyTorch's rule before constructing the kernel. - tileops/kernels/pool/max_pool2d.py:125 — the indices path initializes
max_val=-infand only updates oncandidate > max_val, so an all--infwindow never records its first valid element. PyTorch still returns the first valid index for tied maxima in that case. -> Seed the first valid sample explicitly, or update on equality only until an index has been initialized, soreturn_indices=Truematches PyTorch on-inftie windows.
Ibuki-wind
left a comment
There was a problem hiding this comment.
- tileops/kernels/pool/max_pool2d.py:_max_pool2d_values_indices_kernel — fully padded windows with
return_indices=Trueseedmax_val/max_indexfrom an out-of-bounds(first_ih, first_iw)instead of matching PyTorch's-inf+ sentinel-index behavior → detect the no-valid-element case before the initial load and handle it explicitly.
Ibuki-wind
left a comment
There was a problem hiding this comment.
- PR boundary — this branch no longer matches the repo's current tests-first TDD split: it is an implementation-only slice of an older mixed test+implementation+benchmark effort, so the implementation is not reviewable against the stage artifact that is supposed to precede it. Re-split the work so the test PR lands first, then rebase the implementation PR on top of that stage.
- PR body — the current body still claims dedicated tests/benchmarks and benchmark results that are not part of this diff, so the evidence chain is no longer auditable from this PR alone. Rewrite the body so it only describes artifacts that are actually present in this PR, or move the missing evidence into the proper stage-specific PRs and reference them explicitly.
- tileops/kernels/pool/max_pool2d.py:56 — the values-only path still divides by
out_h * out_weven when a legal parameter combination produces an empty output shape, which crashes lowering instead of returning an empty tensor. Add an explicit zero-output short-circuit before emitting index arithmetic. - tileops/kernels/pool/max_pool2d.py:119 — the values+indices path has the same zero-output failure mode, so
return_indices=Trueis also not PyTorch-compatible for legal empty-output cases. Guard this path the same way. - Performance evidence — on the local H200 / Torch 2.9.1+cu128 environment, the reported speedups are only reproducible after autotuning, not on the default
tune=Falsepath exposed by this PR, and one claimed ceil-mode number was not reproduced locally. Narrow or correct the performance claim accordingly.
| m_idx = by * block_m + i | ||
| c_idx = bx * block_c + j | ||
| if m_idx < n * out_h * out_w and c_idx < c_in: | ||
| batch = m_idx // (out_h * out_w) |
There was a problem hiding this comment.
_max_pool2d_values_kernel still computes m_idx // (out_h * out_w) and m_idx % (out_h * out_w) even when a legal parameter set makes out_h * out_w == 0, which crashes TIR lowering with a divide-by-zero internal error instead of returning an empty tensor like PyTorch -> short-circuit the zero-output case before any flattened-output index arithmetic is emitted.
| m_idx = by * block_m + i | ||
| c_idx = bx * block_c + j | ||
| if m_idx < n * out_h * out_w and c_idx < c_in: | ||
| batch = m_idx // (out_h * out_w) |
There was a problem hiding this comment.
_max_pool2d_values_indices_kernel has the same empty-output bug on the return_indices=True path, so legal inputs can still fail before execution starts -> add the same zero-output guard here and return empty value/index tensors with the expected shapes instead of lowering invalid arithmetic.
Closes #769
Summary
MaxPool2dKernelandMaxPool2dOpwith NHWC max-pooling support, dilation,ceil_mode, and optionalreturn_indicesreturn_indices=Falsepath does not pay the extra indices write costTest plan
python -m pytest tests/ops/test_max_pool2d.py -qpython -m pytest benchmarks/ops/bench_max_pool2d.py -qStructural Readiness
SKIP (recommended): benchmark does not include a dedicated “default-parameter-values” case; current coverage is focused on representative performance workloads.Benchmark
python -m pytest benchmarks/ops/bench_max_pool2d.py -qtorchbaseline6 passedDtype support matrix
torch.float16torch.bfloat16torch.float32n=2,c=64,h=w=112, fp16n=2,c=64,h=w=112, bf16n=2,c=128,h=w=56, fp16n=2,c=128,h=w=56, bf16n=3,c=96,h=55,w=57, ceil, fp16n=3,c=96,h=55,w=57, ceil, bf16Takeaways:
max_pool2dbenchmark focused on the standalone pooling path instead of paying for indices generation.Additional context
max_pool2d.return_indiceswas added to align the functional interface more closely with PyTorch semantics.