Skip to content

Commit 5afacc7

Browse files
committed
Fix tests
1 parent 985daea commit 5afacc7

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

src/onnx_ir/passes/_pass_infra_test.py

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class PostconditionTest(unittest.TestCase):
4040

4141
def test_ensures_called_with_result_model_not_input_model(self):
4242
"""Test that ensures() is called with result.model, not the input model."""
43-
43+
4444
class TestPass(_pass_infra.PassBase):
4545
def __init__(self):
4646
self.ensures_called_with = None
@@ -56,8 +56,7 @@ def changes_input(self) -> bool:
5656
def call(self, model: ir.Model) -> _pass_infra.PassResult:
5757
# Create a new model (different object)
5858
new_model = ir.Model(
59-
graph=ir.Graph([], [], nodes=[]),
60-
ir_version=model.ir_version
59+
graph=ir.Graph([], [], nodes=[]), ir_version=model.ir_version
6160
)
6261
return _pass_infra.PassResult(model=new_model, modified=True)
6362

@@ -68,14 +67,14 @@ def ensures(self, model: ir.Model) -> None:
6867
pass_ = TestPass()
6968
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
7069
result = pass_(input_model)
71-
70+
7271
# Verify that ensures was called with the result model, not the input model
7372
self.assertIs(pass_.ensures_called_with, result.model)
7473
self.assertIsNot(pass_.ensures_called_with, input_model)
7574

7675
def test_ensures_called_with_result_model_in_place_pass(self):
7776
"""Test that ensures() is called with result.model for in-place passes."""
78-
77+
7978
class TestInPlacePass(_pass_infra.InPlacePass):
8079
def __init__(self):
8180
self.ensures_called_with = None
@@ -91,7 +90,7 @@ def ensures(self, model: ir.Model) -> None:
9190
pass_ = TestInPlacePass()
9291
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
9392
result = pass_(input_model)
94-
93+
9594
# For in-place passes, result.model should be the same as input_model
9695
self.assertIs(result.model, input_model)
9796
# Verify that ensures was called with the result model (which is the same as input)
@@ -100,7 +99,7 @@ def ensures(self, model: ir.Model) -> None:
10099

101100
def test_postcondition_error_raised_when_ensures_fails(self):
102101
"""Test that PostconditionError is raised when ensures() raises an exception."""
103-
102+
104103
class TestPass(_pass_infra.PassBase):
105104
@property
106105
def in_place(self) -> bool:
@@ -119,16 +118,17 @@ def ensures(self, model: ir.Model) -> None:
119118

120119
pass_ = TestPass()
121120
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
122-
123-
with self.assertRaises(PostconditionError) as cm:
121+
122+
with self.assertRaisesRegex(
123+
ir.passes.PostconditionError, "Post-condition for pass 'TestPass' failed"
124+
) as cm:
124125
pass_(model)
125-
126-
self.assertIn("Post-condition for pass 'TestPass' failed", str(cm.exception))
126+
127127
self.assertIsInstance(cm.exception.__cause__, ValueError)
128128

129129
def test_postcondition_error_raised_when_ensures_raises_postcondition_error(self):
130130
"""Test that PostconditionError is re-raised when ensures() raises PostconditionError."""
131-
131+
132132
class TestPass(_pass_infra.PassBase):
133133
@property
134134
def in_place(self) -> bool:
@@ -143,35 +143,33 @@ def call(self, model: ir.Model) -> _pass_infra.PassResult:
143143

144144
def ensures(self, model: ir.Model) -> None:
145145
# Directly raise PostconditionError
146-
raise PostconditionError("Direct postcondition error")
146+
raise ir.passes.PostconditionError("Direct postcondition error")
147147

148148
pass_ = TestPass()
149149
model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
150-
151-
with self.assertRaises(PostconditionError) as cm:
150+
151+
with self.assertRaisesRegex(
152+
ir.passes.PostconditionError, "Direct postcondition error"
153+
):
152154
pass_(model)
153-
154-
self.assertEqual(str(cm.exception), "Direct postcondition error")
155155

156156
def test_ensures_receives_correct_model_when_pass_modifies_model(self):
157157
"""Test a more complex scenario where the pass modifies the model structure."""
158-
158+
159159
class ModelModifyingPass(_pass_infra.FunctionalPass):
160160
def __init__(self):
161161
self.ensures_model_graph_nodes_count = None
162162

163163
def call(self, model: ir.Model) -> _pass_infra.PassResult:
164164
# Create a new model with additional nodes
165-
new_graph = ir.Graph([], [], nodes=[])
166-
# Add a dummy node to the graph to make it different
167-
new_node = ir.Node(
168-
domain="",
169-
op_type="Identity",
170-
inputs=[],
171-
outputs=[ir.Value(name="output", shape=ir.Shape([]), dtype=ir.DataType.FLOAT)],
172-
graph=new_graph
165+
new_model = ir.Model(
166+
graph=ir.Graph(
167+
[],
168+
[output := ir.Value(name="output")],
169+
nodes=[ir.node("TestOp", [], outputs=[output])],
170+
),
171+
ir_version=model.ir_version,
173172
)
174-
new_model = ir.Model(graph=new_graph, ir_version=model.ir_version)
175173
return _pass_infra.PassResult(model=new_model, modified=True)
176174

177175
def ensures(self, model: ir.Model) -> None:
@@ -180,8 +178,8 @@ def ensures(self, model: ir.Model) -> None:
180178

181179
pass_ = ModelModifyingPass()
182180
input_model = ir.Model(graph=ir.Graph([], [], nodes=[]), ir_version=10)
183-
result = pass_(input_model)
184-
181+
_ = pass_(input_model)
182+
185183
# The ensures method should see the modified model with 1 node
186184
self.assertEqual(pass_.ensures_model_graph_nodes_count, 1)
187185

0 commit comments

Comments
 (0)