Skip to content

Commit fbe2e58

Browse files
authored
Use GraphBuilder in test_replace_ops_passes. #1
Differential Revision: D75911655 Pull Request resolved: #11344
1 parent 27cb43d commit fbe2e58

File tree

1 file changed

+118
-147
lines changed

1 file changed

+118
-147
lines changed

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 118 additions & 147 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,6 @@
5757
from executorch.exir.passes import dead_code_elimination_pass
5858

5959
from parameterized.parameterized import parameterized
60-
from torch._ops import OpOverload
6160
from torch.fx.passes.infra.pass_base import PassResult
6261

6362

@@ -87,36 +86,46 @@ def assertTargetCountsEqual(
8786

8887
@parameterized.expand(
8988
[
90-
# Regular MM
91-
[(64, 33), (33, 128)],
92-
# Batched MM
93-
[(2, 48, 48), (2, 48, 48)],
94-
]
89+
(
90+
"regular",
91+
(64, 33), # x_shape
92+
(33, 128), # y_shape
93+
),
94+
(
95+
"batched",
96+
(2, 48, 48), # x_shape
97+
(2, 48, 48), # y_shape
98+
),
99+
],
95100
)
96101
@torch.no_grad()
97102
def test_replace_matmul_with_transposed_matmul(
98103
self,
104+
_,
99105
x_shape: Tuple[int],
100106
y_shape: Tuple[int],
101107
) -> None:
102-
class MatMul(torch.nn.Module):
103-
def __init__(self) -> None:
104-
super(MatMul, self).__init__()
105-
106-
def forward(self, x, y):
107-
return torch.matmul(x, y)
108-
109-
model = MatMul()
110-
X = torch.randn(x_shape)
111-
Y = torch.randn(y_shape)
112-
p = ReplaceMatmulWithTransposedMatmulPass()
113-
inputs = (X, Y)
114-
graph_module = (
115-
quantize_and_export_to_edge(model, inputs).exported_program().graph_module
108+
builder = GraphBuilder()
109+
x = builder.placeholder("x", torch.randn(*x_shape, dtype=torch.float32))
110+
y = builder.placeholder("y", torch.randn(*y_shape, dtype=torch.float32))
111+
matmul = builder.call_operator(
112+
op=exir_ops.edge.cadence.quantized_matmul.default,
113+
args=(
114+
x,
115+
0, # X_zero_point
116+
y,
117+
0, # Y_zero_point,
118+
None, # bias
119+
1, # out_multiplier
120+
0, # out_shift
121+
0, # out_zero_point
122+
False, # transposed=False
123+
),
116124
)
117-
# pyre-fixme[16]: Optional type has no attribute `graph_module`
118-
graph_after_passes = p(graph_module).graph_module
119-
125+
builder.output([matmul])
126+
original = builder.get_graph_module()
127+
p = ReplaceMatmulWithTransposedMatmulPass()
128+
graph_after_passes = cast(PassResult, p(original)).graph_module
120129
self.assertEqual(
121130
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
122131
1,
@@ -130,33 +139,24 @@ def forward(self, x, y):
130139

131140
@parameterized.expand(
132141
[
133-
[(3, 5), (0, 0)],
134-
[
135-
(20, 1, 80),
136-
(0, 0),
137-
],
138-
]
142+
("2d", (3, 5), [0, 0]), # shape # padding
143+
("3d", (20, 1, 80), [0, 0, 0]), # shape # padding
144+
],
139145
)
140146
@torch.no_grad()
141147
def test_replace_constant_pad_nd_with_slice(
142-
self, shape: Tuple[int], padding: Tuple[int]
148+
self, _, shape: Tuple[int], padding: Tuple[int]
143149
):
144-
# F.pad is converted to aten::constant_pad_nd after functionalization & decomposition.
145-
class Padding(torch.nn.Module):
146-
def __init__(self):
147-
super().__init__()
148-
self.padding = padding
149-
150-
def forward(self, x: torch.Tensor):
151-
return F.pad(x, self.padding)
152-
153-
model = Padding()
154-
x = torch.randn(shape)
155-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
156-
150+
builder = GraphBuilder()
151+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
152+
matmul = builder.call_operator(
153+
op=exir_ops.edge.aten.constant_pad_nd.default,
154+
args=(x, [0, 0, 0, 0]),
155+
)
156+
builder.output([matmul])
157+
original = builder.get_graph_module()
157158
p = ReplaceConstantPadNdWithSlicePass()
158-
159-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
159+
graph_after_passes = cast(PassResult, p(original)).graph_module
160160
self.assertEqual(
161161
count_node(graph_after_passes, exir_ops.edge.aten.slice.Tensor),
162162
1,
@@ -169,142 +169,140 @@ def forward(self, x: torch.Tensor):
169169

170170
@parameterized.expand(
171171
[
172-
[(7, 5, 6), 1.23],
173-
[(7, 5), 2],
172+
["3d", (7, 5, 6), 1.23],
173+
["2d", (7, 5), 2],
174+
["1d", (10,), 42949],
174175
]
175176
)
176177
@torch.no_grad()
177-
def test_add_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
178-
class Add(torch.nn.Module):
179-
def forward(self, x):
180-
return torch.ops.aten.add.Scalar(x, other)
181-
182-
model = Add()
178+
def test_add_replace_scalar_with_tensor_arg(
179+
self, _, shape: Tuple[int], other: float
180+
):
183181
x = torch.randn(shape)
184-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
185-
182+
original = single_op_builder(
183+
placeholders=(x,),
184+
op=exir_ops.edge.aten.add.Scalar,
185+
args=(x, other),
186+
)
186187
p = ReplaceScalarWithTensorArgPass()
187-
188-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
188+
graph_after_passes = cast(PassResult, p(original)).graph_module
189189
self.assertEqual(
190190
count_node(graph_after_passes, exir_ops.edge.aten.add.Tensor),
191191
1,
192192
)
193-
194193
self.assertEqual(
195194
count_node(graph_after_passes, exir_ops.edge.aten.add.Scalar),
196195
0,
197196
)
198197

199198
@parameterized.expand(
200199
[
201-
[(7, 5, 6), 1.23],
202-
[(7, 5), 2],
203-
[(10), 42949],
200+
["3d", (7, 5, 6), 1.23],
201+
["2d", (7, 5), 2],
202+
["1d", (10,), 42949],
204203
]
205204
)
206205
@torch.no_grad()
207-
def test_sub_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
208-
class Sub(torch.nn.Module):
209-
def forward(self, x):
210-
return torch.ops.aten.sub.Scalar(x, other)
211-
212-
model = Sub()
206+
def test_sub_replace_scalar_with_tensor_arg(
207+
self, _, shape: Tuple[int], other: float
208+
):
213209
x = torch.randn(shape)
214-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
215-
210+
original = single_op_builder(
211+
placeholders=(x,),
212+
op=exir_ops.edge.aten.sub.Scalar,
213+
args=(x, other),
214+
)
216215
p = ReplaceScalarWithTensorArgPass()
217-
218-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
216+
graph_after_passes = cast(PassResult, p(original)).graph_module
219217
self.assertEqual(
220218
count_node(graph_after_passes, exir_ops.edge.aten.sub.Tensor),
221219
1,
222220
)
223-
224221
self.assertEqual(
225222
count_node(graph_after_passes, exir_ops.edge.aten.sub.Scalar),
226223
0,
227224
)
228225

229226
@parameterized.expand(
230227
[
231-
[(7, 5, 6), 1.23],
232-
[(7, 5), 2],
233-
[(513), 3],
228+
["3d", (7, 5, 6), 1.23],
229+
["2d", (7, 5), 2],
230+
["1d", (10,), 42949],
234231
]
235232
)
236233
@torch.no_grad()
237-
def test_mul_replace_scalar_with_tensor_arg(self, shape: Tuple[int], other: float):
238-
class Mul(torch.nn.Module):
239-
def forward(self, x):
240-
return torch.ops.aten.mul.Scalar(x, other)
241-
242-
model = Mul()
234+
def test_mul_replace_scalar_with_tensor_arg(
235+
self, _, shape: Tuple[int], other: float
236+
):
243237
x = torch.randn(shape)
244-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
245-
238+
original = single_op_builder(
239+
placeholders=(x,),
240+
op=exir_ops.edge.aten.mul.Scalar,
241+
args=(x, other),
242+
)
246243
p = ReplaceScalarWithTensorArgPass()
247-
248-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
244+
graph_after_passes = cast(PassResult, p(original)).graph_module
249245
self.assertEqual(
250246
count_node(graph_after_passes, exir_ops.edge.aten.mul.Tensor),
251247
1,
252248
)
253-
254249
self.assertEqual(
255250
count_node(graph_after_passes, exir_ops.edge.aten.mul.Scalar),
256251
0,
257252
)
258253

259254
@parameterized.expand(
260255
[
261-
[(7, 5, 6), 1.23],
262-
[(7, 5), 2],
256+
["3d", (7, 5, 6), 1.23],
257+
["2d", (7, 5), 2],
258+
["1d", (10,), 42949],
263259
]
264260
)
265261
@torch.no_grad()
266262
def test_div_replace_scalar_with_tensor_arg(
267263
self,
264+
_,
268265
shape: Tuple[int],
269266
other: float,
270267
):
271-
class Div(torch.nn.Module):
272-
def forward(self, x):
273-
return torch.ops.aten.div.Scalar(x, other)
274-
275-
model = Div()
276-
x = torch.randn(shape)
277-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
278-
268+
x = torch.randn(*shape)
269+
original = single_op_builder(
270+
placeholders=(x,),
271+
op=exir_ops.edge.aten.div.Scalar,
272+
args=(x, other),
273+
)
279274
p = ReplaceScalarWithTensorArgPass()
280-
281-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
275+
graph_after_passes = cast(PassResult, p(original)).graph_module
282276
self.assertEqual(
283277
count_node(graph_after_passes, exir_ops.edge.aten.div.Tensor),
284278
1,
285279
)
286-
287280
self.assertEqual(
288281
count_node(graph_after_passes, exir_ops.edge.aten.div.Scalar),
289282
0,
290283
)
291284

292285
@parameterized.expand(
293286
[
294-
[(2, 3, 5, 6)],
295-
[(7, 6, 5)],
296-
[(4, 4)],
297-
[(316)],
287+
["4d", (2, 3, 5, 6)],
288+
["3d", (7, 6, 5)],
289+
["2d", (4, 4)],
290+
["1d", (316)],
298291
]
299292
)
300293
@torch.no_grad()
301-
def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]):
302-
model = torch.nn.ReLU()
294+
def test_replace_functionally_equivalent_op_targets_relu(
295+
self, _, shape: Tuple[int]
296+
):
303297
x = torch.randn(shape)
304-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
298+
original = single_op_builder(
299+
placeholders=(x,),
300+
op=exir_ops.edge.aten.relu_.default,
301+
args=(x,),
302+
)
305303
p = ReplaceFunctionallyEquivalentOpTargets()
304+
graph_after_passes = cast(PassResult, p(original)).graph_module
306305

