Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering aten.index.Tensor #559

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft

Lowering aten.index.Tensor #559

wants to merge 7 commits into from

Conversation

swimdi
Copy link
Contributor

@swimdi swimdi commented Dec 5, 2024

Ticket

#535

Problem description

for example, if input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
then its output is [[input[0][2], input[1][1], input[1][2]]]
So lowering aten.index to several getitem/reshape/cat

# before digest
input = torch.rand([3, 4, 5])
indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
output = aten.index.Tensor(input, indices)
return output_reshape

# after digest
input = torch.rand([3, 4, 5])
indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
indices_flatten = [aten.reshape(indices[0], (2,)), aten.reshape(indices[1], (2,))]
indexing0 = getitem(indices_flatten[0], 0), getitem(indices_flatten[1], 0))
output0 = getitem(input, indexing0) # input[0][2]
indexing1 = getitem(indices_flatten[0], 1), getitem(indices_flatten[1], 1))
output1 = getitem(input, indexing0) # input[1][1]
indexing2 = getitem(indices_flatten[0], 2), getitem(indices_flatten[1], 2))
output2 = getitem(input, indexing0) # input[1][2]
output = aten.cat(output0, output1, output2)
output_reshape = aten.reshape(output, [1, 3, 5])
return output_reshape

What's changed

  • Add DigestAtenOps and it digest aten.index to serveral torch ops

I'll merge #532 first then this PR

@swimdi swimdi self-assigned this Dec 5, 2024
@swimdi swimdi linked an issue Dec 5, 2024 that may be closed by this pull request
@swimdi swimdi changed the title Support lowering of index Support lowering aten.index.Tensor Dec 5, 2024
@swimdi swimdi changed the title Support lowering aten.index.Tensor Lowering aten.index.Tensor Dec 5, 2024
@ayerofieiev-tt
Copy link
Member

if (
src_node.target
not in TTNN_LAYOUT_CHANGE_OPS.union(
class NodeInputAligner:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to a different file?

Copy link
Contributor Author

@swimdi swimdi Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is #532 part, I'll merge that PR first and then this PR

@swimdi
Copy link
Contributor Author

swimdi commented Dec 6, 2024

It is not the same as index_fill, right?
https://github.com/tenstorrent/tt-metal/blob/main/ttnn/cpp/ttnn/operations/index_fill/index_fill_pybind.cpp

I think ttnn.index_fill is not same with aten.index.Tensor, former seems fill some specific value to the input tensor (not change the shape), latter is get part of input accroding to indices (change the shape)

@swimdi swimdi force-pushed the swimdi/stage4-index branch from 86f8693 to ddd7fe3 Compare December 6, 2024 04:22
return broadcasted_shape, broadcasted_indices

# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
# then output is [[input[0][2], input[1][1], input[1][2]]]
Copy link
Contributor

@jerrysky3 jerrysky3 Dec 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this example be [input[0][2], input[1][1], input[1][2]]? (and input[x][y] is 1-D tensor, so the concat results is a 2-D tensor)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually no, the indices is [[0, 1, 1]] not [0, 1, 1] so it has one extra rank

Comment on lines 1163 to 1165
g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape))
)
edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can add a new call_function_with_new_shape to GraphWrapper and set the new shape in the meta? So we don't need to update the meta separately

I'm thinking something like def call_function_with_new_shape(self, target, new_shape, args, kwargs)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ohoh good ideal! And I implemented by extend def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None)

@swimdi swimdi mentioned this pull request Dec 9, 2024
@swimdi swimdi force-pushed the swimdi/stage4-index branch from fd2ccb1 to 4485e83 Compare December 9, 2024 08:03
@swimdi swimdi added this pull request to the merge queue Dec 9, 2024
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to failed status checks Dec 9, 2024
@ayerofieiev-tt
Copy link
Member

@swimdi can you please share how is Tensor returned by index used later?
I see that in pytorch they don't seem to actually calculate much but rather build an iterator.
Maybe we can see the pattern of how index is being used and avoid it.

@swimdi
Copy link
Contributor Author

swimdi commented Dec 10, 2024

@swimdi can you please share how is Tensor returned by index used later?
I see that in pytorch they don't seem to actually calculate much but rather build an iterator.
Maybe we can see the pattern of how index is being used and avoid it.

Take beit for example, part of its code is

