@@ -35,5 +35,156 @@ def call(self, model: ir.Model) -> _pass_infra.PassResult:
35
35
self .assertIs (result_1 .model , result_2 .model )
36
36
37
37
38
+ class PostconditionTest (unittest .TestCase ):
39
+ """Test that postconditions are checked on the result model, not the input model."""
40
+
41
+ def test_ensures_called_with_result_model_not_input_model (self ):
42
+ """Test that ensures() is called with result.model, not the input model."""
43
+
44
+ class TestPass (_pass_infra .PassBase ):
45
+ def __init__ (self ):
46
+ self .ensures_called_with = None
47
+
48
+ @property
49
+ def in_place (self ) -> bool :
50
+ return False # Not in-place to create a new model
51
+
52
+ @property
53
+ def changes_input (self ) -> bool :
54
+ return False
55
+
56
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
57
+ # Create a new model (different object)
58
+ new_model = ir .Model (
59
+ graph = ir .Graph ([], [], nodes = []),
60
+ ir_version = model .ir_version
61
+ )
62
+ return _pass_infra .PassResult (model = new_model , modified = True )
63
+
64
+ def ensures (self , model : ir .Model ) -> None :
65
+ # Record which model ensures was called with
66
+ self .ensures_called_with = model
67
+
68
+ pass_ = TestPass ()
69
+ input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
70
+ result = pass_ (input_model )
71
+
72
+ # Verify that ensures was called with the result model, not the input model
73
+ self .assertIs (pass_ .ensures_called_with , result .model )
74
+ self .assertIsNot (pass_ .ensures_called_with , input_model )
75
+
76
+ def test_ensures_called_with_result_model_in_place_pass (self ):
77
+ """Test that ensures() is called with result.model for in-place passes."""
78
+
79
+ class TestInPlacePass (_pass_infra .InPlacePass ):
80
+ def __init__ (self ):
81
+ self .ensures_called_with = None
82
+
83
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
84
+ # In-place pass returns the same model
85
+ return _pass_infra .PassResult (model = model , modified = True )
86
+
87
+ def ensures (self , model : ir .Model ) -> None :
88
+ # Record which model ensures was called with
89
+ self .ensures_called_with = model
90
+
91
+ pass_ = TestInPlacePass ()
92
+ input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
93
+ result = pass_ (input_model )
94
+
95
+ # For in-place passes, result.model should be the same as input_model
96
+ self .assertIs (result .model , input_model )
97
+ # Verify that ensures was called with the result model (which is the same as input)
98
+ self .assertIs (pass_ .ensures_called_with , result .model )
99
+ self .assertIs (pass_ .ensures_called_with , input_model )
100
+
101
+ def test_postcondition_error_raised_when_ensures_fails (self ):
102
+ """Test that PostconditionError is raised when ensures() raises an exception."""
103
+
104
+ class TestPass (_pass_infra .PassBase ):
105
+ @property
106
+ def in_place (self ) -> bool :
107
+ return True
108
+
109
+ @property
110
+ def changes_input (self ) -> bool :
111
+ return True
112
+
113
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
114
+ return _pass_infra .PassResult (model = model , modified = False )
115
+
116
+ def ensures (self , model : ir .Model ) -> None :
117
+ # Simulate a postcondition failure
118
+ raise ValueError ("Postcondition failed" )
119
+
120
+ pass_ = TestPass ()
121
+ model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
122
+
123
+ with self .assertRaises (PostconditionError ) as cm :
124
+ pass_ (model )
125
+
126
+ self .assertIn ("Post-condition for pass 'TestPass' failed" , str (cm .exception ))
127
+ self .assertIsInstance (cm .exception .__cause__ , ValueError )
128
+
129
+ def test_postcondition_error_raised_when_ensures_raises_postcondition_error (self ):
130
+ """Test that PostconditionError is re-raised when ensures() raises PostconditionError."""
131
+
132
+ class TestPass (_pass_infra .PassBase ):
133
+ @property
134
+ def in_place (self ) -> bool :
135
+ return True
136
+
137
+ @property
138
+ def changes_input (self ) -> bool :
139
+ return True
140
+
141
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
142
+ return _pass_infra .PassResult (model = model , modified = False )
143
+
144
+ def ensures (self , model : ir .Model ) -> None :
145
+ # Directly raise PostconditionError
146
+ raise PostconditionError ("Direct postcondition error" )
147
+
148
+ pass_ = TestPass ()
149
+ model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
150
+
151
+ with self .assertRaises (PostconditionError ) as cm :
152
+ pass_ (model )
153
+
154
+ self .assertEqual (str (cm .exception ), "Direct postcondition error" )
155
+
156
+ def test_ensures_receives_correct_model_when_pass_modifies_model (self ):
157
+ """Test a more complex scenario where the pass modifies the model structure."""
158
+
159
+ class ModelModifyingPass (_pass_infra .FunctionalPass ):
160
+ def __init__ (self ):
161
+ self .ensures_model_graph_nodes_count = None
162
+
163
+ def call (self , model : ir .Model ) -> _pass_infra .PassResult :
164
+ # Create a new model with additional nodes
165
+ new_graph = ir .Graph ([], [], nodes = [])
166
+ # Add a dummy node to the graph to make it different
167
+ new_node = ir .Node (
168
+ domain = "" ,
169
+ op_type = "Identity" ,
170
+ inputs = [],
171
+ outputs = [ir .Value (name = "output" , shape = ir .Shape ([]), dtype = ir .DataType .FLOAT )],
172
+ graph = new_graph
173
+ )
174
+ new_model = ir .Model (graph = new_graph , ir_version = model .ir_version )
175
+ return _pass_infra .PassResult (model = new_model , modified = True )
176
+
177
+ def ensures (self , model : ir .Model ) -> None :
178
+ # Record the number of nodes in the model passed to ensures
179
+ self .ensures_model_graph_nodes_count = len (list (model .graph ))
180
+
181
+ pass_ = ModelModifyingPass ()
182
+ input_model = ir .Model (graph = ir .Graph ([], [], nodes = []), ir_version = 10 )
183
+ result = pass_ (input_model )
184
+
185
+ # The ensures method should see the modified model with 1 node
186
+ self .assertEqual (pass_ .ensures_model_graph_nodes_count , 1 )
187
+
188
+
38
189
if __name__ == "__main__" :
39
190
unittest .main ()
0 commit comments