diff --git a/backends/qualcomm/builders/op_full.py b/backends/qualcomm/builders/op_full.py index 5ac2e95c57b..7a109ff0637 100644 --- a/backends/qualcomm/builders/op_full.py +++ b/backends/qualcomm/builders/op_full.py @@ -25,8 +25,9 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: + tensor_shape = list(self.get_tensor(node, node).shape) out_tensor = torch.full( - node.args[0], node.args[1], dtype=node.meta["val"].dtype + tensor_shape, node.args[1], dtype=node.meta["val"].dtype ) # since we can derive the constant value of current op in AoT stage diff --git a/backends/qualcomm/builders/op_full_like.py b/backends/qualcomm/builders/op_full_like.py index 66f80ecc80a..69a03d66b13 100644 --- a/backends/qualcomm/builders/op_full_like.py +++ b/backends/qualcomm/builders/op_full_like.py @@ -25,8 +25,8 @@ def define_node( node: torch.fx.Node, nodes_to_wrappers: Dict[torch.fx.Node, PyQnnManager.TensorWrapper], ) -> PyQnnManager.PyQnnOpWrapper: - in_tensor = node.args[0].meta["val"] - ref_tensor = torch.zeros(in_tensor.shape, dtype=in_tensor.dtype) + in_tensor = self.get_tensor(node, node) + ref_tensor = torch.zeros(list(in_tensor.shape), dtype=in_tensor.dtype) out_tensor = torch.full_like(ref_tensor, node.args[1]) # since we can derive the constant value of current op in AoT stage diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 12d5e0902db..7a918308d4e 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -875,6 +875,31 @@ def forward(self, x): return self.second(self.first(x)) +class ConvFull(torch.nn.Module): + def __init__(self, fill, full_shape): + super().__init__() + self.conv = torch.nn.Conv2d(8, 16, 3, padding=1) + self.fill = fill + self.full_shape = full_shape + + def forward(self, x): + y = self.conv(x) + c = torch.full(self.full_shape, self.fill, dtype=y.dtype) + return torch.cat([y, c], dim=1) + + +class ConvFullLike(torch.nn.Module): + def __init__(self, fill): + super().__init__() + self.conv = torch.nn.Conv2d(8, 16, 3, padding=1) + self.fill = fill + + def forward(self, x): + y = self.conv(x) + c = torch.full_like(y, self.fill) + return torch.cat([y, c], dim=1) + + class ConvTranspose1dSingle(torch.nn.Module): def __init__(self, bias=True, dilation=1): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index d76e3ea1df7..fbf9a1342f5 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -2344,6 +2344,17 @@ def test_qnn_backend_einsum_outer_product_relu(self): ) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full_layout_transformed(self): + full_shape = (1, 16, 4, 6) + module = ConvFull(0.5, full_shape) # noqa: F405 + sample_input = (torch.randn(1, 8, 4, 6),) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_full_like_layout_transformed(self): + module = ConvFullLike(0.5) # noqa: F405 + sample_input = (torch.randn(1, 8, 4, 6),) + self.lower_module_and_test_output(module, sample_input) + # TODO: Create a new UT class for passes specific checks def test_qnn_backend_lift_add_tensor(self): module = LiftAddTensor() # noqa: F405 @@ -5095,6 +5106,19 @@ def test_qnn_backend_einsum_outer_product_relu(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_full_layout_transformed(self): + full_shape = (1, 16, 4, 6) + module = ConvFull(0.5, full_shape) # noqa: F405 + sample_input = (torch.randn(1, 8, 4, 6),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + + def test_qnn_backend_full_like_layout_transformed(self): + module = ConvFullLike(0.5) # noqa: F405 + sample_input = (torch.randn(1, 8, 4, 6),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + @unittest.skipIf(is_qnn_sdk_version_less_than("2.35"), "UT pass after QNN 2.35") def test_qnn_backend_masked_softmax(self): if self.enable_x86_64: