Skip to content

Commit 4ade286

Browse files
committed
Update tests
Signed-off-by: Justin Chu <[email protected]>
1 parent 5afacc7 commit 4ade286

File tree

1 file changed

+26
-30
lines changed

1 file changed

+26
-30
lines changed

src/onnx_ir/passes/_pass_infra_test.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,32 @@ def ensures(self, model: ir.Model) -> None:
9797
self.assertIs(pass_.ensures_called_with, result.model)
9898
self.assertIs(pass_.ensures_called_with, input_model)
9999

100+
def test_ensures_called_with_result_model_functional_pass(self):
101+
"""Test that ensures() is called with result.model for functional passes."""
102+
103+
class TestPass(_pass_infra.FunctionalPass):
104+
def __init__(self):
105+
self.ensures_called_with = None
106+
107+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
108+
# Create a new model (different object)
109+
new_model = ir.Model(
110+
graph=ir.Graph([], [], nodes=[]), ir_version=model.ir_version
111+
)
112+
return _pass_infra.PassResult(model=new_model, modified=True)
113+
114+
def ensures(self, model: ir.Model) -> None:
115+
# Record which model ensures was called with
116+
self.ensures_called_with = model
117+
118+
pass_ = TestPass()
119+
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
120+
result = pass_(input_model)
121+
122+
# Verify that ensures was called with the result model, not the input model
123+
self.assertIs(pass_.ensures_called_with, result.model)
124+
self.assertIsNot(pass_.ensures_called_with, input_model)
125+
100126
def test_postcondition_error_raised_when_ensures_fails(self):
101127
"""Test that PostconditionError is raised when ensures() raises an exception."""
102128

@@ -153,36 +179,6 @@ def ensures(self, model: ir.Model) -> None:
153179
):
154180
pass_(model)
155181

156-
def test_ensures_receives_correct_model_when_pass_modifies_model(self):
157-
"""Test a more complex scenario where the pass modifies the model structure."""
158-
159-
class ModelModifyingPass(_pass_infra.FunctionalPass):
160-
def __init__(self):
161-
self.ensures_model_graph_nodes_count = None
162-
163-
def call(self, model: ir.Model) -> _pass_infra.PassResult:
164-
# Create a new model with additional nodes
165-
new_model = ir.Model(
166-
graph=ir.Graph(
167-
[],
168-
[output := ir.Value(name="output")],
169-
nodes=[ir.node("TestOp", [], outputs=[output])],
170-
),
171-
ir_version=model.ir_version,
172-
)
173-
return _pass_infra.PassResult(model=new_model, modified=True)
174-
175-
def ensures(self, model: ir.Model) -> None:
176-
# Record the number of nodes in the model passed to ensures
177-
self.ensures_model_graph_nodes_count = len(list(model.graph))
178-
179-
pass_ = ModelModifyingPass()
180-
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
181-
_ = pass_(input_model)
182-
183-
# The ensures method should see the modified model with 1 node
184-
self.assertEqual(pass_.ensures_model_graph_nodes_count, 1)
185-
186182

187183
if __name__ == "__main__":
188184
unittest.main()

0 commit comments

Comments
 (0)