Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions backends/cadence/aot/compiler_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,29 @@
QuantArgs = tuple[float, int, int, int, torch.dtype]


def extract_input_shapes_from_graph(
module: GraphModule,
) -> dict[int, tuple[int, ...]]:
"""
Extract input shapes from the FX graph placeholder nodes.

Returns a dict mapping input index to expected shape tuple.
"""
input_shapes: dict[int, tuple[int, ...]] = {}
idx = 0
for node in module.graph.nodes:
if node.op == "placeholder":
# Get the tensor_meta from the node if available
if "val" in node.meta:
val = node.meta["val"]
if isinstance(val, torch.Tensor):
input_shapes[idx] = tuple(val.shape)
elif hasattr(val, "shape"):
input_shapes[idx] = tuple(val.shape)
idx += 1
return input_shapes


@torch.no_grad()
def trace(
model: torch.nn.Module,
Expand Down Expand Up @@ -138,6 +161,9 @@ def __init__(
super().__init__()
self.module: GraphModule = module
self.quant_args: dict[int, QuantArgs] = {}
self.expected_shapes: dict[int, tuple[int, ...]] = (
extract_input_shapes_from_graph(module)
)

if input_args is not None:
logger.warning(
Expand All @@ -151,6 +177,18 @@ def __init__(

def forward(self, *args: torch.Tensor) -> Any:
"""Run inference, dequantizing configured inputs."""
# Validate input shapes for quantized inputs
for index in self.quant_args:
if index < len(args):
actual_shape = tuple(args[index].shape)
if index in self.expected_shapes:
expected_shape = self.expected_shapes[index]
if actual_shape != expected_shape:
raise ValueError(
f"Shape mismatch for quantized input at index {index}: "
f"expected {expected_shape}, got {actual_shape}"
)

dequantized_args = []
for index, node in enumerate(args):
if index in self.quant_args:
Expand Down
Loading