-
Notifications
You must be signed in to change notification settings - Fork 32
[Feat][POOL] Add max_pool2d operator #832
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
RMLYC
wants to merge
10
commits into
tile-ai:main
Choose a base branch
from
RMLYC:feat/pool/max-pool2d
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 6 commits
Commits
Show all changes
10 commits
Select commit
Hold shift + click to select a range
6d8be4d
[FEAT][POOL] add max_pool2d operator
RMLYC 7a8647d
[FEAT][POOL] add max_pool2d operator
RMLYC 808ee2c
[Fix][POOL] split max_pool2d index path
RMLYC 39b0584
Merge remote-tracking branch 'origin/feat/pool/max-pool2d' into feat/…
RMLYC 2a0aa19
[Fix][POOL] widen max_pool2d indices math
RMLYC ae476b0
[Chore][Lint] fix pre-commit import order
RMLYC e529d7e
[Fix][POOL] align max_pool2d semantics
RMLYC 225aa4b
[Chore][POOL] split max_pool2d tests and benchmarks
RMLYC d654c32
[Fix][POOL] handle empty max_pool2d windows
RMLYC 8a87232
[Fix][POOL] restore max_pool2d validation assets
RMLYC File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| from typing import Optional, Tuple | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
|
|
||
| from benchmarks.benchmark import BenchmarkBase, BenchmarkReport | ||
| from tileops.kernels.pool.common import pool_output_dim | ||
| from tileops.ops import MaxPool2dOp | ||
|
|
||
|
|
||
| class MaxPool2dBenchCase: | ||
| def __init__( | ||
| self, | ||
| n: int, | ||
| c_in: int, | ||
| h_in: int, | ||
| w_in: int, | ||
| kernel_size: Tuple[int, int], | ||
| stride: Optional[Tuple[int, int]], | ||
| padding: Tuple[int, int], | ||
| dilation: Tuple[int, int], | ||
| ceil_mode: bool, | ||
| dtype: torch.dtype, | ||
| ) -> None: | ||
| self.n = n | ||
| self.c_in = c_in | ||
| self.h_in = h_in | ||
| self.w_in = w_in | ||
| self.kernel_size = kernel_size | ||
| self.stride = kernel_size if stride is None else stride | ||
| self.padding = padding | ||
| self.dilation = dilation | ||
| self.ceil_mode = ceil_mode | ||
| self.dtype = dtype | ||
|
|
||
| def gen_inputs(self) -> tuple[torch.Tensor]: | ||
| x = torch.randn(self.n, self.h_in, self.w_in, self.c_in, device="cuda", dtype=self.dtype).contiguous() | ||
| return (x,) | ||
|
|
||
| def ref_program(self, x: torch.Tensor) -> torch.Tensor: | ||
| return F.max_pool2d( | ||
| x, | ||
| kernel_size=self.kernel_size, | ||
| stride=self.stride, | ||
| padding=self.padding, | ||
| dilation=self.dilation, | ||
| ceil_mode=self.ceil_mode, | ||
| ) | ||
|
|
||
|
|
||
| class MaxPool2dBenchmark(BenchmarkBase): | ||
| def calculate_flops(self) -> Optional[float]: | ||
| t = self.workload | ||
| out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) | ||
| out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) | ||
| return t.n * t.c_in * out_h * out_w * t.kernel_size[0] * t.kernel_size[1] | ||
|
|
||
| def calculate_memory(self) -> Optional[float]: | ||
| t = self.workload | ||
| out_h = pool_output_dim(t.h_in, t.kernel_size[0], t.stride[0], t.padding[0], t.ceil_mode, t.dilation[0]) | ||
| out_w = pool_output_dim(t.w_in, t.kernel_size[1], t.stride[1], t.padding[1], t.ceil_mode, t.dilation[1]) | ||
| return (t.n * t.c_in * t.h_in * t.w_in + t.n * t.c_in * out_h * out_w) * t.dtype.itemsize | ||
|
|
||
|
|
||
| _MAX_POOL2D_BASE_CASES = [ | ||
| (2, 64, 112, 112, (3, 3), (2, 2), (1, 1), (1, 1), False, "vision-3x3-s2"), | ||
| (2, 128, 56, 56, (5, 5), (2, 2), (2, 2), (1, 1), False, "vision-5x5-s2"), | ||
| (3, 96, 55, 57, (3, 3), (2, 2), (1, 1), (2, 1), True, "ceil-dilation-nonpow2"), | ||
| ] | ||
|
|
||
| _MAX_POOL2D_BENCH_PARAMS = [ | ||
| pytest.param(*case[:-1], dtype, True, id=f"{case[-1]}-{str(dtype).split('.')[-1]}") | ||
| for case in _MAX_POOL2D_BASE_CASES | ||
| for dtype in (torch.float16, torch.bfloat16) | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "n, c_in, h_in, w_in, kernel_size, stride, padding, dilation, ceil_mode, dtype, tune", | ||
| _MAX_POOL2D_BENCH_PARAMS, | ||
| ) | ||
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") | ||
| def test_max_pool2d_bench( | ||
| n: int, | ||
| c_in: int, | ||
| h_in: int, | ||
| w_in: int, | ||
| kernel_size: Tuple[int, int], | ||
| stride: Optional[Tuple[int, int]], | ||
| padding: Tuple[int, int], | ||
| dilation: Tuple[int, int], | ||
| ceil_mode: bool, | ||
| dtype: torch.dtype, | ||
| tune: bool, | ||
| ) -> None: | ||
| test = MaxPool2dBenchCase( | ||
| n, | ||
| c_in, | ||
| h_in, | ||
| w_in, | ||
| kernel_size, | ||
| stride, | ||
| padding, | ||
| dilation, | ||
| ceil_mode, | ||
| dtype, | ||
| ) | ||
| bm = MaxPool2dBenchmark(test) | ||
| inputs = test.gen_inputs() | ||
| (x,) = inputs | ||
| x_nchw = x.permute(0, 3, 1, 2).contiguous() | ||
|
|
||
| op = MaxPool2dOp( | ||
| n=n, | ||
| c_in=c_in, | ||
| h_in=h_in, | ||
| w_in=w_in, | ||
| kernel_size=kernel_size, | ||
| stride=stride, | ||
| padding=padding, | ||
| dilation=dilation, | ||
| return_indices=False, | ||
| ceil_mode=ceil_mode, | ||
| dtype=dtype, | ||
| tune=tune, | ||
| ) | ||
| result = bm.profile(op, *inputs) | ||
| BenchmarkReport.record(op, locals(), result, tag="tileops") | ||
|
|
||
| result_bl = bm.profile(test.ref_program, x_nchw) | ||
| BenchmarkReport.record(op, locals(), result_bl, tag="torch") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__, "-vvs"]) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.