Description
ExportedProgram for
class AvgPool2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=6,
)
def forward(self, x):
return self.ap2d(x)
produces the call to AvgPool2d
as torch.ops.aten.avg_pool2d.default(x, [6, 6], [6, 6])
. This matches with documented behavior for kernel
parameter (https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) that states that single integer value will be used for both height, width dimension. As per documentation, the only other possible value for kernel
is a tuple of 2 integers. However, tuple of single element works as well:
class AvgPool2dFloatModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.ap2d = torch.nn.AvgPool2d(
kernel_size=(6,)
)
def forward(self, x):
return self.ap2d(x)
and the ExportedProgram has the call to AvgPool2d
as torch.ops.aten.avg_pool2d.default(x, [6], [6])
. Note that the kernel
value is not being repeated though that's what happens when executing the code in python.
This ExportedProgram causes an assertion when lowering the resulting Torch IR to Tosa/Linalg/Stablehlo as the lowerings assume that kernel
is 2-elements.
So I think this can be fixed by either of the following approaches:
- Match the behavior of ExportedProgram for the second scenario to match with the first one. I am not familiar with PyTorch codebase, so not sure where to make the change. If anyone knows where to start looking, I'll appreciate it.
- Fix the individual lowerings but that means repeating the same logic in 3 different places.
- In Torch IR before any of the lowerings (possibly when
DecomposeComplexOps
is called) extend thekernel
param of thetorch.aten.avg_pool2d
op to be of correct size, so the individual lowerings don't need to be fixed.
I'm leaning towards 3 (since I don't know how to make 1 work) -- is that the correct approach? If so, which pass will be the correct place to add the logic -- AFAICT none of the existing passes seem to be doing a similar transform where the op is replaced by the same op but with different params. Should I add a new pass?
@sjarus, @vivekkhandelwal1 -- any thoughts? Thanks!