Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lowering aten.index.Tensor #559

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions tests/lowering/misc/test_index.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
import torch_ttnn
import pytest
import ttnn

from tests.utils import assert_with_pcc


class IndexModule(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, input, indices):
return torch.ops.aten.index.Tensor(input, indices)


@pytest.mark.parametrize(
"input_shapes, indices",
[
((3, 4, 5), [[0, 1], [2, 1], [2, 4]]),
((3, 4, 5), [[0, 1], [2, 1]]),
((3, 4, 5), [[[0, 1]], [[2, 1]]]),
((3, 4, 5), [[[0, 1, 1]], [[2, 1, 2]]]),
((3, 4, 5), [[[0, 1, 1], [1, 1, 0]], [[2, 1, 2]]]), # broadcast
],
)
def test_index(device, input_shapes, indices):
m = IndexModule()
inputs = torch.rand(input_shapes, dtype=torch.bfloat16)
indices = [torch.tensor(index) for index in indices]
result_before = m.forward(inputs, indices)

option = torch_ttnn.TorchTtnnOption(device=device)
# option.gen_graphviz = True

# The compilation is lazy, so we need to run forward once to trigger the compilation
m = torch.compile(m, backend=torch_ttnn.backend, options=option)

result_after = m.forward(inputs, indices)
# option._out_fx_graphs[0].print_tabular()

# Check the graph has be rewritten and contain ttnn ops
nodes = [node.target for node in option._out_fx_graphs[0].nodes]
assert torch.ops.aten.index.Tensor not in nodes

# Check inference result
assert_with_pcc(result_before, result_after, pcc=0.99)
5 changes: 4 additions & 1 deletion tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(self, op_name: str, input_strings: List[str]):
"aten.index.Tensor": self._adjust_index_tensor,
"aten.index_put.default": self._adjust_index_tensor,
"aten._native_batch_norm_legit_no_training.default": self._adjust__native_batch_norm_legit_no_training_default,
# "aten._unsafe_index.Tensor": self._adjust_index_tensor,
"aten._unsafe_index.Tensor": self._adjust_index_tensor,
}

