Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Tripy] Add __len__ for tp.Shape and infer length statically when possible #92

Merged
merged 6 commits into from
Aug 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 95 additions & 3 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def test_shape(self, values):
s = tp.Shape(values)

assert isinstance(s, tp.Shape)
assert len(s) == len(values)
assert s.trace_tensor.producer.inputs == []
assert cp.from_dlpack(s).get().tolist() == values

def test_empty_shape(self):
s = tp.Shape([])

assert isinstance(s, tp.Shape)
assert len(s) == 0
assert s.trace_tensor.producer.inputs == []
assert cp.from_dlpack(s).get().tolist() == []

Expand All @@ -60,6 +62,7 @@ def test_constructor_from_tensor(self, values):
s = tp.Shape(t)

assert isinstance(s, tp.Shape)
assert len(s) == len(values)
assert s.trace_tensor.producer.inputs == []
# they should be the same underlying value
assert s.trace_tensor == t.trace_tensor
Expand Down Expand Up @@ -100,6 +103,12 @@ def test_plus_override(self, values, other_values):
assert isinstance(new_shape.trace_tensor.producer, Concatenate)
assert cp.from_dlpack(new_shape).get().tolist() == values + appended

def test_len_concatenation(self, values):
s = tp.Shape(values)
# we are testing that the length is *inferred*, so do not execute the concatenation
c = s + s
assert len(c) == 2 * len(values)

def test_explicit_addition(self, values):
from tripy.frontend.trace.ops.binary_elementwise import BinaryElementwise

Expand All @@ -110,6 +119,14 @@ def test_explicit_addition(self, values):
assert res.trace_tensor.producer.kind == BinaryElementwise.Kind.SUM
assert cp.from_dlpack(res).get().tolist() == [2 * v for v in values]

def test_len_binary_op(self, values):
s = tp.Shape(values)
res = s.add(tp.Shape(values))
assert len(res) == len(values)

res = s * 2
assert len(res) == len(values)

def test_shape_op(self, values):
from tripy.frontend.trace.ops.shape import Shape

Expand All @@ -120,6 +137,11 @@ def test_shape_op(self, values):
assert isinstance(s.trace_tensor.producer, Shape)
assert cp.from_dlpack(s).get().tolist() == [len(values)]

def test_len_shape_op(self, values):
t = tp.Tensor(values)
s = t.shape
assert len(s) == 1

def test_flip(self, values):
from tripy.frontend.trace.ops.flip import Flip

Expand All @@ -130,6 +152,11 @@ def test_flip(self, values):
assert isinstance(flipped_shape.trace_tensor.producer, Flip)
assert cp.from_dlpack(flipped_shape).get().tolist() == values[::-1]

def test_len_flip(self, values):
s = tp.Shape(values)
flipped = tp.flip(s, dims=0)
assert len(flipped) == len(values)

def test_expand(self):
from tripy.frontend.trace.ops.expand import Expand

Expand All @@ -141,6 +168,11 @@ def test_expand(self):
assert isinstance(expanded.trace_tensor.producer, Expand)
assert cp.from_dlpack(expanded).get().tolist() == [1, 1, 1]

def test_len_expand(self):
s = tp.Shape([1])
expanded = tp.expand(s, (3,))
assert len(expanded) == 3

def test_gather(self, values):
from tripy.frontend.trace.ops.gather import Gather

Expand All @@ -150,6 +182,11 @@ def test_gather(self, values):
assert isinstance(s2.trace_tensor.producer, Gather)
assert cp.from_dlpack(s2).get().tolist() == [values[0], values[-1]]

def test_len_gather(self, values):
s = tp.Shape(values)
gathered = tp.gather(s, 0, tp.Tensor([0, len(values) - 1]))
assert len(gathered) == 2

def test_matmul(self, values):
s1 = tp.Shape(values)
s2 = tp.Shape(values)
Expand Down Expand Up @@ -199,6 +236,38 @@ def test_slice_range(self, values):
assert isinstance(dims.trace_tensor.producer, Slice)
assert cp.from_dlpack(dims).get().tolist() == values[1:]

@pytest.mark.parametrize(
"slice_value",
[
slice(0, 2),
slice(0, 1),
slice(1, 3),
slice(0, 3, 2),
slice(1, 3, 2),
slice(1, 4, 2),
slice(1, 4, 3), # should select only one
slice(1, None, 200), # selects only start point
# some with negative strides
slice(None, None, -1),
slice(None, None, -2),
slice(4, 0, -1),
slice(2, 0, -1),
slice(2, 1, -1),
# check the clamping behavior
slice(-10, 20),
slice(10, -20, -1),
# invalid bounds (length 0 result)
slice(0, 4, -1),
slice(4, 0),
slice(2, 2),
],
)
def test_slice_len(self, slice_value):
# checking consistency against Python list
values = [1, 2, 3, 4]
s1 = tp.Shape(values)
assert len(s1[slice_value]) == len(values[slice_value])

def test_reduce(self, values):
from tripy.frontend.trace.ops.reduce import Reduce

