From 4245fd71bde4b91b91147a86ef38605fedb6a064 Mon Sep 17 00:00:00 2001 From: swimdi Date: Thu, 5 Dec 2024 18:17:30 +0800 Subject: [PATCH 1/6] Add lowering of aten.index.Tensor --- torch_ttnn/passes/lowering/to_tt_pass.py | 46 ++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 54185a928..97c2c9b61 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -19,6 +19,7 @@ import torch.fx.traceback as fx_traceback from . import target_wrappers from .to_tt_guard import can_lowering_to_ttnn +from operator import getitem relational_scalar_ops = { torch.ops.aten.eq.Scalar: ttnn.eq, @@ -1131,12 +1132,57 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): return gm +def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + g = gm.graph + nodes = list(gm.graph.nodes) + for node in nodes: + + def rewrite_node(node): + args = node.args + kwargs = node.kwargs + + if node.target == torch.ops.aten.index.Tensor: + # for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])] + # then output is [[input[0][2], input[1][1], input[1][2]]] + input_tensor, indices = args + input_shape = get_shape(input_tensor) + num_index = len(indices) + # TODO: support broadcasting + index_shape = get_shape(indices[0]) + index_size = index_shape.numel() + remained_shape = input_shape[num_index:] + reshape_shape = index_shape + remained_shape + indices_flatten = [g.call_function(torch.ops.aten.flatten, args=(idx,)) for idx in indices] + output = [] + for i in range(index_size): + indexing = [g.call_function(getitem, args=(indices_flatten[n], i)) for n in range(num_index)] + output.append(g.call_function(getitem, args=(input_tensor, indexing))) + # aten.cat cannot concat zero dim tensor + if len(remained_shape) == 0: + output = [g.call_function(torch.ops.aten.reshape, args=(o, [1])) for o in output] + output_cat = g.call_function(torch.ops.aten.cat, args=(output,)) + output_reshape = g.call_function(torch.ops.aten.reshape, args=(output_cat, reshape_shape)) + return output_reshape + + with g.inserting_before(node): + new_node = rewrite_node(node) + if new_node is not None: + node.replace_all_uses_with( + new_node, + delete_user_cb=lambda node: node != new_node, + ) + + gm = GraphCleanup(gm) + return gm + + class ToTtPass(PassBase): def __init__(self, device, use_less_ttnn_op_types): self.device = device self.use_less_ttnn_op_types = use_less_ttnn_op_types def call(self, gm: torch.fx.GraphModule): + gm = DigestAtenOps(gm) # Replace more patterns with torch.fx.Transformer gm = ReplaceMoreTt(gm, self.device, self.use_less_ttnn_op_types).transform() From 3462136fa091078aaa377611bbc94d58de777257 Mon Sep 17 00:00:00 2001 From: swimdi Date: Thu, 5 Dec 2024 18:47:01 +0800 Subject: [PATCH 2/6] Implement broadcast_indices, add test_index.py --- tests/lowering/misc/test_index.py | 47 ++++++++++++++++++++++++ torch_ttnn/passes/lowering/to_tt_pass.py | 19 +++++++++- 2 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 tests/lowering/misc/test_index.py diff --git a/tests/lowering/misc/test_index.py b/tests/lowering/misc/test_index.py new file mode 100644 index 000000000..01b3780c7 --- /dev/null +++ b/tests/lowering/misc/test_index.py @@ -0,0 +1,47 @@ +import torch +import torch_ttnn +import pytest +import ttnn + +from tests.utils import assert_with_pcc + + +class IndexModule(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, input, indices): + return torch.ops.aten.index.Tensor(input, indices) + + +@pytest.mark.parametrize( + "input_shapes, indices", + [ + ((3, 4, 5), [[0, 1], [2, 1], [2, 4]]), + ((3, 4, 5), [[0, 1], [2, 1]]), + ((3, 4, 5), [[[0, 1]], [[2, 1]]]), + ((3, 4, 5), [[[0, 1, 1]], [[2, 1, 2]]]), + ((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]]), # broadcast + ], +) +def test_select(device, input_shapes, indices): + m = IndexModule() + inputs = torch.rand(input_shapes, dtype=torch.bfloat16) + indices = [torch.tensor(index) for index in indices] + result_before = m.forward(inputs, indices) + + option = torch_ttnn.TorchTtnnOption(device=device) + # option.gen_graphviz = True + + # The compilation is lazy, so we need to run forward once to trigger the compilation + m = torch.compile(m, backend=torch_ttnn.backend, options=option) + + result_after = m.forward(inputs, indices) + # option._out_fx_graphs[0].print_tabular() + + # Check the graph has be rewritten and contain ttnn ops + nodes = [node.target for node in option._out_fx_graphs[0].nodes] + assert torch.ops.aten.index.Tensor not in nodes + + # Check inference result + assert_with_pcc(result_before, result_after, pcc=0.99) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index 97c2c9b61..b884076b6 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1142,13 +1142,28 @@ def rewrite_node(node): kwargs = node.kwargs if node.target == torch.ops.aten.index.Tensor: + + def broadcast_indices(indices): + import numpy as np + + indices_shapes = [get_shape(indices[i]) for i in range(len(indices))] + broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes)) + broadcasted_indices = [] + for i in range(len(indices)): + if indices_shapes[i] == broadcasted_shape: + broadcasted_indices.append(indices[i]) + else: + broadcasted_indices.append( + g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape)) + ) + return broadcasted_shape, broadcasted_indices + # for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])] # then output is [[input[0][2], input[1][1], input[1][2]]] input_tensor, indices = args + index_shape, indices = broadcast_indices(indices) input_shape = get_shape(input_tensor) num_index = len(indices) - # TODO: support broadcasting - index_shape = get_shape(indices[0]) index_size = index_shape.numel() remained_shape = input_shape[num_index:] reshape_shape = index_shape + remained_shape From ddd7fe3a9b1d99980b763bc30a88760d26530fbd Mon Sep 17 00:00:00 2001 From: swimdi Date: Thu, 5 Dec 2024 20:38:48 +0800 Subject: [PATCH 3/6] Make sure meta val of aten.index is correct --- torch_ttnn/passes/lowering/to_tt_pass.py | 38 ++++++++++++++++++++---- 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b884076b6..def9b1368 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -17,6 +17,7 @@ from torch.fx.passes.infra.pass_base import PassBase, PassResult import torch.fx.traceback as fx_traceback +from torch._subclasses.fake_tensor import FakeTensorMode from . import target_wrappers from .to_tt_guard import can_lowering_to_ttnn from operator import getitem @@ -447,7 +448,7 @@ def __init__(self, node): def call_function(self, target, args=(), kwargs={}): new_node = self.g.call_function(target, args, kwargs) - new_node.meta = self.node.meta + new_node.meta = self.node.meta.copy() if hasattr(self.node.target, "_schema"): new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node) if target == ttnn.layer_norm: @@ -1133,9 +1134,9 @@ def batch_norm_inference(input, weight, bias, mean, var, momentum, eps): def DigestAtenOps(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: - g = gm.graph nodes = list(gm.graph.nodes) for node in nodes: + g = GraphWrapper(node) def rewrite_node(node): args = node.args @@ -1143,6 +1144,11 @@ def rewrite_node(node): if node.target == torch.ops.aten.index.Tensor: + def edit_meta_val(node, shape, dtype): + fake_mode = FakeTensorMode() + fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) + node.meta["val"] = fake_tensor + def broadcast_indices(indices): import numpy as np @@ -1156,27 +1162,47 @@ def broadcast_indices(indices): broadcasted_indices.append( g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape)) ) + edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype) return broadcasted_shape, broadcasted_indices # for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])] # then output is [[input[0][2], input[1][1], input[1][2]]] input_tensor, indices = args + if get_shape(input_tensor) is None: + return None + if None in [get_shape(indices[i]) for i in range(len(indices))]: + return None index_shape, indices = broadcast_indices(indices) input_shape = get_shape(input_tensor) num_index = len(indices) index_size = index_shape.numel() remained_shape = input_shape[num_index:] reshape_shape = index_shape + remained_shape - indices_flatten = [g.call_function(torch.ops.aten.flatten, args=(idx,)) for idx in indices] + input_dtype = input_tensor.meta["val"].dtype + flatten_shape = torch.Size([index_size]) + indices_flatten = [ + g.call_function(torch.ops.aten.reshape.default, args=(idx, flatten_shape)) for idx in indices + ] + for i in range(len(indices_flatten)): + edit_meta_val(indices_flatten[i], flatten_shape, indices[i].meta["val"].dtype) output = [] for i in range(index_size): indexing = [g.call_function(getitem, args=(indices_flatten[n], i)) for n in range(num_index)] + for n in range(num_index): + edit_meta_val(indexing[n], torch.Size([]), indices_flatten[n].meta["val"].dtype) output.append(g.call_function(getitem, args=(input_tensor, indexing))) + edit_meta_val(output[-1], remained_shape, input_dtype) # aten.cat cannot concat zero dim tensor if len(remained_shape) == 0: - output = [g.call_function(torch.ops.aten.reshape, args=(o, [1])) for o in output] - output_cat = g.call_function(torch.ops.aten.cat, args=(output,)) - output_reshape = g.call_function(torch.ops.aten.reshape, args=(output_cat, reshape_shape)) + remained_shape = torch.Size([1]) + output = [g.call_function(torch.ops.aten.reshape.default, args=(o, remained_shape)) for o in output] + for o in output: + edit_meta_val(o, remained_shape, input_dtype) + output_cat = g.call_function(torch.ops.aten.cat.default, args=(output,)) + output_cat_shape = torch.Size([len(output)] + list(remained_shape)) + edit_meta_val(output_cat, output_cat_shape, input_dtype) + output_reshape = g.call_function(torch.ops.aten.reshape.default, args=(output_cat, reshape_shape)) + edit_meta_val(output_reshape, reshape_shape, input_dtype) return output_reshape with g.inserting_before(node): From 4485e8317154a2f3bdf0c2ad6bd9ad4093a0e254 Mon Sep 17 00:00:00 2001 From: swimdi Date: Mon, 9 Dec 2024 16:00:16 +0800 Subject: [PATCH 4/6] move edit_meta_val to call_function --- tests/lowering/misc/test_index.py | 2 +- torch_ttnn/passes/lowering/to_tt_pass.py | 81 ++++++++++++++++-------- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/tests/lowering/misc/test_index.py b/tests/lowering/misc/test_index.py index 01b3780c7..52a9e7b36 100644 --- a/tests/lowering/misc/test_index.py +++ b/tests/lowering/misc/test_index.py @@ -24,7 +24,7 @@ def forward(self, input, indices): ((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]]), # broadcast ], ) -def test_select(device, input_shapes, indices): +def test_index(device, input_shapes, indices): m = IndexModule() inputs = torch.rand(input_shapes, dtype=torch.bfloat16) indices = [torch.tensor(index) for index in indices] diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index def9b1368..f83c4519d 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1,6 +1,7 @@ import torch import ttnn import math +import numpy as np from torch._subclasses.fake_tensor import unset_fake_temporarily from torch_ttnn.utils import ( GraphCleanup, @@ -446,13 +447,19 @@ def __init__(self, node): self.g = node.graph self.node = node - def call_function(self, target, args=(), kwargs={}): + def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None): new_node = self.g.call_function(target, args, kwargs) new_node.meta = self.node.meta.copy() if hasattr(self.node.target, "_schema"): new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node) if target == ttnn.layer_norm: new_node.meta["val"] = new_node.meta["val"][0] + if new_shape is not None or new_dtype is not None: + shape = new_shape if new_shape is not None else new_node.meta["val"].size() + dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype + fake_mode = FakeTensorMode() + fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) + new_node.meta["val"] = fake_tensor return new_node def inserting_before(self, node): @@ -1144,14 +1151,7 @@ def rewrite_node(node): if node.target == torch.ops.aten.index.Tensor: - def edit_meta_val(node, shape, dtype): - fake_mode = FakeTensorMode() - fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype)) - node.meta["val"] = fake_tensor - def broadcast_indices(indices): - import numpy as np - indices_shapes = [get_shape(indices[i]) for i in range(len(indices))] broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes)) broadcasted_indices = [] @@ -1160,9 +1160,13 @@ def broadcast_indices(indices): broadcasted_indices.append(indices[i]) else: broadcasted_indices.append( - g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape)) + g.call_function( + torch.ops.aten.expand.default, + (indices[i], broadcasted_shape), + new_shape=broadcasted_shape, + new_dtype=indices[i].meta["val"].dtype, + ) ) - edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype) return broadcasted_shape, broadcasted_indices # for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])] @@ -1181,28 +1185,55 @@ def broadcast_indices(indices): input_dtype = input_tensor.meta["val"].dtype flatten_shape = torch.Size([index_size]) indices_flatten = [ - g.call_function(torch.ops.aten.reshape.default, args=(idx, flatten_shape)) for idx in indices + g.call_function( + torch.ops.aten.reshape.default, + args=(idx, flatten_shape), + new_shape=flatten_shape, + new_dtype=idx.meta["val"].dtype, + ) + for idx in indices ] - for i in range(len(indices_flatten)): - edit_meta_val(indices_flatten[i], flatten_shape, indices[i].meta["val"].dtype) output = [] for i in range(index_size): - indexing = [g.call_function(getitem, args=(indices_flatten[n], i)) for n in range(num_index)] - for n in range(num_index): - edit_meta_val(indexing[n], torch.Size([]), indices_flatten[n].meta["val"].dtype) - output.append(g.call_function(getitem, args=(input_tensor, indexing))) - edit_meta_val(output[-1], remained_shape, input_dtype) + indexing = [ + g.call_function( + getitem, + args=(indices_flatten[n], i), + new_shape=torch.Size([]), + new_dtype=indices_flatten[n].meta["val"].dtype, + ) + for n in range(num_index) + ] + output.append( + g.call_function( + getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype + ) + ) # aten.cat cannot concat zero dim tensor if len(remained_shape) == 0: remained_shape = torch.Size([1]) - output = [g.call_function(torch.ops.aten.reshape.default, args=(o, remained_shape)) for o in output] - for o in output: - edit_meta_val(o, remained_shape, input_dtype) - output_cat = g.call_function(torch.ops.aten.cat.default, args=(output,)) + output = [ + g.call_function( + torch.ops.aten.reshape.default, + args=(o, remained_shape), + new_shape=remained_shape, + new_dtype=input_dtype, + ) + for o in output + ] output_cat_shape = torch.Size([len(output)] + list(remained_shape)) - edit_meta_val(output_cat, output_cat_shape, input_dtype) - output_reshape = g.call_function(torch.ops.aten.reshape.default, args=(output_cat, reshape_shape)) - edit_meta_val(output_reshape, reshape_shape, input_dtype) + output_cat = g.call_function( + torch.ops.aten.cat.default, + args=(output,), + new_shape=output_cat_shape, + new_dtype=input_dtype, + ) + output_reshape = g.call_function( + torch.ops.aten.reshape.default, + args=(output_cat, reshape_shape), + new_shape=reshape_shape, + new_dtype=input_dtype, + ) return output_reshape with g.inserting_before(node): From 61b2c3929caf96ac1f6e54e14dd592209b6fd29e Mon Sep 17 00:00:00 2001 From: swimdi Date: Wed, 18 Dec 2024 19:01:39 +0800 Subject: [PATCH 5/6] Try to support aten.index's indices has None but failed of slice(None) unhashable --- tests/utils.py | 5 +- torch_ttnn/passes/lowering/to_tt_pass.py | 75 ++++++++++++++++-------- 2 files changed, 56 insertions(+), 24 deletions(-) diff --git a/tests/utils.py b/tests/utils.py index 46b18d7aa..ee33e79a4 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -423,7 +423,7 @@ def __init__(self, op_name: str, input_strings: List[str]): "aten.index.Tensor": self._adjust_index_tensor, "aten.index_put.default": self._adjust_index_tensor, "aten._native_batch_norm_legit_no_training.default": self._adjust__native_batch_norm_legit_no_training_default, - # "aten._unsafe_index.Tensor": self._adjust_index_tensor, + "aten._unsafe_index.Tensor": self._adjust_index_tensor, } def _adjust_bitwise_not_default(self, input_vals): @@ -503,6 +503,9 @@ def _adjust_index_tensor(self, input_vals): new_indices = [] for i in range(len(indices)): indice = indices[i] + if indice is None: + new_indices.append(None) + continue new_indice = [] for j in range(len(indice)): new_indice.append(torch.randint(0, self_shape[i], [1])) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index ac73ec5e8..b71ba6c44 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -692,6 +692,9 @@ def lower_binary_eltwise(fn, args): return None + # if node.target == torch.ops.aten.reshape.default: + # return g.call_function(ttnn.reshape, args, kwargs) + if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default: if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default: # ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default) @@ -1201,14 +1204,16 @@ def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node): new_kwargs["dtype"] = node.meta["val"].dtype return g.call_function(torch.ops.aten.zeros.default, args=(target_shape, *args[2:]), kwargs=new_kwargs) - if node.target == torch.ops.aten.index.Tensor: + if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]: def broadcast_indices(indices): indices_shapes = [get_shape(gm, indices[i]) for i in range(len(indices))] broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes)) broadcasted_indices = [] for i in range(len(indices)): - if indices_shapes[i] == broadcasted_shape: + if indices_shapes[i] is None: + broadcasted_indices.append(None) + elif indices_shapes[i] == broadcasted_shape: broadcasted_indices.append(indices[i]) else: broadcasted_indices.append( @@ -1226,36 +1231,60 @@ def broadcast_indices(indices): input_tensor, indices = args if get_shape(gm, input_tensor) is None: return None - if None in [get_shape(gm, indices[i]) for i in range(len(indices))]: - return None + for index in indices: + if index is None: + # slice(None) unhasable! + return None + if index is not None and get_shape(gm, index) is None: + return None index_shape, indices = broadcast_indices(indices) + if index_shape.numel() > 256: + # cannot create too much op, or will cause + # runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on + # (x=0,y=0) are too large. Max allowable is 256 + return None input_shape = get_shape(gm, input_tensor) num_index = len(indices) index_size = index_shape.numel() - remained_shape = input_shape[num_index:] + remained_shape = [] + for i in range(len(indices)): + if indices[i] is None: + remained_shape.append(input_shape[i]) + remained_shape += input_shape[num_index:] + remained_shape = torch.Size(remained_shape) reshape_shape = index_shape + remained_shape input_dtype = input_tensor.meta["val"].dtype flatten_shape = torch.Size([index_size]) - indices_flatten = [ - g.call_function( - torch.ops.aten.reshape.default, - args=(idx, flatten_shape), - new_shape=flatten_shape, - new_dtype=idx.meta["val"].dtype, - ) - for idx in indices - ] + + indices_flatten = [] + for idx in indices: + if idx is None: + indices_flatten.append(None) + else: + indices_flatten.append( + g.call_function( + torch.ops.aten.reshape.default, + args=(idx, flatten_shape), + new_shape=flatten_shape, + new_dtype=idx.meta["val"].dtype, + ) + ) output = [] for i in range(index_size): - indexing = [ - g.call_function( - getitem, - args=(indices_flatten[n], i), - new_shape=torch.Size([]), - new_dtype=indices_flatten[n].meta["val"].dtype, - ) - for n in range(num_index) - ] + indexing = [] + for n in range(num_index): + if indices_flatten[n] is None: + # TODO: unhasable! + indexing.append(slice(None)) + else: + indexing.append( + g.call_function( + getitem, + args=(indices_flatten[n], i), + new_shape=torch.Size([]), + new_dtype=indices_flatten[n].meta["val"].dtype, + ) + ) output.append( g.call_function(getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype) ) From a7347dd720ae5ce40c2ff3865a3fa5edcb17ca68 Mon Sep 17 00:00:00 2001 From: swimdi Date: Wed, 18 Dec 2024 21:00:37 +0800 Subject: [PATCH 6/6] Skip some reshape case --- torch_ttnn/passes/lowering/to_tt_pass.py | 53 ++++++++++++++---------- 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/torch_ttnn/passes/lowering/to_tt_pass.py b/torch_ttnn/passes/lowering/to_tt_pass.py index b71ba6c44..8e8b0017b 100644 --- a/torch_ttnn/passes/lowering/to_tt_pass.py +++ b/torch_ttnn/passes/lowering/to_tt_pass.py @@ -1245,6 +1245,10 @@ def broadcast_indices(indices): return None input_shape = get_shape(gm, input_tensor) num_index = len(indices) + if len(index_shape) != 1: + need_flatten = True + else: + need_flatten = False index_size = index_shape.numel() remained_shape = [] for i in range(len(indices)): @@ -1256,33 +1260,35 @@ def broadcast_indices(indices): input_dtype = input_tensor.meta["val"].dtype flatten_shape = torch.Size([index_size]) - indices_flatten = [] - for idx in indices: - if idx is None: - indices_flatten.append(None) - else: - indices_flatten.append( - g.call_function( - torch.ops.aten.reshape.default, - args=(idx, flatten_shape), - new_shape=flatten_shape, - new_dtype=idx.meta["val"].dtype, + if need_flatten: + indices_flatten = [] + for idx in indices: + if idx is None: + indices_flatten.append(None) + else: + indices_flatten.append( + g.call_function( + torch.ops.aten.reshape.default, + args=(idx, flatten_shape), + new_shape=flatten_shape, + new_dtype=idx.meta["val"].dtype, + ) ) - ) + indices = indices_flatten output = [] for i in range(index_size): indexing = [] for n in range(num_index): - if indices_flatten[n] is None: + if indices[n] is None: # TODO: unhasable! indexing.append(slice(None)) else: indexing.append( g.call_function( getitem, - args=(indices_flatten[n], i), + args=(indices[n], i), new_shape=torch.Size([]), - new_dtype=indices_flatten[n].meta["val"].dtype, + new_dtype=indices[n].meta["val"].dtype, ) ) output.append( @@ -1307,13 +1313,16 @@ def broadcast_indices(indices): new_shape=output_cat_shape, new_dtype=input_dtype, ) - output_reshape = g.call_function( - torch.ops.aten.reshape.default, - args=(output_cat, reshape_shape), - new_shape=reshape_shape, - new_dtype=input_dtype, - ) - return output_reshape + if need_flatten: + output_reshape = g.call_function( + torch.ops.aten.reshape.default, + args=(output_cat, reshape_shape), + new_shape=reshape_shape, + new_dtype=input_dtype, + ) + return output_reshape + else: + return output_cat return None