@@ -40,7 +40,7 @@ class PostconditionTest(unittest.TestCase):
40
40
41
41
def test_ensures_called_with_result_model_not_input_model (self ):
42
42
"""Test that ensures() is called with result.model, not the input model."""
43
-
43
+
44
44
class TestPass (_pass_infra .PassBase ):
45
45
def __init__ (self ):
46
46
self .ensures_called_with = None
@@ -56,8 +56,7 @@ def changes_input(self) -> bool:
56
56
def call (self , model : ir .Model ) -> _pass_infra .PassResult :
57
57
# Create a new model (different object)
58
58
new_model = ir .Model (
59
- graph = ir .Graph ([], [], nodes = []),
60
- ir_version = model .ir_version
59
+ graph = ir .Graph ([], [], nodes = []), ir_version = model .ir_version
61
60
)
62
61
return _pass_infra .PassResult (model = new_model , modified = True )
63
62
@@ -68,14 +67,14 @@ def ensures(self, model: ir.Model) -> None:
68
67
pass_ = TestPass ()
69
68
input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
70
69
result = pass_ (input_model )
71
-
70
+
72
71
# Verify that ensures was called with the result model, not the input model
73
72
self .assertIs (pass_ .ensures_called_with , result .model )
74
73
self .assertIsNot (pass_ .ensures_called_with , input_model )
75
74
76
75
def test_ensures_called_with_result_model_in_place_pass (self ):
77
76
"""Test that ensures() is called with result.model for in-place passes."""
78
-
77
+
79
78
class TestInPlacePass (_pass_infra .InPlacePass ):
80
79
def __init__ (self ):
81
80
self .ensures_called_with = None
@@ -91,7 +90,7 @@ def ensures(self, model: ir.Model) -> None:
91
90
pass_ = TestInPlacePass ()
92
91
input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
93
92
result = pass_ (input_model )
94
-
93
+
95
94
# For in-place passes, result.model should be the same as input_model
96
95
self .assertIs (result .model , input_model )
97
96
# Verify that ensures was called with the result model (which is the same as input)
@@ -100,7 +99,7 @@ def ensures(self, model: ir.Model) -> None:
100
99
101
100
def test_postcondition_error_raised_when_ensures_fails (self ):
102
101
"""Test that PostconditionError is raised when ensures() raises an exception."""
103
-
102
+
104
103
class TestPass (_pass_infra .PassBase ):
105
104
@property
106
105
def in_place (self ) -> bool :
@@ -119,16 +118,17 @@ def ensures(self, model: ir.Model) -> None:
119
118
120
119
pass_ = TestPass ()
121
120
model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
122
-
123
- with self .assertRaises (PostconditionError ) as cm :
121
+
122
+ with self .assertRaisesRegex (
123
+ ir .passes .PostconditionError , "Post-condition for pass 'TestPass' failed"
124
+ ) as cm :
124
125
pass_ (model )
125
-
126
- self .assertIn ("Post-condition for pass 'TestPass' failed" , str (cm .exception ))
126
+
127
127
self .assertIsInstance (cm .exception .__cause__ , ValueError )
128
128
129
129
def test_postcondition_error_raised_when_ensures_raises_postcondition_error (self ):
130
130
"""Test that PostconditionError is re-raised when ensures() raises PostconditionError."""
131
-
131
+
132
132
class TestPass (_pass_infra .PassBase ):
133
133
@property
134
134
def in_place (self ) -> bool :
@@ -143,35 +143,33 @@ def call(self, model: ir.Model) -> _pass_infra.PassResult:
143
143
144
144
def ensures (self , model : ir .Model ) -> None :
145
145
# Directly raise PostconditionError
146
- raise PostconditionError ("Direct postcondition error" )
146
+ raise ir . passes . PostconditionError ("Direct postcondition error" )
147
147
148
148
pass_ = TestPass ()
149
149
model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
150
-
151
- with self .assertRaises (PostconditionError ) as cm :
150
+
151
+ with self .assertRaisesRegex (
152
+ ir .passes .PostconditionError , "Direct postcondition error"
153
+ ):
152
154
pass_ (model )
153
-
154
- self .assertEqual (str (cm .exception ), "Direct postcondition error" )
155
155
156
156
def test_ensures_receives_correct_model_when_pass_modifies_model (self ):
157
157
"""Test a more complex scenario where the pass modifies the model structure."""
158
-
158
+
159
159
class ModelModifyingPass (_pass_infra .FunctionalPass ):
160
160
def __init__ (self ):
161
161
self .ensures_model_graph_nodes_count = None
162
162
163
163
def call (self , model : ir .Model ) -> _pass_infra .PassResult :
164
164
# Create a new model with additional nodes
165
- new_graph = ir .Graph ([], [], nodes = [])
166
- # Add a dummy node to the graph to make it different
167
- new_node = ir .Node (
168
- domain = "" ,
169
- op_type = "Identity" ,
170
- inputs = [],
171
- outputs = [ir .Value (name = "output" , shape = ir .Shape ([]), dtype = ir .DataType .FLOAT )],
172
- graph = new_graph
165
+ new_model = ir .Model (
166
+ graph = ir .Graph (
167
+ [],
168
+ [output := ir .Value (name = "output" )],
169
+ nodes = [ir .node ("TestOp" , [], outputs = [output ])],
170
+ ),
171
+ ir_version = model .ir_version ,
173
172
)
174
- new_model = ir .Model (graph = new_graph , ir_version = model .ir_version )
175
173
return _pass_infra .PassResult (model = new_model , modified = True )
176
174
177
175
def ensures (self , model : ir .Model ) -> None :
@@ -180,8 +178,8 @@ def ensures(self, model: ir.Model) -> None:
180
178
181
179
pass_ = ModelModifyingPass ()
182
180
input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
183
- result = pass_ (input_model )
184
-
181
+ _ = pass_ (input_model )
182
+
185
183
# The ensures method should see the modified model with 1 node
186
184
self .assertEqual (pass_ .ensures_model_graph_nodes_count , 1 )
187
185
0 commit comments