Skip to content

Commit 985daea

Browse files
Add Tests for PR#78 (#79)
This PR adds tests for #78 ### Commits: - Add unit tests for postcondition checking bug fix Tests verify that ensures() is called with result.model instead of input model, covering the fix in PassBase.__call__ where postconditions should check the transformed model rather than the original input model. - Add unit tests for postcondition checking bug fix Tests verify that ensures() is called with result.model instead of input model, covering the fix in PassBase.__call__ where postconditions should check the transformed model rather than the original input model. Co-authored-by: codecov-ai[bot] <156709835+codecov-ai[bot]@users.noreply.github.com>
1 parent c3c9dca commit 985daea

File tree

1 file changed

+151
-0
lines changed

1 file changed

+151
-0
lines changed

src/onnx_ir/passes/_pass_infra_test.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,5 +35,156 @@ def call(self, model: ir.Model) -> _pass_infra.PassResult:
3535
self.assertIs(result_1.model, result_2.model)
3636

3737

38+
class PostconditionTest(unittest.TestCase):
39+
"""Test that postconditions are checked on the result model, not the input model."""
40+
41+
def test_ensures_called_with_result_model_not_input_model(self):
42+
"""Test that ensures() is called with result.model, not the input model."""
43+
44+
class TestPass(_pass_infra.PassBase):
45+
def __init__(self):
46+
self.ensures_called_with = None
47+
48+
@property
49+
def in_place(self) -> bool:
50+
return False # Not in-place to create a new model
51+
52+
@property
53+
def changes_input(self) -> bool:
54+
return False
55+
56+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
57+
# Create a new model (different object)
58+
new_model = ir.Model(
59+
graph=ir.Graph([], [], nodes=[]),
60+
ir_version=model.ir_version
61+
)
62+
return _pass_infra.PassResult(model=new_model, modified=True)
63+
64+
def ensures(self, model: ir.Model) -> None:
65+
# Record which model ensures was called with
66+
self.ensures_called_with = model
67+
68+
pass_ = TestPass()
69+
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
70+
result = pass_(input_model)
71+
72+
# Verify that ensures was called with the result model, not the input model
73+
self.assertIs(pass_.ensures_called_with, result.model)
74+
self.assertIsNot(pass_.ensures_called_with, input_model)
75+
76+
def test_ensures_called_with_result_model_in_place_pass(self):
77+
"""Test that ensures() is called with result.model for in-place passes."""
78+
79+
class TestInPlacePass(_pass_infra.InPlacePass):
80+
def __init__(self):
81+
self.ensures_called_with = None
82+
83+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
84+
# In-place pass returns the same model
85+
return _pass_infra.PassResult(model=model, modified=True)
86+
87+
def ensures(self, model: ir.Model) -> None:
88+
# Record which model ensures was called with
89+
self.ensures_called_with = model
90+
91+
pass_ = TestInPlacePass()
92+
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
93+
result = pass_(input_model)
94+
95+
# For in-place passes, result.model should be the same as input_model
96+
self.assertIs(result.model, input_model)
97+
# Verify that ensures was called with the result model (which is the same as input)
98+
self.assertIs(pass_.ensures_called_with, result.model)
99+
self.assertIs(pass_.ensures_called_with, input_model)
100+
101+
def test_postcondition_error_raised_when_ensures_fails(self):
102+
"""Test that PostconditionError is raised when ensures() raises an exception."""
103+
104+
class TestPass(_pass_infra.PassBase):
105+
@property
106+
def in_place(self) -> bool:
107+
return True
108+
109+
@property
110+
def changes_input(self) -> bool:
111+
return True
112+
113+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
114+
return _pass_infra.PassResult(model=model, modified=False)
115+
116+
def ensures(self, model: ir.Model) -> None:
117+
# Simulate a postcondition failure
118+
raise ValueError("Postcondition failed")
119+
120+
pass_ = TestPass()
121+
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
122+
123+
with self.assertRaises(PostconditionError) as cm:
124+
pass_(model)
125+
126+
self.assertIn("Post-condition for pass 'TestPass' failed", str(cm.exception))
127+
self.assertIsInstance(cm.exception.__cause__, ValueError)
128+
129+
def test_postcondition_error_raised_when_ensures_raises_postcondition_error(self):
130+
"""Test that PostconditionError is re-raised when ensures() raises PostconditionError."""
131+
132+
class TestPass(_pass_infra.PassBase):
133+
@property
134+
def in_place(self) -> bool:
135+
return True
136+
137+
@property
138+
def changes_input(self) -> bool:
139+
return True
140+
141+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
142+
return _pass_infra.PassResult(model=model, modified=False)
143+
144+
def ensures(self, model: ir.Model) -> None:
145+
# Directly raise PostconditionError
146+
raise PostconditionError("Direct postcondition error")
147+
148+
pass_ = TestPass()
149+
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
150+
151+
with self.assertRaises(PostconditionError) as cm:
152+
pass_(model)
153+
154+
self.assertEqual(str(cm.exception), "Direct postcondition error")
155+
156+
def test_ensures_receives_correct_model_when_pass_modifies_model(self):
157+
"""Test a more complex scenario where the pass modifies the model structure."""
158+
159+
class ModelModifyingPass(_pass_infra.FunctionalPass):
160+
def __init__(self):
161+
self.ensures_model_graph_nodes_count = None
162+
163+
def call(self, model: ir.Model) -> _pass_infra.PassResult:
164+
# Create a new model with additional nodes
165+
new_graph = ir.Graph([], [], nodes=[])
166+
# Add a dummy node to the graph to make it different
167+
new_node = ir.Node(
168+
domain="",
169+
op_type="Identity",
170+
inputs=[],
171+
outputs=[ir.Value(name="output", shape=ir.Shape([]), dtype=ir.DataType.FLOAT)],
172+
graph=new_graph
173+
)
174+
new_model = ir.Model(graph=new_graph, ir_version=model.ir_version)
175+
return _pass_infra.PassResult(model=new_model, modified=True)
176+
177+
def ensures(self, model: ir.Model) -> None:
178+
# Record the number of nodes in the model passed to ensures
179+
self.ensures_model_graph_nodes_count = len(list(model.graph))
180+
181+
pass_ = ModelModifyingPass()
182+
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
183+
result = pass_(input_model)
184+
185+
# The ensures method should see the modified model with 1 node
186+
self.assertEqual(pass_.ensures_model_graph_nodes_count, 1)
187+
188+
38189
if __name__ == "__main__":
39190
unittest.main()

0 commit comments

Comments
 (0)