Skip to content

Commit

Permalink
Add length inference for slice
Browse files Browse the repository at this point in the history
  • Loading branch information
slyubomirsky committed Aug 14, 2024
1 parent 2d9ecdd commit 60d624a
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 7 deletions.
32 changes: 32 additions & 0 deletions tripy/tests/frontend/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,38 @@ def test_slice_range(self, values):
assert isinstance(dims.trace_tensor.producer, Slice)
assert cp.from_dlpack(dims).get().tolist() == values[1:]

@pytest.mark.parametrize(
"slice_value",
[
slice(0, 2),
slice(0, 1),
slice(1, 3),
slice(0, 3, 2),
slice(1, 3, 2),
slice(1, 4, 2),
slice(1, 4, 3), # should select only one
slice(1, None, 200), # selects only start point
# some with negative strides
slice(None, None, -1),
slice(None, None, -2),
slice(4, 0, -1),
slice(2, 0, -1),
slice(2, 1, -1),
# check the clamping behavior
slice(-10, 20),
slice(10, -20, -1),
# invalid bounds (length 0 result)
slice(0, 4, -1),
slice(4, 0),
slice(2, 2),
],
)
def test_slice_len(self, slice_value):
# checking consistency against Python list
values = [1, 2, 3, 4]
s1 = tp.Shape(values)
assert len(s1[slice_value]) == len(values[slice_value])

def test_reduce(self, values):
from tripy.frontend.trace.ops.reduce import Reduce

Expand Down
57 changes: 50 additions & 7 deletions tripy/tripy/frontend/trace/ops/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math
from dataclasses import dataclass
from typing import Tuple, Union
from typing import Optional, Tuple, Union
from tripy import utils
from tripy.frontend.ops.registry import TENSOR_METHOD_REGISTRY
from tripy.frontend.trace.ops import utils as op_utils
Expand All @@ -30,6 +30,8 @@

@dataclass(repr=False)
class Slice(BaseTraceOp):
shape_slice: Optional[slice] = None # only used for inferring the length of a shape result

def infer_dtypes(self):
self.outputs[0].dtype = self.inputs[0].dtype

Expand All @@ -38,8 +40,43 @@ def infer_rank(self):
self.outputs[0].rank = self.inputs[0].rank

def infer_len(self):
# Skipping inference here since it depends on the concrete _values_ of the inputs
# rather than their shapes. This can be revisited if necessary
# Only infer if we have concrete values to use. Note that the result is only a Shape if these are *slices*,
# not single indices, so a slice is the only case that needs to be considered
if self.shape_slice is not None:
input_len = op_utils.get_trace_shape(self.inputs[0])[0]

def convert_to_positive_idx(idx):
return idx if idx >= 0 else input_len + idx

def clamp_bound(idx):
return 0 if idx < 0 else (idx if idx <= input_len else input_len)

stride = utils.default(self.shape_slice.step, 1)
if stride > 0:
start = 0 if self.shape_slice.start is None else convert_to_positive_idx(self.shape_slice.start)
stop = input_len if self.shape_slice.stop is None else convert_to_positive_idx(self.shape_slice.stop)
else:
# for negative stride, we compute the indices as they would be on the flipped list, see comments below
start = (
0
if self.shape_slice.start is None
else input_len - convert_to_positive_idx(self.shape_slice.start) - 1
)
stop = (
input_len
if self.shape_slice.stop is None
else input_len - convert_to_positive_idx(self.shape_slice.stop) - 1
)

start_point = clamp_bound(start)
end_point = clamp_bound(stop)
if start_point >= end_point:
return [0]

# - 1 because the end_point is exclusive. Use // so we round down
strides_in_range = (end_point - start_point - 1) // abs(stride)
# + 1 because we include the starting point and then make strides
return [1 + strides_in_range]
return [None]

# we only care about the data input
Expand Down Expand Up @@ -166,6 +203,7 @@ def __getitem__(self, index: Union[slice, int, Tuple[int], "tripy.Tensor"]) -> "
assert np.array_equal(cp.from_dlpack(output).get(), np.arange(10)[8:2:-1])
"""
from tripy.frontend.shape import Shape
from tripy.frontend.tensor import Tensor
from tripy.frontend.trace.ops.flip import flip
from tripy.frontend.trace.ops.reshape import reshape, squeeze
Expand All @@ -176,6 +214,11 @@ def __getitem__(self, index: Union[slice, int, Tuple[int], "tripy.Tensor"]) -> "
if isinstance(index, Tensor):
return gather(self, 0, index)

# if we are taking a literal slice of a shape, we can pass on the slice to infer the length of the shape statically
shape_slice = None
if isinstance(self, Shape) and isinstance(index, slice):
shape_slice = index

index = make_tuple(index)
if len(index) > self.rank:
raise_error(f"Input tensor has a rank of {self.rank} but was attempted to be sliced with {len(index)} indices")
Expand Down Expand Up @@ -238,7 +281,7 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
if flip_dims:
input_tensor = flip(input_tensor, dims=flip_dims)

out = slice_helper(input_tensor, *args)
out = slice_helper(input_tensor, *args, shape_slice=shape_slice)

squeeze_dims = []
for i, idx in enumerate(index):
Expand All @@ -255,6 +298,6 @@ def clamp_bound(bound: Union[int, Tensor]) -> Union[int, Tensor]:
# Conveniently converts the inputs to tensors. The decorator also fills in column info for the converted tensors.
# Because the helper is called inside another function, we need to skip one entry in the call stack to find
# the original call to user code.
@frontend_utils.convert_inputs_to_tensors(skip_num_stack_entries=1)
def slice_helper(tensor, *slice_params):
return Slice.build(inputs=[tensor, *slice_params])
@frontend_utils.convert_inputs_to_tensors(exclude=["shape_slice"], skip_num_stack_entries=1)
def slice_helper(tensor, *slice_params, shape_slice: Optional[slice] = None):
return Slice.build(inputs=[tensor, *slice_params], shape_slice=shape_slice)

0 comments on commit 60d624a

Please sign in to comment.