@@ -97,6 +97,32 @@ def ensures(self, model: ir.Model) -> None:
97
97
self .assertIs (pass_ .ensures_called_with , result .model )
98
98
self .assertIs (pass_ .ensures_called_with , input_model )
99
99
100
+ def test_ensures_called_with_result_model_functional_pass (self ):
101
+ """Test that ensures() is called with result.model for functional passes."""
102
+
103
+ class TestPass (_pass_infra .FunctionalPass ):
104
+ def __init__ (self ):
105
+ self .ensures_called_with = None
106
+
107
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
108
+ # Create a new model (different object)
109
+ new_model = ir .Model (
110
+ graph = ir .Graph ([], [], nodes = []), ir_version = model .ir_version
111
+ )
112
+ return _pass_infra .PassResult (model = new_model , modified = True )
113
+
114
+ def ensures (self , model : ir .Model ) -> None :
115
+ # Record which model ensures was called with
116
+ self .ensures_called_with = model
117
+
118
+ pass_ = TestPass ()
119
+ input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
120
+ result = pass_ (input_model )
121
+
122
+ # Verify that ensures was called with the result model, not the input model
123
+ self .assertIs (pass_ .ensures_called_with , result .model )
124
+ self .assertIsNot (pass_ .ensures_called_with , input_model )
125
+
100
126
def test_postcondition_error_raised_when_ensures_fails (self ):
101
127
"""Test that PostconditionError is raised when ensures() raises an exception."""
102
128
@@ -153,36 +179,6 @@ def ensures(self, model: ir.Model) -> None:
153
179
):
154
180
pass_ (model )
155
181
156
- def test_ensures_receives_correct_model_when_pass_modifies_model (self ):
157
- """Test a more complex scenario where the pass modifies the model structure."""
158
-
159
- class ModelModifyingPass (_pass_infra .FunctionalPass ):
160
- def __init__ (self ):
161
- self .ensures_model_graph_nodes_count = None
162
-
163
- def call (self , model : ir .Model ) -> _pass_infra .PassResult :
164
- # Create a new model with additional nodes
165
- new_model = ir .Model (
166
- graph = ir .Graph (
167
- [],
168
- [output := ir .Value (name = "output" )],
169
- nodes = [ir .node ("TestOp" , [], outputs = [output ])],
170
- ),
171
- ir_version = model .ir_version ,
172
- )
173
- return _pass_infra .PassResult (model = new_model , modified = True )
174
-
175
- def ensures (self , model : ir .Model ) -> None :
176
- # Record the number of nodes in the model passed to ensures
177
- self .ensures_model_graph_nodes_count = len (list (model .graph ))
178
-
179
- pass_ = ModelModifyingPass ()
180
- input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
181
- _ = pass_ (input_model )
182
-
183
- # The ensures method should see the modified model with 1 node
184
- self .assertEqual (pass_ .ensures_model_graph_nodes_count , 1 )
185
-
186
182
187
183
if __name__ == "__main__" :
188
184
unittest .main ()
0 commit comments