view_13 = torch.ops.aten.view.default(arg223_1, [-1])
index = torch.ops.aten.index.Tensor(arg1_1, [view_13]) # <= indexing here
view_14 = torch.ops.aten.view.default(index, [197, 197, -1])
permute_3 = torch.ops.aten.permute.default(view_14, [2, 0, 1])
clone_2 = torch.ops.aten.clone.default(permute_3, memory_format = torch.contiguous_format)
unsqueeze = torch.ops.aten.unsqueeze.default(clone_2, 0)
add = torch.ops.aten.add.Tensor(div, unsqueeze)
_softmax = torch.ops.aten._softmax.default(add, -1, False)
clone_3 = torch.ops.aten.clone.default(_softmax)
expand_3 = torch.ops.aten.expand.default(clone_3, [1, 12, 197, 197])
view_15 = torch.ops.aten.view.default(expand_3, [12, 197, 197])
bmm_1 = torch.ops.aten.bmm.default(view_15, view_16)

the output of aten.index.Tensor is then go through permute=>clone=>unsqueeze=>add=>softmax=>clone=>expand=>view and be the argument of bmm

I'm not sure how to avoid it, for this beit case, the input variation is

["Tensor<[732, 12]> self = ?", "List[Optional[Tensor]] indices = [<[38809]>]"]

This means the input shape is [732, 12], and based on the values in indices, it retrieves elements from the 732 elements (each of size 12) a total of 38,809 times. Finally, the output shape will be [38809, 12].

@swimdi swimdi marked this pull request as draft December 13, 2024 06:20
@swimdi swimdi force-pushed the swimdi/stage4-index branch from c2be067 to 7b4af97 Compare December 18, 2024 11:51
@swimdi swimdi force-pushed the swimdi/stage4-index branch from 7b4af97 to 61b2c39 Compare December 18, 2024 12:08
@ayerofieiev-tt
Copy link
Member

ayerofieiev-tt commented Dec 20, 2024

@swimdi can you find out what Pytorch function is lowered to this list of ATen ops? I have an impression that we can fuse this back to something that might be already present in TT-NN.

This looks like some Attention function. Maybe something from this list?

  • multi_head_attention_forward
  • index_select
  • gather

@ayerofieiev-tt
Copy link
Member

class BeitRelativePositionBias(nn.Module):
    def __init__(self, config: BeitConfig, window_size: tuple) -> None:
        super().__init__()
        self.window_size = window_size
        self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
        self.relative_position_bias_table = nn.Parameter(
            torch.zeros(self.num_relative_distance, config.num_attention_heads)
        )  # 2*Wh-1 * 2*Ww-1, nH
        # cls to token & token 2 cls & cls to cls

        # get pair-wise relative position index for each token inside the window
        coords_h = torch.arange(window_size[0])
        coords_w = torch.arange(window_size[1])
        coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))  # 2, Wh, Ww
        coords_flatten = torch.flatten(coords, 1)  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = relative_coords.permute(1, 2, 0).contiguous()  # Wh*Ww, Wh*Ww, 2
        relative_coords[:, :, 0] += window_size[0] - 1  # shift to start from 0
        relative_coords[:, :, 1] += window_size[1] - 1
        relative_coords[:, :, 0] *= 2 * window_size[1] - 1
        relative_position_index = torch.zeros(
            size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
        )
        relative_position_index[1:, 1:] = relative_coords.sum(-1)  # Wh*Ww, Wh*Ww
        relative_position_index[0, 0:] = self.num_relative_distance - 3
        relative_position_index[0:, 0] = self.num_relative_distance - 2
        relative_position_index[0, 0] = self.num_relative_distance - 1

        self.register_buffer("relative_position_index", relative_position_index, persistent=False)

    def forward(self) -> torch.Tensor:
        relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
            self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
        )  # Wh*Ww,Wh*Ww,nH

        return relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Ww

@swimdi
Copy link
Contributor Author

swimdi commented Dec 23, 2024

Yes, I think BeitRelativePositionBias will produce aten.index.Tensor, and actually is this line

self.relative_position_bias_table[self.relative_position_index.view(-1)]

and it can be represented as these code to reproduce

import torch
import ttnn
import torch_ttnn
device = ttnn.open_device(device_id=0)

class PatternModule(torch.nn.Module):
    def __init__(self):
        super().__init__()

    # represented by self.relative_position_bias_table[self.relative_position_index.view(-1)]
    def forward(self, input, indices) -> torch.Tensor:
        return input[indices]

m = PatternModule()
input = torch.rand([732, 12])
indices = torch.randint(low=0, high=732, size=[197*197])
option = torch_ttnn.TorchTtnnOption(device=device)
m = torch.compile(m, backend=torch_ttnn.backend, options=option)
output = m.forward(input, indices)
nodes = list(option._out_fx_graphs[0].nodes)
print(input.shape)
# torch.Size([732, 12])
print(output.shape)
# torch.Size([38809, 12])
print(nodes)
# [arg0_1, arg1_1, index_tensor, output]
print(nodes[2].target)
# aten.index.Tensor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

aten.index.Tensor
3 participants