def _adjust_bitwise_not_default(self, input_vals):
Expand Down Expand Up @@ -503,6 +503,9 @@ def _adjust_index_tensor(self, input_vals):
new_indices = []
for i in range(len(indices)):
indice = indices[i]
if indice is None:
new_indices.append(None)
continue
new_indice = []
for j in range(len(indice)):
new_indice.append(torch.randint(0, self_shape[i], [1]))
Expand Down
140 changes: 136 additions & 4 deletions torch_ttnn/passes/lowering/to_tt_pass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
import ttnn
import math
import numpy as np
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch_ttnn.utils import (
GraphCleanup,
Expand All @@ -17,8 +18,10 @@

from torch.fx.passes.infra.pass_base import PassBase, PassResult
import torch.fx.traceback as fx_traceback
from torch._subclasses.fake_tensor import FakeTensorMode
from . import target_wrappers
from .to_tt_guard import can_lowering_to_ttnn
from operator import getitem

relational_scalar_ops = {
torch.ops.aten.eq.Scalar: ttnn.eq,
Expand Down Expand Up @@ -410,13 +413,19 @@ def __init__(self, node):
self.g = node.graph
self.node = node

def call_function(self, target, args=(), kwargs={}):
def call_function(self, target, args=(), kwargs={}, new_shape=None, new_dtype=None):
new_node = self.g.call_function(target, args, kwargs)
new_node.meta = self.node.meta
new_node.meta = self.node.meta.copy()
if hasattr(self.node.target, "_schema"):
new_node.meta["original_input_variations"] = metrics.collect_input_variation_from_node(self.node)
if target == ttnn.layer_norm:
new_node.meta["val"] = new_node.meta["val"][0]
if new_shape is not None or new_dtype is not None:
shape = new_shape if new_shape is not None else new_node.meta["val"].size()
dtype = new_dtype if new_dtype is not None else new_node.meta["val"].dtype
fake_mode = FakeTensorMode()
fake_tensor = fake_mode.from_tensor(torch.zeros(shape, dtype=dtype))
new_node.meta["val"] = fake_tensor
return new_node

def inserting_before(self, node):
Expand Down Expand Up @@ -683,6 +692,9 @@ def lower_binary_eltwise(fn, args):

return None

# if node.target == torch.ops.aten.reshape.default:
# return g.call_function(ttnn.reshape, args, kwargs)

if node.target == torch.ops.aten.squeeze.dim or node.target == torch.ops.aten.squeeze.default:
if use_less_ttnn_op_types or node.target == torch.ops.aten.squeeze.default:
# ttnn.squeeze does not support calling the OP without provided dim (torch.ops.aten.squeeze.default)
Expand Down Expand Up @@ -1178,7 +1190,7 @@ def lower_binary_eltwise(fn, args):
return gm


def decompose_aten_to_aten_ops(g: GraphWrapper, node):
def decompose_aten_to_aten_ops(gm: torch.fx.GraphModule, g: GraphWrapper, node):
args = node.args
kwargs = node.kwargs
if node.target == torch.ops.aten.full_like.default:
Expand All @@ -1192,6 +1204,126 @@ def decompose_aten_to_aten_ops(g: GraphWrapper, node):
new_kwargs["dtype"] = node.meta["val"].dtype
return g.call_function(torch.ops.aten.zeros.default, args=(target_shape, *args[2:]), kwargs=new_kwargs)

if node.target in [torch.ops.aten.index.Tensor, torch.ops.aten._unsafe_index.Tensor]:

def broadcast_indices(indices):
indices_shapes = [get_shape(gm, indices[i]) for i in range(len(indices))]
broadcasted_shape = torch.Size(np.broadcast_shapes(*indices_shapes))
broadcasted_indices = []
for i in range(len(indices)):
if indices_shapes[i] is None:
broadcasted_indices.append(None)
elif indices_shapes[i] == broadcasted_shape:
broadcasted_indices.append(indices[i])
else:
broadcasted_indices.append(
g.call_function(
torch.ops.aten.expand.default,
(indices[i], broadcasted_shape),
new_shape=broadcasted_shape,
new_dtype=indices[i].meta["val"].dtype,
)
)
return broadcasted_shape, broadcasted_indices

# for example, input.shape = (3, 4, 5), indices = [tensor([[0, 1, 1]]), tensor([[2, 1, 2]])]
# then output is [[input[0][2], input[1][1], input[1][2]]]
input_tensor, indices = args
if get_shape(gm, input_tensor) is None:
return None
for index in indices:
if index is None:
# slice(None) unhasable!
return None
if index is not None and get_shape(gm, index) is None:
return None
index_shape, indices = broadcast_indices(indices)
if index_shape.numel() > 256:
# cannot create too much op, or will cause
# runtime args targeting kernel reader_concat_stick_layout_interleaved_start_id on
# (x=0,y=0) are too large. Max allowable is 256
return None
input_shape = get_shape(gm, input_tensor)
num_index = len(indices)
if len(index_shape) != 1:
need_flatten = True
else:
need_flatten = False
index_size = index_shape.numel()
remained_shape = []
for i in range(len(indices)):
if indices[i] is None:
remained_shape.append(input_shape[i])
remained_shape += input_shape[num_index:]
remained_shape = torch.Size(remained_shape)
reshape_shape = index_shape + remained_shape
input_dtype = input_tensor.meta["val"].dtype
flatten_shape = torch.Size([index_size])

if need_flatten:
indices_flatten = []
for idx in indices:
if idx is None:
indices_flatten.append(None)
else:
indices_flatten.append(
g.call_function(
torch.ops.aten.reshape.default,
args=(idx, flatten_shape),
new_shape=flatten_shape,
new_dtype=idx.meta["val"].dtype,
)
)
indices = indices_flatten
output = []
for i in range(index_size):
indexing = []
for n in range(num_index):
if indices[n] is None:
# TODO: unhasable!
indexing.append(slice(None))
else:
indexing.append(
g.call_function(
getitem,
args=(indices[n], i),
new_shape=torch.Size([]),
new_dtype=indices[n].meta["val"].dtype,
)
)
output.append(
g.call_function(getitem, args=(input_tensor, indexing), new_shape=remained_shape, new_dtype=input_dtype)
)
# aten.cat cannot concat zero dim tensor
if len(remained_shape) == 0:
remained_shape = torch.Size([1])
output = [
g.call_function(
torch.ops.aten.reshape.default,
args=(o, remained_shape),
new_shape=remained_shape,
new_dtype=input_dtype,
)
for o in output
]
output_cat_shape = torch.Size([len(output)] + list(remained_shape))
output_cat = g.call_function(
torch.ops.aten.cat.default,
args=(output,),
new_shape=output_cat_shape,
new_dtype=input_dtype,
)
if need_flatten:
output_reshape = g.call_function(
torch.ops.aten.reshape.default,
args=(output_cat, reshape_shape),
new_shape=reshape_shape,
new_dtype=input_dtype,
)
return output_reshape
else:
return output_cat

return None


Expand All @@ -1203,7 +1335,7 @@ def rewrite_graph(gm: torch.fx.GraphModule, rewrite_node_fn) -> torch.fx.GraphMo
continue
g = GraphWrapper(node)
with g.inserting_before(node):
new_node = rewrite_node_fn(g, node)
new_node = rewrite_node_fn(gm, g, node)
if new_node is not None:
node.replace_all_uses_with(
new_node,
Expand Down
Loading