Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve handling of trailing optional inputs in pattern matching #1948

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions onnxscript/rewriter/pattern.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,19 +1040,19 @@

self._matched[pattern_node] = node

# TODO: Revisit this to handle optional trailing inputs better.
if pattern_node.allow_other_inputs:
if len(node.inputs) < len(pattern_node.inputs):
if len(node.inputs) > len(pattern_node.inputs):
if pattern_node.allow_other_inputs:
# Ignore extraneous inputs
to_match = zip(node.inputs, pattern_node.inputs)

Check warning on line 1046 in onnxscript/rewriter/pattern.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern.py#L1046

Added line #L1046 was not covered by tests
else:
return self.fail(
f"Number of inputs ({len(node.inputs)}) is less than expected ({len(pattern_node.inputs)})"
f"Number of inputs ({len(node.inputs)}) is more than expected ({len(pattern_node.inputs)})"
)
else:
if len(node.inputs) != len(pattern_node.inputs):
return self.fail(
f"Input nums mismatch. {len(node.inputs)} vs {len(pattern_node.inputs)}"
)
# Inputs are padded with Nones to match against pattern
to_match = itertools.zip_longest(node.inputs, pattern_node.inputs, fillvalue=None)

Check failure

Code scanning / lintrunner

MYPY/assignment Error

Incompatible types in assignment (expression has type "zip_longest[tuple[Value | None, ValuePattern | None]]", variable has type "zip[tuple[Value | None, ValuePattern | None]]") To disable, use # type: ignore[assignment]

for arg_value, arg_pattern in zip(node.inputs, pattern_node.inputs):
for arg_value, arg_pattern in to_match:
# arg_pattern could be a Var, if it's the original arg.
if arg_pattern is None:
if arg_value is None:
Expand Down
32 changes: 32 additions & 0 deletions onnxscript/rewriter/pattern_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,6 +476,38 @@
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")

def test_match_trailing_optional_input(self):
def none_pattern(op, optional_input, x):
# match against a call to Original where the first input may or may not be None
return op.Original(x, optional_input)

def replacement(op, optional_input, x):
if optional_input is None:
return op.ReplacedNone(x)
return op.ReplacedNotNone(x)

rule = pattern.RewriteRule(none_pattern, replacement)

@script()
def test_model(x: FLOAT[1024]) -> FLOAT[1024]:
# Pattern should match following call (with optional_input == None)
t1 = op.Original(x, None)

Check warning on line 494 in onnxscript/rewriter/pattern_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern_test.py#L494

Added line #L494 was not covered by tests
# as well as this one (with optional_input != None)
t2 = op.Original(x, t1)

Check warning on line 496 in onnxscript/rewriter/pattern_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern_test.py#L496

Added line #L496 was not covered by tests
# as well as this one (with optional_input == None)
z = op.Original(t2)
return z

Check warning on line 499 in onnxscript/rewriter/pattern_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/rewriter/pattern_test.py#L498-L499

Added lines #L498 - L499 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

count = rule.apply_to_model(model)
self.assertEqual(count, 3)
self.assertEqual(len(model.graph), 3)
self.assertEqual(model.graph.node(0).op_type, "ReplacedNone")
self.assertEqual(model.graph.node(1).op_type, "ReplacedNotNone")
self.assertEqual(model.graph.node(2).op_type, "ReplacedNone")


class PatternBuilderTest(unittest.TestCase):
def test_pattern_builder_context(self):
Expand Down
Loading