307-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
308306
self.assertEqual(
309307
count_node(graph_after_passes, exir_ops.edge.aten.relu.default),
310308
1,
@@ -315,56 +313,29 @@ def test_replace_functionally_equivalent_op_targets_relu(self, shape: Tuple[int]
315313
)
316314

317315
@parameterized.expand(
318-
[
319-
# split the only dimension
320-
[(50,), i, 0]
321-
for i in range(2, 7)
322-
]
323-
+ [
324-
# split the leading dim
325-
[(10, 2, 3), i, 0]
326-
for i in range(2, 7)
327-
]
328-
+ [
329-
# split the trailing dim
330-
[(3, 3, 6), i, 2]
331-
for i in range(2, 6)
332-
]
333-
+ [
334-
# split the dim in the middle
335-
[(3, 5, 14, 2, 3), i, 2]
336-
for i in range(2, 7)
337-
]
316+
[["split_linear_tensor", (50,), i, 0] for i in range(2, 7)]
317+
+ [["split_leading_dim", (10, 2, 3), i, 0] for i in range(2, 7)]
318+
+ [["split_trailing_dim", (3, 3, 6), i, 2] for i in range(2, 6)]
319+
+ [["split_middle_dim", (3, 5, 14, 2, 3), i, 2] for i in range(2, 7)]
338320
)
339321
@torch.no_grad()
340322
def test_replace_functionally_equivalent_op_targets_unsafe_split(
341-
self, shape: Tuple[int], split_size: int, dim: int
323+
self, _, shape: Tuple[int], split_size: int, dim: int
342324
):
343-
class TensorSplitWithSizes(torch.nn.Module):
344-
def __init__(self, split_size: int, dim: int, op: OpOverload):
345-
super().__init__()
346-
self.split_size = split_size
347-
self.dim = dim
348-
self.op = op
349-
350-
def forward(self, x: torch.Tensor):
351-
return self.op(x, self.split_size, self.dim)
352-
353325
x = torch.randn(shape)
354-
model = TensorSplitWithSizes(split_size, dim, torch.unsafe_split)
355-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
326+
original = single_op_builder(
327+
placeholders=(x,),
328+
op=exir_ops.edge.aten.unsafe_split.Tensor,
329+
args=(x, split_size, dim),
330+
)
356331
p = ReplaceFunctionallyEquivalentOpTargets()
357-
358-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
332+
graph_after_passes = cast(PassResult, p(original)).graph_module
359333
self.assertEqual(
360-
count_node(
361-
graph_after_passes, exir_ops.edge.aten.split_with_sizes_copy.default
362-
),
334+
count_node(graph_after_passes, exir_ops.edge.aten.split_copy.Tensor),
363335
1,
364336
)
365337
self.assertEqual(
366-
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor),
367-
0,
338+
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x
368339
)
369340

370341
@parameterized.expand(

0 commit comments

Comments
 (0)