From 3e6cc1903fbee74225259074d1a9273b4dac4545 Mon Sep 17 00:00:00 2001 From: Physics <3150105638@zju.edu.cn> Date: Mon, 30 Mar 2026 15:41:15 +0800 Subject: [PATCH 1/5] fix: clamp start/end in slice_backward to prevent out-of-bounds kernel write --- src/flag_gems/ops/slice_backward.py | 7 +++++-- tests/test_reduction_ops.py | 25 +++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/src/flag_gems/ops/slice_backward.py b/src/flag_gems/ops/slice_backward.py index 19d1ee5cc5..f7b63558d3 100644 --- a/src/flag_gems/ops/slice_backward.py +++ b/src/flag_gems/ops/slice_backward.py @@ -52,8 +52,6 @@ def slice_backward( shape = list(input_sizes) - slice_len = (end - start + step - 1) // step - outer = 1 for i in range(dim): outer *= shape[i] @@ -64,6 +62,11 @@ def slice_backward( dim_size = shape[dim] + actual_start = max(0, min(start, dim_size)) + actual_end = max(0, min(end, dim_size)) + slice_len = max(0, (actual_end - actual_start + step - 1) // step) + start = actual_start + numel = grad_output.numel() BLOCK = 1024 diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 542f86e177..75a54bb317 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1236,6 +1236,31 @@ def test_accuracy_slice_backward( gems_assert_equal(res_out, ref_out) +@pytest.mark.slice +@pytest.mark.parametrize("shape", SLICE_BACKWARD_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_slice_backward_oob_end(shape, dtype): + # Regression test: end > dim_size caused out-of-bounds write in kernel. + device = flag_gems.device + dim = 1 % len(shape) + dim_size = shape[dim] + start = 0 + end = dim_size + 100 # intentionally out of bounds + step = 1 + + slice_len = dim_size # clamped + valid_shape = list(shape) + valid_shape[dim] = slice_len + + grad_output = torch.randn(valid_shape, dtype=dtype, device=device) + ref_grad_output = to_reference(grad_output) + + ref_out = torch.ops.aten.slice_backward(ref_grad_output, shape, dim, start, end, step) + res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step) + + gems_assert_equal(res_out, ref_out) + + @pytest.mark.slice_scatter @pytest.mark.parametrize(("dim", "shape", "stride"), REGULAR_DIM_SHAPE_STRIDES) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) From c7a5fba7f96634da9c6216a21b8e93e7ea925c2c Mon Sep 17 00:00:00 2001 From: Physics <3150105638@zju.edu.cn> Date: Mon, 30 Mar 2026 17:06:58 +0800 Subject: [PATCH 2/5] fix: derive slice_len from grad_output shape to prevent index mismatch Also add regression test for oob start case. --- src/flag_gems/ops/slice_backward.py | 3 +-- tests/test_reduction_ops.py | 27 +++++++++++++++++++++++++-- 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/flag_gems/ops/slice_backward.py b/src/flag_gems/ops/slice_backward.py index f7b63558d3..36ed8f3e24 100644 --- a/src/flag_gems/ops/slice_backward.py +++ b/src/flag_gems/ops/slice_backward.py @@ -63,8 +63,7 @@ def slice_backward( dim_size = shape[dim] actual_start = max(0, min(start, dim_size)) - actual_end = max(0, min(end, dim_size)) - slice_len = max(0, (actual_end - actual_start + step - 1) // step) + slice_len = grad_output.shape[dim] start = actual_start numel = grad_output.numel() diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 75a54bb317..ef77ddba1e 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1248,10 +1248,33 @@ def test_slice_backward_oob_end(shape, dtype): end = dim_size + 100 # intentionally out of bounds step = 1 - slice_len = dim_size # clamped + # grad_output shape matches what PyTorch would produce (clamped slice) valid_shape = list(shape) - valid_shape[dim] = slice_len + valid_shape[dim] = dim_size + grad_output = torch.randn(valid_shape, dtype=dtype, device=device) + ref_grad_output = to_reference(grad_output) + + ref_out = torch.ops.aten.slice_backward(ref_grad_output, shape, dim, start, end, step) + res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step) + + gems_assert_equal(res_out, ref_out) + +@pytest.mark.slice +@pytest.mark.parametrize("shape", SLICE_BACKWARD_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_slice_backward_oob_start(shape, dtype): + # Regression test: start > dim_size caused out-of-bounds write in kernel. + device = flag_gems.device + dim = 1 % len(shape) + dim_size = shape[dim] + start = dim_size + 50 # intentionally out of bounds + end = dim_size + 100 + step = 1 + + # grad_output is empty since clamped slice is empty + valid_shape = list(shape) + valid_shape[dim] = 0 grad_output = torch.randn(valid_shape, dtype=dtype, device=device) ref_grad_output = to_reference(grad_output) From 48ebd5dd0208f408173ade62de02067ae3a78389 Mon Sep 17 00:00:00 2001 From: Physics <3150105638@zju.edu.cn> Date: Mon, 30 Mar 2026 17:27:58 +0800 Subject: [PATCH 3/5] style: apply black formatting to test_reduction_ops.py --- tests/test_reduction_ops.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index ef77ddba1e..2ae38e3127 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1254,7 +1254,9 @@ def test_slice_backward_oob_end(shape, dtype): grad_output = torch.randn(valid_shape, dtype=dtype, device=device) ref_grad_output = to_reference(grad_output) - ref_out = torch.ops.aten.slice_backward(ref_grad_output, shape, dim, start, end, step) + ref_out = torch.ops.aten.slice_backward( + ref_grad_output, shape, dim, start, end, step + ) res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step) gems_assert_equal(res_out, ref_out) @@ -1278,7 +1280,9 @@ def test_slice_backward_oob_start(shape, dtype): grad_output = torch.randn(valid_shape, dtype=dtype, device=device) ref_grad_output = to_reference(grad_output) - ref_out = torch.ops.aten.slice_backward(ref_grad_output, shape, dim, start, end, step) + ref_out = torch.ops.aten.slice_backward( + ref_grad_output, shape, dim, start, end, step + ) res_out = flag_gems.ops.slice_backward(grad_output, shape, dim, start, end, step) gems_assert_equal(res_out, ref_out) From 369089b8e43d14c6bb5a811faaa10a529a12e2c2 Mon Sep 17 00:00:00 2001 From: Physics <3150105638@zju.edu.cn> Date: Thu, 2 Apr 2026 15:19:40 +0800 Subject: [PATCH 4/5] Update src/flag_gems/ops/slice_backward.py delete repeative assignment Co-authored-by: Qiming Teng Signed-off-by: Physics <3150105638@zju.edu.cn> --- src/flag_gems/ops/slice_backward.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/flag_gems/ops/slice_backward.py b/src/flag_gems/ops/slice_backward.py index 36ed8f3e24..bc4c6107b3 100644 --- a/src/flag_gems/ops/slice_backward.py +++ b/src/flag_gems/ops/slice_backward.py @@ -62,9 +62,8 @@ def slice_backward( dim_size = shape[dim] - actual_start = max(0, min(start, dim_size)) slice_len = grad_output.shape[dim] - start = actual_start + start = max(0, min(start, dim_size)) numel = grad_output.numel() From 7621ca170b314406451ea8b05d021126b31b6f92 Mon Sep 17 00:00:00 2001 From: Physics <3150105638@zju.edu.cn> Date: Wed, 8 Apr 2026 15:22:56 +0800 Subject: [PATCH 5/5] fix: normalize negative dim and start in slice_backward --- src/flag_gems/ops/slice_backward.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/flag_gems/ops/slice_backward.py b/src/flag_gems/ops/slice_backward.py index bc4c6107b3..66ef66a502 100644 --- a/src/flag_gems/ops/slice_backward.py +++ b/src/flag_gems/ops/slice_backward.py @@ -52,6 +52,9 @@ def slice_backward( shape = list(input_sizes) + if dim < 0: + dim += len(shape) + outer = 1 for i in range(dim): outer *= shape[i] @@ -63,6 +66,8 @@ def slice_backward( dim_size = shape[dim] slice_len = grad_output.shape[dim] + if start < 0: + start += dim_size start = max(0, min(start, dim_size)) numel = grad_output.numel()