57
57
from executorch .exir .passes import dead_code_elimination_pass
58
58
59
59
from parameterized .parameterized import parameterized
60
- from torch ._ops import OpOverload
61
60
from torch .fx .passes .infra .pass_base import PassResult
62
61
63
62
@@ -87,36 +86,46 @@ def assertTargetCountsEqual(
87
86
88
87
@parameterized .expand (
89
88
[
90
- # Regular MM
91
- [(64 , 33 ), (33 , 128 )],
92
- # Batched MM
93
- [(2 , 48 , 48 ), (2 , 48 , 48 )],
94
- ]
89
+ (
90
+ "regular" ,
91
+ (64 , 33 ), # x_shape
92
+ (33 , 128 ), # y_shape
93
+ ),
94
+ (
95
+ "batched" ,
96
+ (2 , 48 , 48 ), # x_shape
97
+ (2 , 48 , 48 ), # y_shape
98
+ ),
99
+ ],
95
100
)
96
101
@torch .no_grad ()
97
102
def test_replace_matmul_with_transposed_matmul (
98
103
self ,
104
+ _ ,
99
105
x_shape : Tuple [int ],
100
106
y_shape : Tuple [int ],
101
107
) -> None :
102
- class MatMul (torch .nn .Module ):
103
- def __init__ (self ) -> None :
104
- super (MatMul , self ).__init__ ()
105
-
106
- def forward (self , x , y ):
107
- return torch .matmul (x , y )
108
-
109
- model = MatMul ()
110
- X = torch .randn (x_shape )
111
- Y = torch .randn (y_shape )
112
- p = ReplaceMatmulWithTransposedMatmulPass ()
113
- inputs = (X , Y )
114
- graph_module = (
115
- quantize_and_export_to_edge (model , inputs ).exported_program ().graph_module
108
+ builder = GraphBuilder ()
109
+ x = builder .placeholder ("x" , torch .randn (* x_shape , dtype = torch .float32 ))
110
+ y = builder .placeholder ("y" , torch .randn (* y_shape , dtype = torch .float32 ))
111
+ matmul = builder .call_operator (
112
+ op = exir_ops .edge .cadence .quantized_matmul .default ,
113
+ args = (
114
+ x ,
115
+ 0 , # X_zero_point
116
+ y ,
117
+ 0 , # Y_zero_point,
118
+ None , # bias
119
+ 1 , # out_multiplier
120
+ 0 , # out_shift
121
+ 0 , # out_zero_point
122
+ False , # transposed=False
123
+ ),
116
124
)
117
- # pyre-fixme[16]: Optional type has no attribute `graph_module`
118
- graph_after_passes = p (graph_module ).graph_module
119
-
125
+ builder .output ([matmul ])
126
+ original = builder .get_graph_module ()
127
+ p = ReplaceMatmulWithTransposedMatmulPass ()
128
+ graph_after_passes = cast (PassResult , p (original )).graph_module
120
129
self .assertEqual (
121
130
count_node (graph_after_passes , exir_ops .edge .aten .transpose_copy .int ),
122
131
1 ,
@@ -130,33 +139,24 @@ def forward(self, x, y):
130
139
131
140
@parameterized .expand (
132
141
[
133
- [(3 , 5 ), (0 , 0 )],
134
- [
135
- (20 , 1 , 80 ),
136
- (0 , 0 ),
137
- ],
138
- ]
142
+ ("2d" , (3 , 5 ), [0 , 0 ]), # shape # padding
143
+ ("3d" , (20 , 1 , 80 ), [0 , 0 , 0 ]), # shape # padding
144
+ ],
139
145
)
140
146
@torch .no_grad ()
141
147
def test_replace_constant_pad_nd_with_slice (
142
- self , shape : Tuple [int ], padding : Tuple [int ]
148
+ self , _ , shape : Tuple [int ], padding : Tuple [int ]
143
149
):
144
- # F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
145
- class Padding (torch .nn .Module ):
146
- def __init__ (self ):
147
- super ().__init__ ()
148
- self .padding = padding
149
-
150
- def forward (self , x : torch .Tensor ):
151
- return F .pad (x , self .padding )
152
-
153
- model = Padding ()
154
- x = torch .randn (shape )
155
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
156
-
150
+ builder = GraphBuilder ()
151
+ x = builder .placeholder ("x" , torch .randn (* shape , dtype = torch .float32 ))
152
+ matmul = builder .call_operator (
153
+ op = exir_ops .edge .aten .constant_pad_nd .default ,
154
+ args = (x , [0 , 0 , 0 , 0 ]),
155
+ )
156
+ builder .output ([matmul ])
157
+ original = builder .get_graph_module ()
157
158
p = ReplaceConstantPadNdWithSlicePass ()
158
-
159
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
159
+ graph_after_passes = cast (PassResult , p (original )).graph_module
160
160
self .assertEqual (
161
161
count_node (graph_after_passes , exir_ops .edge .aten .slice .Tensor ),
162
162
1 ,
@@ -169,142 +169,140 @@ def forward(self, x: torch.Tensor):
169
169
170
170
@parameterized .expand (
171
171
[
172
- [(7 , 5 , 6 ), 1.23 ],
173
- [(7 , 5 ), 2 ],
172
+ ["3d" , (7 , 5 , 6 ), 1.23 ],
173
+ ["2d" , (7 , 5 ), 2 ],
174
+ ["1d" , (10 ,), 42949 ],
174
175
]
175
176
)
176
177
@torch .no_grad ()
177
- def test_add_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
178
- class Add (torch .nn .Module ):
179
- def forward (self , x ):
180
- return torch .ops .aten .add .Scalar (x , other )
181
-
182
- model = Add ()
178
+ def test_add_replace_scalar_with_tensor_arg (
179
+ self , _ , shape : Tuple [int ], other : float
180
+ ):
183
181
x = torch .randn (shape )
184
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
185
-
182
+ original = single_op_builder (
183
+ placeholders = (x ,),
184
+ op = exir_ops .edge .aten .add .Scalar ,
185
+ args = (x , other ),
186
+ )
186
187
p = ReplaceScalarWithTensorArgPass ()
187
-
188
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
188
+ graph_after_passes = cast (PassResult , p (original )).graph_module
189
189
self .assertEqual (
190
190
count_node (graph_after_passes , exir_ops .edge .aten .add .Tensor ),
191
191
1 ,
192
192
)
193
-
194
193
self .assertEqual (
195
194
count_node (graph_after_passes , exir_ops .edge .aten .add .Scalar ),
196
195
0 ,
197
196
)
198
197
199
198
@parameterized .expand (
200
199
[
201
- [(7 , 5 , 6 ), 1.23 ],
202
- [(7 , 5 ), 2 ],
203
- [(10 ), 42949 ],
200
+ ["3d" , (7 , 5 , 6 ), 1.23 ],
201
+ ["2d" , (7 , 5 ), 2 ],
202
+ ["1d" , (10 , ), 42949 ],
204
203
]
205
204
)
206
205
@torch .no_grad ()
207
- def test_sub_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
208
- class Sub (torch .nn .Module ):
209
- def forward (self , x ):
210
- return torch .ops .aten .sub .Scalar (x , other )
211
-
212
- model = Sub ()
206
+ def test_sub_replace_scalar_with_tensor_arg (
207
+ self , _ , shape : Tuple [int ], other : float
208
+ ):
213
209
x = torch .randn (shape )
214
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
215
-
210
+ original = single_op_builder (
211
+ placeholders = (x ,),
212
+ op = exir_ops .edge .aten .sub .Scalar ,
213
+ args = (x , other ),
214
+ )
216
215
p = ReplaceScalarWithTensorArgPass ()
217
-
218
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
216
+ graph_after_passes = cast (PassResult , p (original )).graph_module
219
217
self .assertEqual (
220
218
count_node (graph_after_passes , exir_ops .edge .aten .sub .Tensor ),
221
219
1 ,
222
220
)
223
-
224
221
self .assertEqual (
225
222
count_node (graph_after_passes , exir_ops .edge .aten .sub .Scalar ),
226
223
0 ,
227
224
)
228
225
229
226
@parameterized .expand (
230
227
[
231
- [(7 , 5 , 6 ), 1.23 ],
232
- [(7 , 5 ), 2 ],
233
- [( 513 ), 3 ],
228
+ ["3d" , (7 , 5 , 6 ), 1.23 ],
229
+ ["2d" , (7 , 5 ), 2 ],
230
+ ["1d" , ( 10 , ), 42949 ],
234
231
]
235
232
)
236
233
@torch .no_grad ()
237
- def test_mul_replace_scalar_with_tensor_arg (self , shape : Tuple [int ], other : float ):
238
- class Mul (torch .nn .Module ):
239
- def forward (self , x ):
240
- return torch .ops .aten .mul .Scalar (x , other )
241
-
242
- model = Mul ()
234
+ def test_mul_replace_scalar_with_tensor_arg (
235
+ self , _ , shape : Tuple [int ], other : float
236
+ ):
243
237
x = torch .randn (shape )
244
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
245
-
238
+ original = single_op_builder (
239
+ placeholders = (x ,),
240
+ op = exir_ops .edge .aten .mul .Scalar ,
241
+ args = (x , other ),
242
+ )
246
243
p = ReplaceScalarWithTensorArgPass ()
247
-
248
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
244
+ graph_after_passes = cast (PassResult , p (original )).graph_module
249
245
self .assertEqual (
250
246
count_node (graph_after_passes , exir_ops .edge .aten .mul .Tensor ),
251
247
1 ,
252
248
)
253
-
254
249
self .assertEqual (
255
250
count_node (graph_after_passes , exir_ops .edge .aten .mul .Scalar ),
256
251
0 ,
257
252
)
258
253
259
254
@parameterized .expand (
260
255
[
261
- [(7 , 5 , 6 ), 1.23 ],
262
- [(7 , 5 ), 2 ],
256
+ ["3d" , (7 , 5 , 6 ), 1.23 ],
257
+ ["2d" , (7 , 5 ), 2 ],
258
+ ["1d" , (10 ,), 42949 ],
263
259
]
264
260
)
265
261
@torch .no_grad ()
266
262
def test_div_replace_scalar_with_tensor_arg (
267
263
self ,
264
+ _ ,
268
265
shape : Tuple [int ],
269
266
other : float ,
270
267
):
271
- class Div (torch .nn .Module ):
272
- def forward (self , x ):
273
- return torch .ops .aten .div .Scalar (x , other )
274
-
275
- model = Div ()
276
- x = torch .randn (shape )
277
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
278
-
268
+ x = torch .randn (* shape )
269
+ original = single_op_builder (
270
+ placeholders = (x ,),
271
+ op = exir_ops .edge .aten .div .Scalar ,
272
+ args = (x , other ),
273
+ )
279
274
p = ReplaceScalarWithTensorArgPass ()
280
-
281
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
275
+ graph_after_passes = cast (PassResult , p (original )).graph_module
282
276
self .assertEqual (
283
277
count_node (graph_after_passes , exir_ops .edge .aten .div .Tensor ),
284
278
1 ,
285
279
)
286
-
287
280
self .assertEqual (
288
281
count_node (graph_after_passes , exir_ops .edge .aten .div .Scalar ),
289
282
0 ,
290
283
)
291
284
292
285
@parameterized .expand (
293
286
[
294
- [(2 , 3 , 5 , 6 )],
295
- [(7 , 6 , 5 )],
296
- [(4 , 4 )],
297
- [(316 )],
287
+ ["4d" , (2 , 3 , 5 , 6 )],
288
+ ["3d" , (7 , 6 , 5 )],
289
+ ["2d" , (4 , 4 )],
290
+ ["1d" , (316 )],
298
291
]
299
292
)
300
293
@torch .no_grad ()
301
- def test_replace_functionally_equivalent_op_targets_relu (self , shape : Tuple [int ]):
302
- model = torch .nn .ReLU ()
294
+ def test_replace_functionally_equivalent_op_targets_relu (
295
+ self , _ , shape : Tuple [int ]
296
+ ):
303
297
x = torch .randn (shape )
304
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
298
+ original = single_op_builder (
299
+ placeholders = (x ,),
300
+ op = exir_ops .edge .aten .relu_ .default ,
301
+ args = (x ,),
302
+ )
305
303
p = ReplaceFunctionallyEquivalentOpTargets ()
304
+ graph_after_passes = cast (PassResult , p (original )).graph_module
306
305
307
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
308
306
self .assertEqual (
309
307
count_node (graph_after_passes , exir_ops .edge .aten .relu .default ),
310
308
1 ,
@@ -315,56 +313,29 @@ def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]
315
313
)
316
314
317
315
@parameterized .expand (
318
- [
319
- # split the only dimension
320
- [(50 ,), i , 0 ]
321
- for i in range (2 , 7 )
322
- ]
323
- + [
324
- # split the leading dim
325
- [(10 , 2 , 3 ), i , 0 ]
326
- for i in range (2 , 7 )
327
- ]
328
- + [
329
- # split the trailing dim
330
- [(3 , 3 , 6 ), i , 2 ]
331
- for i in range (2 , 6 )
332
- ]
333
- + [
334
- # split the dim in the middle
335
- [(3 , 5 , 14 , 2 , 3 ), i , 2 ]
336
- for i in range (2 , 7 )
337
- ]
316
+ [["split_linear_tensor" , (50 ,), i , 0 ] for i in range (2 , 7 )]
317
+ + [["split_leading_dim" , (10 , 2 , 3 ), i , 0 ] for i in range (2 , 7 )]
318
+ + [["split_trailing_dim" , (3 , 3 , 6 ), i , 2 ] for i in range (2 , 6 )]
319
+ + [["split_middle_dim" , (3 , 5 , 14 , 2 , 3 ), i , 2 ] for i in range (2 , 7 )]
338
320
)
339
321
@torch .no_grad ()
340
322
def test_replace_functionally_equivalent_op_targets_unsafe_split (
341
- self , shape : Tuple [int ], split_size : int , dim : int
323
+ self , _ , shape : Tuple [int ], split_size : int , dim : int
342
324
):
343
- class TensorSplitWithSizes (torch .nn .Module ):
344
- def __init__ (self , split_size : int , dim : int , op : OpOverload ):
345
- super ().__init__ ()
346
- self .split_size = split_size
347
- self .dim = dim
348
- self .op = op
349
-
350
- def forward (self , x : torch .Tensor ):
351
- return self .op (x , self .split_size , self .dim )
352
-
353
325
x = torch .randn (shape )
354
- model = TensorSplitWithSizes (split_size , dim , torch .unsafe_split )
355
- graph_module = export_to_edge (model , (x ,)).exported_program ().graph_module
326
+ original = single_op_builder (
327
+ placeholders = (x ,),
328
+ op = exir_ops .edge .aten .unsafe_split .Tensor ,
329
+ args = (x , split_size , dim ),
330
+ )
356
331
p = ReplaceFunctionallyEquivalentOpTargets ()
357
-
358
- graph_after_passes = cast (PassResult , p (graph_module )).graph_module
332
+ graph_after_passes = cast (PassResult , p (original )).graph_module
359
333
self .assertEqual (
360
- count_node (
361
- graph_after_passes , exir_ops .edge .aten .split_with_sizes_copy .default
362
- ),
334
+ count_node (graph_after_passes , exir_ops .edge .aten .split_copy .Tensor ),
363
335
1 ,
364
336
)
365
337
self .assertEqual (
366
- count_node (graph_after_passes , exir_ops .edge .aten .unsafe_split .Tensor ),
367
- 0 ,
338
+ count_node (graph_after_passes , exir_ops .edge .aten .unsafe_split .Tensor ), 0 , x
368
339
)
369
340
370
341
@parameterized .expand (
0 commit comments