-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lowering aten.index.Tensor #559
base: main
Are you sure you want to change the base?
Conversation
It is not the same as index_fill, right? |
if ( | ||
src_node.target | ||
not in TTNN_LAYOUT_CHANGE_OPS.union( | ||
class NodeInputAligner: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move to a different file?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is #532 part, I'll merge that PR first and then this PR
I think |
86f8693
to
ddd7fe3
Compare
return broadcasted_shape, broadcasted_indices | ||
|
||
# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])] | ||
# then output is [[input[0][2], input[1][1], input[1][2]]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this example be [input[0][2], input[1][1], input[1][2]]
? (and input[x][y]
is 1-D tensor, so the concat results is a 2-D tensor)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually no, the indices is [[0, 1, 1]]
not [0, 1, 1]
so it has one extra rank
g.call_function(torch.ops.aten.expand.default, (indices[i], broadcasted_shape)) | ||
) | ||
edit_meta_val(broadcasted_indices[-1], broadcasted_shape, indices[i].meta["val"].dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe we can add a new call_function_with_new_shape
to GraphWrapper
and set the new shape in the meta? So we don't need to update the meta separately
I'm thinking something like def call_function_with_new_shape(self, target, new_shape, args, kwargs)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ohoh good ideal! And I implemented by extend def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None)
fd2ccb1
to
4485e83
Compare
@swimdi can you please share how is Tensor returned by |
Take
the output of I'm not sure how to avoid it, for this
This means the input shape is |
c2be067
to
7b4af97
Compare
7b4af97
to
61b2c39
Compare
@swimdi can you find out what Pytorch function is lowered to this list of ATen ops? I have an impression that we can fuse this back to something that might be already present in TT-NN. This looks like some Attention function. Maybe something from this list?
|
class BeitRelativePositionBias(nn.Module):
def __init__(self, config: BeitConfig, window_size: tuple) -> None:
super().__init__()
self.window_size = window_size
self.num_relative_distance = (2 * window_size[0] - 1) * (2 * window_size[1] - 1) + 3
self.relative_position_bias_table = nn.Parameter(
torch.zeros(self.num_relative_distance, config.num_attention_heads)
) # 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij")) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = torch.zeros(
size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
)
relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
relative_position_index[0, 0:] = self.num_relative_distance - 3
relative_position_index[0:, 0] = self.num_relative_distance - 2
relative_position_index[0, 0] = self.num_relative_distance - 1
self.register_buffer("relative_position_index", relative_position_index, persistent=False)
def forward(self) -> torch.Tensor:
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.window_size[0] * self.window_size[1] + 1, self.window_size[0] * self.window_size[1] + 1, -1
) # Wh*Ww,Wh*Ww,nH
return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww |
Yes, I think
and it can be represented as these code to reproduce
|
Ticket
#535
Problem description
for example, if
input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
then its output is
[[input[0][2], input[1][1], input[1][2]]]
So lowering
aten.index
to severalgetitem/reshape/cat
What's changed
DigestAtenOps
and it digest aten.index to serveral torch opsI'll merge #532 first then this PR