Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/flag_gems/ops/slice_backward.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,6 @@ def slice_backward(

shape = list(input_sizes)

slice_len = (end - start + step - 1) // step

outer = 1
for i in range(dim):
outer *= shape[i]
Expand All @@ -64,6 +62,10 @@ def slice_backward(

dim_size = shape[dim]

actual_start = max(0, min(start, dim_size))
slice_len = grad_output.shape[dim]
start = actual_start

numel = grad_output.numel()

BLOCK = 1024
Expand Down
52 changes: 52 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,6 +1230,58 @@ def test_accuracy_slice_backward(
gems_assert_equal(res_out, ref_out)


@pytest.mark.slice
@pytest.mark.parametrize("shape", SLICE_BACKWARD_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_slice_backward_oob_end(shape, dtype):
# Regression test: end > dim_size caused out-of-bounds write in kernel.
device = flag_gems.device
dim = 1 % len(shape)
dim_size = shape[dim]
start = 0
end = dim_size + 100 # intentionally out of bounds
step = 1

# grad_output shape matches what PyTorch would produce (clamped slice)
valid_shape = list(shape)
valid_shape[dim] = dim_size
grad_output = torch.randn(valid_shape, dtype=dtype, device=device)
ref_grad_output = to_reference(grad_output)

ref_out = torch.ops.aten.slice_backward(
ref_grad_output, shape, dim, start, end, step
)
res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step)

gems_assert_equal(res_out, ref_out)


@pytest.mark.slice
@pytest.mark.parametrize("shape", SLICE_BACKWARD_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_slice_backward_oob_start(shape, dtype):
# Regression test: start > dim_size caused out-of-bounds write in kernel.
device = flag_gems.device
dim = 1 % len(shape)
dim_size = shape[dim]
start = dim_size + 50 # intentionally out of bounds
end = dim_size + 100
step = 1

# grad_output is empty since clamped slice is empty
valid_shape = list(shape)
valid_shape[dim] = 0
grad_output = torch.randn(valid_shape, dtype=dtype, device=device)
ref_grad_output = to_reference(grad_output)

ref_out = torch.ops.aten.slice_backward(
ref_grad_output, shape, dim, start, end, step
)
res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step)

gems_assert_equal(res_out, ref_out)


@pytest.mark.slice_scatter
@pytest.mark.parametrize(("dim", "shape", "stride"), REGULAR_DIM_SHAPE_STRIDES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
Expand Down
Loading