Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Torch] add aten.bilinear op decomposing #3931

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
update e2e tests
  • Loading branch information
Dixin Zhou committed Dec 23, 2024
commit cab423b7c2e867b098854a00b9e0f849ad2a738b
19 changes: 19 additions & 0 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
@@ -969,6 +969,14 @@
"AtenLinearMatVec_basic",
"AtenLinearVecMatBias_basic",
"AtenLinearVecMat_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"Aten_TrilinearModuleSumAllDims_basic",
"Aten_TrilinearModuleSumdims_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"ReduceAminSingleDim_basic",
"AtenDotModule_basic",
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
@@ -1764,6 +1772,9 @@
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModule_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"ElementwiseAddBoolModule_basic",
"Exp2StaticModule_basic",
"CosineSimilarityStaticBroadcastModule_basic",
@@ -3339,6 +3350,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
}

if torch_version_for_comparison() < version.parse("2.3.0.dev"):
@@ -4098,6 +4113,10 @@
"Aten_TrilinearModuleVaryingRanks_basic",
"Aten_TrilinearModuleVaryingRanksUnorderedExpands_basic",
"Aten_TrilinearModuleZerodDimBug_basic",
"Aten_BilinearModule1D_basic",
"Aten_BilinearModuleDynamic_basic",
"Aten_BilinearModuleND_basic",
"Aten_BilinearModule_basic",
"AtenTrilModule_basic",
"AtenTrilWithNegDiagonalModule_basic",
"AtenTrilWithPosDiagonalModule_basic",
28 changes: 14 additions & 14 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py
Original file line number Diff line number Diff line change
@@ -1950,7 +1950,7 @@ def Aten_TrilinearModuleZerodDimBug_basic(module, tu: TestUtils):
# ==============================================================================


class Aten_BilinearModuleStaticShape(torch.nn.Module):
class Aten_BilinearModule(torch.nn.Module):
def __init__(self):
super().__init__()

@@ -1968,12 +1968,12 @@ def forward(self, input1, input2, weight, bias):
return torch.ops.aten.bilinear(input1, input2, weight, bias)


@register_test_case(module_factory=lambda: Aten_BilinearModuleStaticShape())
def Aten_BilinearModuleStaticShape_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Aten_BilinearModule())
def Aten_BilinearModule_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4))


class Aten_BilinearModuleDynamicShape(torch.nn.Module):
class Aten_BilinearModuleDynamic(torch.nn.Module):
def __init__(self):
super().__init__()

@@ -1991,8 +1991,8 @@ def forward(self, input1, input2, weight, bias):
return torch.ops.aten.bilinear(input1, input2, weight, bias)


@register_test_case(module_factory=lambda: Aten_BilinearModuleDynamicShape())
def Aten_BilinearModuleDynamicShape_basic(module, tu: TestUtils):
@register_test_case(module_factory=lambda: Aten_BilinearModuleDynamic())
def Aten_BilinearModuleDynamic_basic(module, tu: TestUtils):
module.forward(tu.rand(8, 2), tu.rand(8, 3), tu.rand(4, 2, 3), tu.rand(4))


@@ -2004,10 +2004,10 @@ def __init__(self):
@annotate_args(
[
None,
([-1], torch.float32, True),
([-1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([2], torch.float32, True),
([3], torch.float32, True),
([4, 2, 3], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, input1, input2, weight, bias):
@@ -2027,10 +2027,10 @@ def __init__(self):
@annotate_args(
[
None,
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1], torch.float32, True),
([8, 6, 12, 2], torch.float32, True),
([8, 6, 12, 3], torch.float32, True),
([4, 2, 3], torch.float32, True),
([4], torch.float32, True),
]
)
def forward(self, input1, input2, weight, bias):
Loading