Expand Down Expand Up @@ -263,6 +332,11 @@ def test_expand_higher_rank_not_wrapped(self):
assert not isinstance(e, tp.Shape)
assert cp.from_dlpack(e).get().tolist() == [[1] for _ in range(3)]

def test_cast_len(self, values):
s = tp.Shape(values)
cast = tp.cast(s, tp.int32)
assert len(cast) == len(values)

def test_split(self, values):
s = tp.Shape(values)
outputs = tp.split(s, len(values))
Expand All @@ -271,6 +345,19 @@ def test_split(self, values):
assert isinstance(output, tp.Shape)
assert cp.from_dlpack(output).get().tolist() == [values[i]]

def test_split_len(self, values):
s = tp.Shape(values)
outputs = tp.split(s, len(values))
for output in outputs:
assert len(output) == 1

def test_split_len_intervals(self):
s = tp.Shape([1, 2, 3, 4, 5])
outputs = tp.split(s, [1, 4])
assert len(outputs[0]) == 1 # 0:1
assert len(outputs[1]) == 3 # 1:4
assert len(outputs[2]) == 1 # 4:5

def test_where(self, values):
from tripy.frontend.trace.ops.where import Where

Expand All @@ -284,6 +371,13 @@ def test_where(self, values):
assert isinstance(res.trace_tensor.producer, Where)
assert cp.from_dlpack(res).get().tolist() == [0 if values[i] < 2 else values[i] for i in range(len(values))]

def test_where_len(self, values):
s1 = tp.Shape(values)
s2 = tp.Shape([0 for _ in values])
cond = tp.Tensor([i >= 1 for i in range(len(values))], dtype=tp.bool)
res = tp.where(cond, s1, s2)
assert len(res) == len(values)

def test_invalid_input_dtype(self):
with raises(tp.TripyException, match="Data has incorrect dtype"):
_ = tp.Shape(np.array([2.0, 3.0], dtype=np.float32))
Expand Down Expand Up @@ -343,9 +437,7 @@ def test_unary_elementwise_fails_at_run_time(self, values):
v = tp.exp(tp.Shape(values))
with raises(
tp.TripyException,
match=(
"'stablehlo.exponential' op operand #0 must be ranked tensor of"
),
match=("'stablehlo.exponential' op operand #0 must be ranked tensor of"),
):
v.eval()

Expand Down
6 changes: 6 additions & 0 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,9 @@ def __eq__(self, other):
from tripy.frontend.trace.ops.reduce import all

return bool(all(self.as_tensor() == other.as_tensor()))

# __len__ for shapes gives the number of dims in the shape, i.e., the first dimension of the shape's shape
def __len__(self):
from tripy.frontend.trace.ops import utils as op_utils

