Skip to content

Commit 710cea4

Browse files
ethansfngfacebook-github-bot
authored andcommitted
Walk transparent ops when extracting input quant params (#20139)
Summary: SAM 3's encoder input feeds a transparent shape op first — placeholder → reshape (patchify) → quantize_per_tensor — whereas the original code only recognized placeholder → quantize_per_tensor (a quantize directly on the input). The change walks through transparent ops (reshape/permute/transpose/etc.) from the input to reach the first quantize, so that indirected pattern resolves. Differential Revision: D107922730
1 parent 8e4fe08 commit 710cea4

1 file changed

Lines changed: 45 additions & 7 deletions

File tree

backends/cadence/aot/compiler_funcs.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,28 @@
2222

2323
logger: logging.Logger = logging.getLogger(__name__)
2424
QuantArgs = tuple[float, int, int, int, torch.dtype]
25+
TRANSPARENT_OPS: frozenset[torch._ops.OpOverloadPacket] = frozenset(
26+
{
27+
torch.ops.aten.view,
28+
torch.ops.aten.view_copy,
29+
torch.ops.aten._unsafe_view,
30+
torch.ops.aten.reshape,
31+
torch.ops.aten.permute,
32+
torch.ops.aten.permute_copy,
33+
torch.ops.aten.transpose,
34+
torch.ops.aten.transpose_copy,
35+
torch.ops.aten.squeeze,
36+
torch.ops.aten.squeeze_copy,
37+
torch.ops.aten.unsqueeze,
38+
torch.ops.aten.unsqueeze_copy,
39+
torch.ops.aten.slice,
40+
torch.ops.aten.slice_copy,
41+
torch.ops.aten.contiguous,
42+
torch.ops.aten.clone,
43+
torch.ops.aten.to,
44+
torch.ops.aten._to_copy,
45+
}
46+
)
2547

2648

2749
@torch.no_grad()
@@ -251,17 +273,27 @@ def extract_input_quant_params_from_graph(
251273
if not input_names:
252274
return quant_args
253275

276+
# Inputs are referenced by node name, which may be a placeholder or a node
277+
# that unpacks/derives the input (e.g. a `getitem` off a tuple input), so
278+
# look the start node up across all nodes -- not just placeholders.
279+
nodes_by_name = {n.name: n for n in module.graph.nodes}
280+
254281
for idx, name in enumerate(input_names):
255-
for node in module.graph.nodes:
256-
if node.op != "call_function":
282+
start = nodes_by_name.get(name)
283+
if start is None:
284+
continue
285+
seen: set[torch.fx.Node] = set()
286+
to_visit: list[torch.fx.Node] = list(start.users)
287+
while to_visit:
288+
node = to_visit.pop()
289+
if node in seen or node.op != "call_function":
257290
continue
258-
291+
seen.add(node)
292+
target_str = str(node.target)
259293
if (
260-
node.args
261-
and isinstance(node.args[0], torch.fx.Node)
262-
and node.args[0].name == name
294+
"quantize_per_tensor" in target_str
295+
and "dequantize" not in target_str
263296
and not node.name.startswith("_assert_tensor_metadata")
264-
and "quantize_per_tensor" in str(node.target)
265297
):
266298
args = node.args[1:]
267299
if len(args) >= 5:
@@ -274,6 +306,12 @@ def extract_input_quant_params_from_graph(
274306
)
275307
found_names.add(name)
276308
break
309+
target = node.target
310+
if (
311+
isinstance(target, torch._ops.OpOverload)
312+
and target.overloadpacket in TRANSPARENT_OPS
313+
):
314+
to_visit.extend(node.users)
277315

278316
missing_names = set(input_names) - found_names
279317
if missing_names:

0 commit comments

Comments
 (0)