From bb41ae4ee6a6446614e715a8a15d85e3377ecbb7 Mon Sep 17 00:00:00 2001 From: Steven Lyubomirsky Date: Wed, 14 Aug 2024 00:09:21 -0400 Subject: [PATCH] Add length inference for slice --- tripy/tests/frontend/test_shape.py | 32 ++++++++++++++ tripy/tripy/frontend/trace/ops/slice.py | 57 ++++++++++++++++++++++--- 2 files changed, 82 insertions(+), 7 deletions(-) diff --git a/tripy/tests/frontend/test_shape.py b/tripy/tests/frontend/test_shape.py index dd1214f59..45fbc49a9 100644 --- a/tripy/tests/frontend/test_shape.py +++ b/tripy/tests/frontend/test_shape.py @@ -236,6 +236,38 @@ def test_slice_range(self, values): assert isinstance(dims.trace_tensor.producer, Slice) assert cp.from_dlpack(dims).get().tolist() == values[1:] + @pytest.mark.parametrize( + "slice_value", + [ + slice(0, 2), + slice(0, 1), + slice(1, 3), + slice(0, 3, 2), + slice(1, 3, 2), + slice(1, 4, 2), + slice(1, 4, 3), # should select only one + slice(1, None, 200), # selects only start point + # some with negative strides + slice(None, None, -1), + slice(None, None, -2), + slice(4, 0, -1), + slice(2, 0, -1), + slice(2, 1, -1), + # check the clamping behavior + slice(-10, 20), + slice(10, -20, -1), + # invalid bounds (length 0 result) + slice(0, 4, -1), + slice(4, 0), + slice(2, 2), + ], + ) + def test_slice_len(self, slice_value): + # checking consistency against Python list + values = [1, 2, 3, 4] + s1 = tp.Shape(values) + assert len(s1[slice_value]) == len(values[slice_value]) + def test_reduce(self, values): from tripy.frontend.trace.ops.reduce import Reduce diff --git a/tripy/tripy/frontend/trace/ops/slice.py b/tripy/tripy/frontend/trace/ops/slice.py index 7bd29e54f..f8e498428 100644 --- a/tripy/tripy/frontend/trace/ops/slice.py +++ b/tripy/tripy/frontend/trace/ops/slice.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Tuple, Union +from typing import Optional, Tuple, Union from tripy import utils from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY from tripy.frontend.trace.ops import utils as op_utils @@ -30,6 +30,8 @@ @dataclass(repr=False) class Slice(BaseTraceOp): + shape_slice: Optional[slice] = None # only used for inferring the length of a shape result + def infer_dtypes(self): self.outputs[0].dtype = self.inputs[0].dtype @@ -38,8 +40,43 @@ def infer_rank(self): self.outputs[0].rank = self.inputs[0].rank def infer_len(self): - # Skipping inference here since it depends on the concrete _values_ of the inputs - # rather than their shapes. This can be revisited if necessary + # Only infer if we have concrete values to use. Note that the result is only a Shape if these are *slices*, + # not single indices, so a slice is the only case that needs to be considered + if self.shape_slice is not None: + input_len = op_utils.get_trace_shape(self.inputs[0])[0] + + def convert_to_positive_idx(idx): + return idx if idx >= 0 else input_len + idx + + def clamp_bound(idx): + return 0 if idx < 0 else (idx if idx <= input_len else input_len) + + stride = utils.default(self.shape_slice.step, 1) + if stride > 0: + start = 0 if self.shape_slice.start is None else convert_to_positive_idx(self.shape_slice.start) + stop = input_len if self.shape_slice.stop is None else convert_to_positive_idx(self.shape_slice.stop) + else: + # for negative stride, we compute the indices as they would be on the flipped list, see comments below + start = ( + 0 + if self.shape_slice.start is None + else input_len - convert_to_positive_idx(self.shape_slice.start) - 1 + ) + stop = ( + input_len + if self.shape_slice.stop is None + else input_len - convert_to_positive_idx(self.shape_slice.stop) - 1 + ) + + start_point = clamp_bound(start) + end_point = clamp_bound(stop) + if start_point >= end_point: + return [0] + + # - 1 because the end_point is exclusive. Use // so we round down + strides_in_range = (end_point - start_point - 1) // abs(stride) + # + 1 because we include the starting point and then make strides + return [1 + strides_in_range] return [None] # we only care about the data input @@ -166,6 +203,7 @@ def __getitem__(self, index: Union[slice, int, Tuple[int], "tripy.Tensor"]) -> " assert np.array_equal(cp.from_dlpack(output).get(), np.arange(10)[8:2:-1]) """ + from tripy.frontend.shape import Shape from tripy.frontend.tensor import Tensor from tripy.frontend.trace.ops.flip import flip from tripy.frontend.trace.ops.reshape import reshape, squeeze @@ -176,6 +214,11 @@ def __getitem__(self, index: Union[slice, int, Tuple[int], "tripy.Tensor"]) -> " if isinstance(index, Tensor): return gather(self, 0, index) + # if we are taking a literal slice of a shape, we can pass on the slice to infer the length of the shape statically + shape_slice = None + if isinstance(self, Shape) and isinstance(index, slice): + shape_slice = index + index = make_tuple(index) if len(index) > self.rank: raise_error(f"Input tensor has a rank of {self.rank} but was attempted to be sliced with {len(index)} indices") @@ -238,7 +281,7 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]: if flip_dims: input_tensor = flip(input_tensor, dims=flip_dims) - out = slice_helper(input_tensor, *args) + out = slice_helper(input_tensor, *args, shape_slice=shape_slice) squeeze_dims = [] for i, idx in enumerate(index): @@ -255,6 +298,6 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]: # Conveniently converts the inputs to tensors. The decorator also fills in column info for the converted tensors. # Because the helper is called inside another function, we need to skip one entry in the call stack to find # the original call to user code. -@frontend_utils.convert_inputs_to_tensors(skip_num_stack_entries=1) -def slice_helper(tensor, *slice_params): - return Slice.build(inputs=[tensor, *slice_params]) +@frontend_utils.convert_inputs_to_tensors(exclude=["shape_slice"], skip_num_stack_entries=1) +def slice_helper(tensor, *slice_params, shape_slice: Optional[slice] = None): + return Slice.build(inputs=[tensor, *slice_params], shape_slice=shape_slice)