return op_utils.get_trace_shape(self.trace_tensor)[0]
16 changes: 12 additions & 4 deletions tripy/tripy/frontend/trace/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[
raise_error(
f"Error processing shape inputs in operator {cls.__name__}{custom_err}\n(Shape input indices: {shape_arg_msg}.)"
)
# for shape outputs, we infer the length
if len(res.value) != 0:
inferred_lengths = op.infer_len()
for idx in res.value:
outputs[idx] = Shape(outputs[idx])
if inferred_lengths[idx] is not None:
out_trace_tensors[idx].shape = [inferred_lengths[idx]]

if num_outputs == 1:
return outputs[0]
Expand Down Expand Up @@ -124,12 +129,15 @@ def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result:
return Result.err(["Either all inputs must be tp.Shape or all must be tp.Tensor."])
return Result.ok([])

def infer_shapes(self):
def infer_len(self) -> List[Optional[int]]:
"""
Infers shapes for the operation and updates output tensor shapes accordingly.
Infers the length of all `tp.Shape` outputs. This is, essentially, the "shape" of the shape.
Returns `None` for outputs that are not `tp.Shape`s or whose length (shape) cannot be inferred.

Returns:
A list of inferred lengths for outputs that are `tp.Shape`s.
"""
# Default implementation of infer_shapes fills dynamic dim for all elements.
self.outputs[0].shape = [-1] * self.outputs[0].rank
return [None for _ in self.outputs]

def infer_dtypes(self):
"""
Expand Down
9 changes: 9 additions & 0 deletions tripy/tripy/frontend/trace/ops/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def infer_shape_output_idxs(self, inputs):
else:
return Result.ok([])

def infer_len(self):
# For the shape case, the result will be broadcast to the max of the input shapes
input_lengths = []
for inp in self.inputs:
shape = op_utils.get_trace_shape(inp)
if len(shape) != 0:
input_lengths.append(shape[0])
return [max(input_lengths)]

def infer_dtypes(self):
op_utils.check_input_dtypes_match(self, self.kind.strip())
self.outputs[0].dtype = self.inputs[0].dtype
Expand Down
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/trace/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
from tripy import export, dtype_info
from tripy.frontend.trace.ops.base import BaseTraceOp
from tripy.frontend.trace.ops.utils import InferLenPolicies


@dataclass(repr=False)
Expand All @@ -35,6 +36,8 @@ def infer_shape_output_idxs(self, inputs):
return Result.ok([0])
return Result.ok([])

infer_len = InferLenPolicies.infer_same_as_first_input

def infer_dtypes(self):
self.outputs[0].dtype = self.dtype

Expand Down
7 changes: 7 additions & 0 deletions tripy/tripy/frontend/trace/ops/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class Concatenate(BaseTraceOp):
def infer_devices(self):
self.outputs[0].device = self.inputs[0].device

def infer_len(self):
# for shapes, only have to sum the input shapes
from tripy.frontend.trace.ops import utils as op_utils

out_length = sum(map(lambda inp: op_utils.get_trace_shape(inp)[0], self.inputs))
return [out_length]

def to_flat_ir(self, inputs, outputs):
from tripy.flat_ir.ops import ConcatenateOp

Expand Down
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/trace/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tripy import export
from tripy.common.device import device
from tripy.frontend.trace.ops.base import BaseTraceOp
from tripy.frontend.trace.ops.utils import InferLenPolicies


@dataclass(repr=False)
Expand All @@ -29,6 +30,8 @@ class Copy(BaseTraceOp):
def infer_devices(self):
self.outputs[0].device = self.target

infer_len = InferLenPolicies.infer_same_as_first_input

def to_flat_ir(self, inputs, outputs):
from tripy.flat_ir.ops import CopyOp

Expand Down
27 changes: 18 additions & 9 deletions tripy/tripy/frontend/trace/ops/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from dataclasses import dataclass
from typing import Sequence, Union
from typing import Optional, Sequence, Union

from tripy import export
from tripy.utils import Result
Expand All @@ -29,6 +29,7 @@
@dataclass(repr=False)
class Expand(BaseTraceOp):
output_rank: int
output_len: Optional[int] = None # only used to help with infer_len for a shape input

def infer_dtypes(self):
self.outputs[0].dtype = self.inputs[0].dtype
Expand All @@ -41,14 +42,17 @@ def infer_shape_output_idxs(self, inputs) -> Result:
return Result.ok([0])
return Result.ok([])

def infer_len(self):
if self.output_len is not None:
return [self.output_len]
# if we don't have a static output length, we can't infer without evaluating the input
return [None]

def infer_rank(self):
if self.output_rank is None:
from tripy.backend.mlir.utils import ShapeContext

out_shape = ShapeContext().get_shape_of_dynamic_trace_tensor(self.inputs[1])
out_shape = op_utils.get_trace_shape(self.inputs[1])
assert len(out_shape) == 1
assert out_shape[0] >= 0, f"incorrect shape computation {out_shape}"
self.inputs[1].shape = out_shape
self.output_rank = out_shape[0]

self.outputs[0].rank = self.output_rank
Expand All @@ -65,9 +69,9 @@ def to_flat_ir(self, inputs, outputs):
)


@frontend_utils.convert_inputs_to_tensors(exclude=["input", "output_rank"], shape_argument=["shape"])
def expand_impl(input: "tripy.Tensor", shape: Sequence, output_rank: int):
return Expand.build([input, shape], output_rank)
@frontend_utils.convert_inputs_to_tensors(exclude=["input", "output_rank", "output_len"], shape_argument=["shape"])
def expand_impl(input: "tripy.Tensor", shape: Sequence, output_rank: int, output_len: Optional[int] = None):
return Expand.build([input, shape], output_rank, output_len)


@export.public_api(document_under="operations/functions")
Expand Down Expand Up @@ -123,4 +127,9 @@ def expand(input: "tripy.Tensor", sizes: Union["tripy.Shape", Sequence[Union[int
continue
out_shape.append(size)

return expand_impl(input, out_shape, len(sizes))
# only used for inferring the length of a shape output (hence, define only in rank-1 case)
out_len = None
if len(sizes) == 1 and isinstance(out_shape[0], int):
out_len = out_shape[0]

return expand_impl(input, out_shape, len(sizes), out_len)
16 changes: 6 additions & 10 deletions tripy/tripy/frontend/trace/ops/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,12 @@ def infer_devices(self):

def infer_rank(self):
if self.output_rank is None:
if self.inputs[0].shape is None:
from tripy.backend.mlir.utils import ShapeContext

out_shape = ShapeContext().get_shape_of_dynamic_trace_tensor(self.inputs[0])
assert len(out_shape) == 1, f"Expected rank of shape tensor to be 1, got {len(out_shape)}"
assert (
out_shape[0] >= 0
), f"Incorrect shape of shape tensor, expected shape to be positive, got {out_shape[0]}"
self.inputs[0].shape = out_shape
self.output_rank = self.inputs[0].shape[0]
input_shape = op_utils.get_trace_shape(self.inputs[0])
assert len(input_shape) == 1, f"Expected rank of shape tensor to be 1, got {len(input_shape)}"
assert (
input_shape[0] >= 0
), f"Incorrect shape of shape tensor, expected shape to be positive, got {input_shape[0]}"
self.output_rank = input_shape[0]
self.outputs[0].rank = self.output_rank

def to_flat_ir(self, inputs, outputs):
Expand Down
Loading