Skip to content

Commit

Permalink
[Tripy] Add __len__ for tp.Shape and infer length statically when…
Browse files Browse the repository at this point in the history
… possible (#92)

For some cases, it is useful to know the length of a `tp.Shape` without
executing the model. This PR adds a method `infer_len` that allows
operators to specify how to statically infer the length of `Shape`
outputs when possible (it is always optional). Test cases are added.
  • Loading branch information
slyubomirsky authored Aug 20, 2024
1 parent b4491f7 commit 9456a8b
Show file tree
Hide file tree
Showing 18 changed files with 298 additions and 47 deletions.
98 changes: 95 additions & 3 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,15 @@ def test_shape(self, values):
s = tp.Shape(values)

assert isinstance(s, tp.Shape)
assert len(s) == len(values)
assert s.trace_tensor.producer.inputs == []
assert cp.from_dlpack(s).get().tolist() == values

def test_empty_shape(self):
s = tp.Shape([])

assert isinstance(s, tp.Shape)
assert len(s) == 0
assert s.trace_tensor.producer.inputs == []
assert cp.from_dlpack(s).get().tolist() == []

Expand All @@ -60,6 +62,7 @@ def test_constructor_from_tensor(self, values):
s = tp.Shape(t)

assert isinstance(s, tp.Shape)
assert len(s) == len(values)
assert s.trace_tensor.producer.inputs == []
# they should be the same underlying value
assert s.trace_tensor == t.trace_tensor
Expand Down Expand Up @@ -100,6 +103,12 @@ def test_plus_override(self, values, other_values):
assert isinstance(new_shape.trace_tensor.producer, Concatenate)
assert cp.from_dlpack(new_shape).get().tolist() == values + appended

def test_len_concatenation(self, values):
s = tp.Shape(values)
# we are testing that the length is *inferred*, so do not execute the concatenation
c = s + s
assert len(c) == 2 * len(values)

def test_explicit_addition(self, values):
from tripy.frontend.trace.ops.binary_elementwise import BinaryElementwise

Expand All @@ -110,6 +119,14 @@ def test_explicit_addition(self, values):
assert res.trace_tensor.producer.kind == BinaryElementwise.Kind.SUM
assert cp.from_dlpack(res).get().tolist() == [2 * v for v in values]

def test_len_binary_op(self, values):
s = tp.Shape(values)
res = s.add(tp.Shape(values))
assert len(res) == len(values)

res = s * 2
assert len(res) == len(values)

def test_shape_op(self, values):
from tripy.frontend.trace.ops.shape import Shape

Expand All @@ -120,6 +137,11 @@ def test_shape_op(self, values):
assert isinstance(s.trace_tensor.producer, Shape)
assert cp.from_dlpack(s).get().tolist() == [len(values)]

def test_len_shape_op(self, values):
t = tp.Tensor(values)
s = t.shape
assert len(s) == 1

def test_flip(self, values):
from tripy.frontend.trace.ops.flip import Flip

Expand All @@ -130,6 +152,11 @@ def test_flip(self, values):
assert isinstance(flipped_shape.trace_tensor.producer, Flip)
assert cp.from_dlpack(flipped_shape).get().tolist() == values[::-1]

def test_len_flip(self, values):
s = tp.Shape(values)
flipped = tp.flip(s, dims=0)
assert len(flipped) == len(values)

def test_expand(self):
from tripy.frontend.trace.ops.expand import Expand

Expand All @@ -141,6 +168,11 @@ def test_expand(self):
assert isinstance(expanded.trace_tensor.producer, Expand)
assert cp.from_dlpack(expanded).get().tolist() == [1, 1, 1]

def test_len_expand(self):
s = tp.Shape([1])
expanded = tp.expand(s, (3,))
assert len(expanded) == 3

def test_gather(self, values):
from tripy.frontend.trace.ops.gather import Gather

Expand All @@ -150,6 +182,11 @@ def test_gather(self, values):
assert isinstance(s2.trace_tensor.producer, Gather)
assert cp.from_dlpack(s2).get().tolist() == [values[0], values[-1]]

def test_len_gather(self, values):
s = tp.Shape(values)
gathered = tp.gather(s, 0, tp.Tensor([0, len(values) - 1]))
assert len(gathered) == 2

def test_matmul(self, values):
s1 = tp.Shape(values)
s2 = tp.Shape(values)
Expand Down Expand Up @@ -199,6 +236,38 @@ def test_slice_range(self, values):
assert isinstance(dims.trace_tensor.producer, Slice)
assert cp.from_dlpack(dims).get().tolist() == values[1:]

@pytest.mark.parametrize(
"slice_value",
[
slice(0, 2),
slice(0, 1),
slice(1, 3),
slice(0, 3, 2),
slice(1, 3, 2),
slice(1, 4, 2),
slice(1, 4, 3), # should select only one
slice(1, None, 200), # selects only start point
# some with negative strides
slice(None, None, -1),
slice(None, None, -2),
slice(4, 0, -1),
slice(2, 0, -1),
slice(2, 1, -1),
# check the clamping behavior
slice(-10, 20),
slice(10, -20, -1),
# invalid bounds (length 0 result)
slice(0, 4, -1),
slice(4, 0),
slice(2, 2),
],
)
def test_slice_len(self, slice_value):
# checking consistency against Python list
values = [1, 2, 3, 4]
s1 = tp.Shape(values)
assert len(s1[slice_value]) == len(values[slice_value])

def test_reduce(self, values):
from tripy.frontend.trace.ops.reduce import Reduce

Expand Down Expand Up @@ -263,6 +332,11 @@ def test_expand_higher_rank_not_wrapped(self):
assert not isinstance(e, tp.Shape)
assert cp.from_dlpack(e).get().tolist() == [[1] for _ in range(3)]

def test_cast_len(self, values):
s = tp.Shape(values)
cast = tp.cast(s, tp.int32)
assert len(cast) == len(values)

def test_split(self, values):
s = tp.Shape(values)
outputs = tp.split(s, len(values))
Expand All @@ -271,6 +345,19 @@ def test_split(self, values):
assert isinstance(output, tp.Shape)
assert cp.from_dlpack(output).get().tolist() == [values[i]]

def test_split_len(self, values):
s = tp.Shape(values)
outputs = tp.split(s, len(values))
for output in outputs:
assert len(output) == 1

def test_split_len_intervals(self):
s = tp.Shape([1, 2, 3, 4, 5])
outputs = tp.split(s, [1, 4])
assert len(outputs[0]) == 1 # 0:1
assert len(outputs[1]) == 3 # 1:4
assert len(outputs[2]) == 1 # 4:5

def test_where(self, values):
from tripy.frontend.trace.ops.where import Where

Expand All @@ -284,6 +371,13 @@ def test_where(self, values):
assert isinstance(res.trace_tensor.producer, Where)
assert cp.from_dlpack(res).get().tolist() == [0 if values[i] < 2 else values[i] for i in range(len(values))]

def test_where_len(self, values):
s1 = tp.Shape(values)
s2 = tp.Shape([0 for _ in values])
cond = tp.Tensor([i >= 1 for i in range(len(values))], dtype=tp.bool)
res = tp.where(cond, s1, s2)
assert len(res) == len(values)

def test_invalid_input_dtype(self):
with raises(tp.TripyException, match="Data has incorrect dtype"):
_ = tp.Shape(np.array([2.0, 3.0], dtype=np.float32))
Expand Down Expand Up @@ -343,9 +437,7 @@ def test_unary_elementwise_fails_at_run_time(self, values):
v = tp.exp(tp.Shape(values))
with raises(
tp.TripyException,
match=(
"'stablehlo.exponential' op operand #0 must be ranked tensor of"
),
match=("'stablehlo.exponential' op operand #0 must be ranked tensor of"),
):
v.eval()

Expand Down
6 changes: 6 additions & 0 deletions tripy/tripy/frontend/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,9 @@ def __eq__(self, other):
from tripy.frontend.trace.ops.reduce import all

return bool(all(self.as_tensor() == other.as_tensor()))

# __len__ for shapes gives the number of dims in the shape, i.e., the first dimension of the shape's shape
def __len__(self):
from tripy.frontend.trace.ops import utils as op_utils

return op_utils.get_trace_shape(self.trace_tensor)[0]
16 changes: 12 additions & 4 deletions tripy/tripy/frontend/trace/ops/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,13 @@ def build(cls, inputs: List["Tensor"], *args, num_outputs=1, **kwargs) -> Union[
raise_error(
f"Error processing shape inputs in operator {cls.__name__}{custom_err}\n(Shape input indices: {shape_arg_msg}.)"
)
# for shape outputs, we infer the length
if len(res.value) != 0:
inferred_lengths = op.infer_len()
for idx in res.value:
outputs[idx] = Shape(outputs[idx])
if inferred_lengths[idx] is not None:
out_trace_tensors[idx].shape = [inferred_lengths[idx]]

if num_outputs == 1:
return outputs[0]
Expand Down Expand Up @@ -124,12 +129,15 @@ def infer_shape_output_idxs(self, inputs: List["Tensor"]) -> Result:
return Result.err(["Either all inputs must be tp.Shape or all must be tp.Tensor."])
return Result.ok([])

def infer_shapes(self):
def infer_len(self) -> List[Optional[int]]:
"""
Infers shapes for the operation and updates output tensor shapes accordingly.
Infers the length of all `tp.Shape` outputs. This is, essentially, the "shape" of the shape.
Returns `None` for outputs that are not `tp.Shape`s or whose length (shape) cannot be inferred.
Returns:
A list of inferred lengths for outputs that are `tp.Shape`s.
"""
# Default implementation of infer_shapes fills dynamic dim for all elements.
self.outputs[0].shape = [-1] * self.outputs[0].rank
return [None for _ in self.outputs]

def infer_dtypes(self):
"""
Expand Down
9 changes: 9 additions & 0 deletions tripy/tripy/frontend/trace/ops/binary_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ def infer_shape_output_idxs(self, inputs):
else:
return Result.ok([])

def infer_len(self):
# For the shape case, the result will be broadcast to the max of the input shapes
input_lengths = []
for inp in self.inputs:
shape = op_utils.get_trace_shape(inp)
if len(shape) != 0:
input_lengths.append(shape[0])
return [max(input_lengths)]

def infer_dtypes(self):
op_utils.check_input_dtypes_match(self, self.kind.strip())
self.outputs[0].dtype = self.inputs[0].dtype
Expand Down
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/trace/ops/cast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dataclasses import dataclass
from tripy import export, dtype_info
from tripy.frontend.trace.ops.base import BaseTraceOp
from tripy.frontend.trace.ops.utils import InferLenPolicies


@dataclass(repr=False)
Expand All @@ -35,6 +36,8 @@ def infer_shape_output_idxs(self, inputs):
return Result.ok([0])
return Result.ok([])

infer_len = InferLenPolicies.infer_same_as_first_input

def infer_dtypes(self):
self.outputs[0].dtype = self.dtype

Expand Down
7 changes: 7 additions & 0 deletions tripy/tripy/frontend/trace/ops/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ class Concatenate(BaseTraceOp):
def infer_devices(self):
self.outputs[0].device = self.inputs[0].device

def infer_len(self):
# for shapes, only have to sum the input shapes
from tripy.frontend.trace.ops import utils as op_utils

out_length = sum(map(lambda inp: op_utils.get_trace_shape(inp)[0], self.inputs))
return [out_length]

def to_flat_ir(self, inputs, outputs):
from tripy.flat_ir.ops import ConcatenateOp

Expand Down
3 changes: 3 additions & 0 deletions tripy/tripy/frontend/trace/ops/copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from tripy import export
from tripy.common.device import device
from tripy.frontend.trace.ops.base import BaseTraceOp
from tripy.frontend.trace.ops.utils import InferLenPolicies


@dataclass(repr=False)
Expand All @@ -29,6 +30,8 @@ class Copy(BaseTraceOp):
def infer_devices(self):
self.outputs[0].device = self.target

infer_len = InferLenPolicies.infer_same_as_first_input

def to_flat_ir(self, inputs, outputs):
from tripy.flat_ir.ops import CopyOp

Expand Down
27 changes: 18 additions & 9 deletions tripy/tripy/frontend/trace/ops/expand.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

from dataclasses import dataclass
from typing import Sequence, Union
from typing import Optional, Sequence, Union

from tripy import export
from tripy.utils import Result
Expand All @@ -29,6 +29,7 @@
@dataclass(repr=False)
class Expand(BaseTraceOp):
output_rank: int
output_len: Optional[int] = None # only used to help with infer_len for a shape input

def infer_dtypes(self):
self.outputs[0].dtype = self.inputs[0].dtype
Expand All @@ -41,14 +42,17 @@ def infer_shape_output_idxs(self, inputs) -> Result:
return Result.ok([0])
return Result.ok([])

def infer_len(self):
if self.output_len is not None:
return [self.output_len]
# if we don't have a static output length, we can't infer without evaluating the input
return [None]

def infer_rank(self):
if self.output_rank is None:
from tripy.backend.mlir.utils import ShapeContext

out_shape = ShapeContext().get_shape_of_dynamic_trace_tensor(self.inputs[1])
out_shape = op_utils.get_trace_shape(self.inputs[1])
assert len(out_shape) == 1
assert out_shape[0] >= 0, f"incorrect shape computation {out_shape}"
self.inputs[1].shape = out_shape
self.output_rank = out_shape[0]

self.outputs[0].rank = self.output_rank
Expand All @@ -65,9 +69,9 @@ def to_flat_ir(self, inputs, outputs):
)


@frontend_utils.convert_inputs_to_tensors(exclude=["input", "output_rank"], shape_argument=["shape"])
def expand_impl(input: "tripy.Tensor", shape: Sequence, output_rank: int):
return Expand.build([input, shape], output_rank)
@frontend_utils.convert_inputs_to_tensors(exclude=["input", "output_rank", "output_len"], shape_argument=["shape"])
def expand_impl(input: "tripy.Tensor", shape: Sequence, output_rank: int, output_len: Optional[int] = None):
return Expand.build([input, shape], output_rank, output_len)


@export.public_api(document_under="operations/functions")
Expand Down Expand Up @@ -123,4 +127,9 @@ def expand(input: "tripy.Tensor", sizes: Union["tripy.Shape", Sequence[Union[int
continue
out_shape.append(size)

return expand_impl(input, out_shape, len(sizes))
# only used for inferring the length of a shape output (hence, define only in rank-1 case)
out_len = None
if len(sizes) == 1 and isinstance(out_shape[0], int):
out_len = out_shape[0]

return expand_impl(input, out_shape, len(sizes), out_len)
16 changes: 6 additions & 10 deletions tripy/tripy/frontend/trace/ops/fill.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,16 +46,12 @@ def infer_devices(self):

def infer_rank(self):
if self.output_rank is None:
if self.inputs[0].shape is None:
from tripy.backend.mlir.utils import ShapeContext

out_shape = ShapeContext().get_shape_of_dynamic_trace_tensor(self.inputs[0])
assert len(out_shape) == 1, f"Expected rank of shape tensor to be 1, got {len(out_shape)}"
assert (
out_shape[0] >= 0
), f"Incorrect shape of shape tensor, expected shape to be positive, got {out_shape[0]}"
self.inputs[0].shape = out_shape
self.output_rank = self.inputs[0].shape[0]
input_shape = op_utils.get_trace_shape(self.inputs[0])
assert len(input_shape) == 1, f"Expected rank of shape tensor to be 1, got {len(input_shape)}"
assert (
input_shape[0] >= 0
), f"Incorrect shape of shape tensor, expected shape to be positive, got {input_shape[0]}"
self.output_rank = input_shape[0]
self.outputs[0].rank = self.output_rank

def to_flat_ir(self, inputs, outputs):
Expand Down
Loading

0 comments on commit 9456a8b

Please sign in to comment.